"""Differential bilinear handlers backed by :mod:`boundlab.zonosq`."""
from __future__ import annotations
import torch
from boundlab import expr
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.expr._affine import ConstVal
from boundlab.expr._core import Expr
from boundlab.zonosq.bilinear import (
bilinear_elementwise,
bilinear_matmul,
matmul_handler as zonosq_matmul_handler,
zonosq_matmuls,
)
def _zero_like_expr(e: Expr) -> Expr:
return expr.ConstVal(torch.zeros(e.shape))
[docs]
def diff_bilinear_elementwise(a: DiffExpr3, b: DiffExpr3) -> DiffExpr3:
assert a.shape == b.shape, f"Shapes must match: {a.shape} vs {b.shape}"
out_x = bilinear_elementwise(a.x, b.x)
out_y = bilinear_elementwise(a.y, b.y)
# a1*b1 - a2*b2 = a1*(b1-b2) + (a1-a2)*b2
term1 = bilinear_elementwise(a.x, b.diff)
term2 = bilinear_elementwise(a.diff, b.y)
out_diff = term1 + term2
return DiffExpr3(out_x, out_y, out_diff)
[docs]
def diff_bilinear_matmul(a: DiffExpr3, b: DiffExpr3) -> DiffExpr3:
out_x = bilinear_matmul(a.x, b.x)
out_y = bilinear_matmul(a.y, b.y)
# A1@B1 - A2@B2 = A1@(B1-B2) + (A1-A2)@B2
# Abstract both products as a single matmul so the bilinear error bounds
# shared epsilon symbols across both terms jointly (accumulating
# coefficients before taking absolute values), which is at least as tight
# as abstracting the two products independently and adding.
out_diff = zonosq_matmuls((a.x, b.diff), (a.diff, b.y))
# Keep the tighter per-coordinate expression between direct diff tracking
# and subtracting the two branch outputs.
from boundlab.prop import bound_width
sub_diff = out_x - out_y
bw_d = bound_width(out_diff)
bw_s = bound_width(sub_diff)
finite_d = torch.isfinite(bw_d)
finite_s = torch.isfinite(bw_s)
use_sub = finite_s & ((~finite_d) | (bw_s < bw_d))
if bool(use_sub.all()):
out_diff = sub_diff
elif bool(use_sub.any()) and bool(finite_d.all()) and bool(finite_s.all()):
mask = use_sub.float()
out_diff = mask * sub_diff + (1.0 - mask) * out_diff
return DiffExpr3(out_x, out_y, out_diff)
def _promote2(x: DiffExpr2) -> DiffExpr3:
return DiffExpr3(x.x, x.y, x.x - x.y)
def diff_mul_handler(a, b):
if isinstance(a, DiffExpr3) and isinstance(b, DiffExpr3):
return diff_bilinear_elementwise(a, b)
if isinstance(a, DiffExpr3) and isinstance(b, DiffExpr2):
try:
return a * b
except TypeError:
return diff_bilinear_elementwise(a, _promote2(b))
if isinstance(a, DiffExpr2) and isinstance(b, DiffExpr3):
try:
return b * a
except TypeError:
return diff_bilinear_elementwise(_promote2(a), b)
if isinstance(a, DiffExpr2) and isinstance(b, DiffExpr2):
try:
return a * b
except TypeError:
return diff_bilinear_elementwise(_promote2(a), _promote2(b))
if isinstance(a, (DiffExpr3, DiffExpr2)):
return a * b
if isinstance(b, (DiffExpr3, DiffExpr2)):
return b * a
if isinstance(a, Expr) and isinstance(b, torch.Tensor):
return a * b
if isinstance(a, torch.Tensor) and isinstance(b, Expr):
return b * a
return a * b
def diff_matmul_handler(a, b):
if (isinstance(a, DiffExpr2) and a.is_constant()) or isinstance(b, (torch.Tensor, ConstVal)):
return a @ b
if (isinstance(b, DiffExpr2) and b.is_constant()) or isinstance(a, (torch.Tensor, ConstVal)):
return a @ b
if isinstance(a, DiffExpr2):
a = _promote2(a)
if isinstance(b, DiffExpr2):
b = _promote2(b)
if isinstance(a, DiffExpr3) or isinstance(b, DiffExpr3):
if isinstance(a, Expr):
a = DiffExpr3(a, a, _zero_like_expr(a))
if isinstance(b, Expr):
b = DiffExpr3(b, b, _zero_like_expr(b))
return diff_bilinear_matmul(a, b)
return zonosq_matmul_handler(a, b)
__all__ = [
"diff_bilinear_elementwise",
"diff_bilinear_matmul",
"diff_mul_handler",
"diff_matmul_handler",
]