"""Bilinear operation handlers for polytope abstract interpretation."""
from __future__ import annotations
from typing import Union
import torch
from boundlab.expr._affine import ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._var import LpEpsilon
def bilinear_elementwise(A: Expr, B: Expr) -> Expr:
"""Linearize element-wise product of two symbolic expressions."""
assert A.shape == B.shape, \
f"Shapes must match for element-wise product: {A.shape} vs {B.shape}"
Ac, As = A.split_const()
Bc, Bs = B.split_const()
result = Ac * Bs + As * Bc + Ac * Bc
if As.is_symmetric_to_0():
Ahw = As.ub()
else:
A_ub, A_lb = As.ublb()
Ahw = (A_ub - A_lb) / 2.0
if Bs.is_symmetric_to_0():
Bhw = Bs.ub()
else:
B_ub, B_lb = Bs.ublb()
Bhw = (B_ub - B_lb) / 2.0
error_bound = Ahw * Bhw
new_eps = LpEpsilon(error_bound.shape)
return result + error_bound * new_eps
def square_matmul(A: Expr, B: Expr) -> Expr:
"""Linearize ``A @ B`` using a square-split bilinear relaxation."""
assert len(A.shape) >= 2 and len(B.shape) >= 2, \
f"Need at least 2D for matmul, got {A.shape} @ {B.shape}"
assert A.shape[:-2] == B.shape[:-2], \
f"Batch dims must match: {A.shape} @ {B.shape}"
assert A.shape[-1] == B.shape[-2], \
f"Inner dims must match: {A.shape} @ {B.shape}"
Ac, As = A.split_const()
Bc, Bs = B.split_const()
result = Ac @ Bs + As @ Bc + Ac @ Bc
if As.is_symmetric_to_0() and Bs.is_symmetric_to_0():
m, k, n = A.shape[-2], A.shape[-1], B.shape[-1]
Au = As.ub()
Bu = Bs.ub()
U = Au @ Bu
L = -U
Au = Au.unsqueeze(-1).expand(*Au.shape, n)
Bu = Bu.unsqueeze(-3).expand(*Bu.shape[:-2], m, k, n)
As = As.unsqueeze(-1).expand(*As.shape, n)
Bs = Bs.unsqueeze(-3).expand(*Bs.shape[:-2], m, k, n)
a = torch.sqrt(Au)
b = torch.sqrt(Bu)
lama = a / b
lamb = b / a
Pos = torch.nan_to_num((lama * As + lamb * Bs).ub() ** 2 / 4, nan=1e10, posinf=1e10, neginf=1e10)
Neg = torch.nan_to_num(-(lama * As - lamb * Bs).ub() ** 2 / 4, nan=-1e10, posinf=-1e10, neginf=-1e10)
U = torch.minimum(Pos.sum(dim=-2), U)
L = torch.maximum(Neg.sum(dim=-2), L)
else:
m, k, n = A.shape[-2], A.shape[-1], B.shape[-1]
Au, Al = A.ublb()
Bu, Bl = B.ublb()
Ac = (Au + Al) / 2
As = A - Ac
Bc = (Bu + Bl) / 2
Bs = B - Bc
result = Ac @ Bc + As @ Bc + Ac @ Bs
Asu = (Au - Al) / 2
Bsu = (Bu - Bl) / 2
U = Asu @ Bsu
L = -U
Asu = Asu.unsqueeze(-1).expand(*Asu.shape, n)
Bsu = Bsu.unsqueeze(-3).expand(*Bsu.shape[:-2], m, k, n)
As = As.unsqueeze(-1).expand(*As.shape, n)
Bs = Bs.unsqueeze(-3).expand(*Bs.shape[:-2], m, k, n)
a = torch.sqrt(Asu)
b = torch.sqrt(Bsu)
lama = a / b
lamb = b / a
PosU, PosL = (lama * As + lamb * Bs).ublb()
NegU, NegL = (lama * As - lamb * Bs).ublb()
Pos = torch.nan_to_num(torch.maximum(PosU ** 2, PosL ** 2) / 4, nan=1e10, posinf=1e10, neginf=1e10)
Neg = torch.nan_to_num(-torch.maximum(NegU ** 2, NegL ** 2) / 4, nan=-1e10, posinf=-1e10, neginf=-1e10)
U = torch.minimum(Pos.sum(dim=-2), U)
L = torch.maximum(Neg.sum(dim=-2), L)
result += (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape)
return result
def bilinear_matmul(A: Expr, B: Expr) -> Expr:
"""Linearize ``A @ B`` when both operands are symbolic expressions."""
return square_matmul(A, B)
[docs]
def matmul_handler(A, B):
"""Dispatcher implementation for ``torch.matmul``."""
if isinstance(A, Expr) and _is_const(B):
return A @ B
if _is_const(A) and isinstance(B, Expr):
return A @ B
if _is_const(A) and _is_const(B):
return torch.matmul(A, B)
if isinstance(A, Expr) and isinstance(B, Expr):
return bilinear_matmul(A, B)
return NotImplemented
[docs]
def mul_handler(A, B):
"""Dispatcher implementation for element-wise multiplication."""
if isinstance(A, Expr) and _is_const(B):
return _mul_expr_const(A, B)
if _is_const(A) and isinstance(B, Expr):
return _mul_expr_const(B, A)
if _is_const(A) and _is_const(B):
return torch.mul(A, B)
if isinstance(A, Expr) and isinstance(B, Expr):
return bilinear_elementwise(A, B)
return NotImplemented
def _mul_expr_const(x: Expr, c):
if isinstance(c, ConstVal):
c = c.value
if isinstance(c, (int, float)):
return x * c
if isinstance(c, torch.Tensor):
if c.dim() == 0:
return x * c
out_shape = torch.broadcast_shapes(tuple(x.shape), tuple(c.shape))
if tuple(x.shape) != out_shape:
x = x.expand(*out_shape)
if tuple(c.shape) != out_shape:
c = c.expand(*out_shape)
return x * c
return x * c
def _is_const(tensor: Union[torch.Tensor, Expr, int]) -> bool:
return isinstance(tensor, torch.Tensor) or isinstance(tensor, ConstVal) or isinstance(tensor, (int, float))