"""
Scalar Multiplication (element-wise):
a1·b1 − a2·b2 = a1·∆b + ∆a·b2
Dot Product / Matrix Multiply:
A1@B1 − A2@B2 = A1@∆B + ∆A@B2
"""
import torch
from boundlab import expr
from boundlab.expr._affine import ConstVal
from boundlab.expr._core import Expr
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.zono.bilinear import (
bilinear_elementwise,
bilinear_matmul,
deept_precise_matmul,
_is_const,
)
def _mm(p, q):
"""matmul for the differential terms, honoring ``config.matmul_mode``.
Square mode (the default) keeps the exact ``bilinear_matmul`` quarter-square
path, so existing results reproduce byte-for-byte. ``precise`` / ``precise+sym``
route genuine zonotope x zonotope products (e.g. attention ``Q@K^T``,
``attn@V``) through the DeepT-Precise relaxation (the shared-eps eps^2 diagonal
correction). Const/affine operands stay on ``bilinear_matmul`` -- that product
is exact, so the relaxation choice is moot there. This mirrors how the
single-model ``matmul_handler`` dispatches on ``config.matmul_mode``."""
from boundlab.config import config
if (config.matmul_mode in ("precise", "precise+sym")
and not _is_const(p) and not _is_const(q)):
return deept_precise_matmul(
p, q,
symmetrize=(config.matmul_mode == "precise+sym"),
interval_cap=config.matmul_interval_cap,
)
return bilinear_matmul(p, q)
[docs]
def diff_bilinear_elementwise(a: DiffExpr3, b: DiffExpr3) -> DiffExpr3:
assert a.shape == b.shape, \
f"Shapes must match: {a.shape} vs {b.shape}"
out_x = bilinear_elementwise(a.x, b.x)
out_y = bilinear_elementwise(a.y, b.y)
# Diff: a1·Δb + Δa·b2
term1 = bilinear_elementwise(a.x, b.diff)
term2 = bilinear_elementwise(a.diff, b.y)
out_diff = term1 + term2
return DiffExpr3(out_x, out_y, out_diff)
[docs]
def diff_bilinear_matmul(a: DiffExpr3, b: DiffExpr3) -> DiffExpr3:
out_x = _mm(a.x, b.x)
out_y = _mm(a.y, b.y)
# Diff: A1@(B1 - B2) + (A1 - A2)@B2
term1 = _mm(a.x, b.diff)
term2 = _mm(a.diff, b.y)
out_diff = term1 + term2
# Reset: if Z_Δ bound is wider than Z_x - Z_y, swap per-neuron.
from boundlab.prop import bound_width
sub_diff = out_x - out_y
bw_d = bound_width(out_diff)
bw_s = bound_width(sub_diff)
n_reset = (bw_s < bw_d).sum().item()
n_total = bw_d.numel()
max_d = bw_d.max().item()
max_s = bw_s.max().item()
mask = (bw_s < bw_d).float()
out_diff = mask * sub_diff + (1.0 - mask) * out_diff
return DiffExpr3(out_x, out_y, out_diff)
def diff_mul_handler(a, b):
if isinstance(a, DiffExpr3) and isinstance(b, DiffExpr3):
return diff_bilinear_elementwise(a, b)
if isinstance(a, DiffExpr3) and isinstance(b, DiffExpr2):
try:
return a * b
except TypeError:
b3 = DiffExpr3(b.x, b.y, b.x - b.y)
return diff_bilinear_elementwise(a, b3)
if isinstance(a, DiffExpr2) and isinstance(b, DiffExpr3):
try:
return b * a
except TypeError:
a3 = DiffExpr3(a.x, a.y, a.x - a.y)
return diff_bilinear_elementwise(a3, b)
if isinstance(a, DiffExpr2) and isinstance(b, DiffExpr2):
try:
return a * b
except TypeError:
a3 = DiffExpr3(a.x, a.y, a.x - a.y)
b3 = DiffExpr3(b.x, b.y, b.x - b.y)
return diff_bilinear_elementwise(a3, b3)
if isinstance(a, (DiffExpr3, DiffExpr2)):
return a * b
if isinstance(b, (DiffExpr3, DiffExpr2)):
return b * a
if isinstance(a, Expr) and isinstance(b, torch.Tensor):
return a * b
if isinstance(a, torch.Tensor) and isinstance(b, Expr):
return b * a
return a * b
def diff_matmul_handler(a, b):
from boundlab.zono.bilinear import matmul_handler as std_matmul_handler
if isinstance(a, DiffExpr2) and a.is_constant() or isinstance(b, (torch.Tensor, ConstVal)):
return a @ b
if isinstance(b, DiffExpr2) and b.is_constant() or isinstance(a, (torch.Tensor, ConstVal)):
return a @ b
if isinstance(a, DiffExpr2):
a = DiffExpr3(a.x, a.y, a.x - a.y)
if isinstance(b, DiffExpr2):
b = DiffExpr3(b.x, b.y, b.x - b.y)
if isinstance(a, DiffExpr3) or isinstance(b, DiffExpr3):
if isinstance(a, Expr):
a = DiffExpr3(a, a, expr.ConstVal(a.shape))
if isinstance(b, Expr):
b = DiffExpr3(b, b, expr.ConstVal(b.shape))
return diff_bilinear_matmul(a, b)
return std_matmul_handler(a, b)
__all__ = [
"diff_bilinear_elementwise",
"diff_bilinear_matmul",
"diff_mul_handler",
"diff_matmul_handler",
]