"""Bilinear operation handlers for zonotope abstract interpretation.
Provides McCormick-style linearization for matmul and element-wise product
when both operands are symbolic expressions (Expr @ Expr or Expr * Expr),
plus the tighter DeepT-Precise relaxation from Bonaert et al. (2021).
"""
from functools import reduce
import operator
import os
from typing import Union
import torch
from boundlab import prop
from boundlab.expr._core import Expr
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._base import LinearOp
[docs]
def bilinear_matmul(A: Expr, B: Expr) -> Expr:
r"""Linearize ``A @ B`` when both operands are symbolic expressions.
A: (m, k), B: (k, n) → result: (m, n)
The method uses a first-order expansion around expression centers:
.. math::
A B \approx c_A B + A c_B - c_A c_B + E
where :math:`E` is bounded by interval half-widths:
.. math::
|E| \le \mathrm{hw}(A)\,\mathrm{hw}(B)
and represented using fresh noise symbols.
Args:
A: Left expression with shape ``(m, k)``.
B: Right expression with shape ``(k, n)``.
Returns:
An expression over-approximating ``A @ B``.
Examples
--------
>>> import torch
>>> import boundlab.expr as expr
>>> from boundlab.zono.bilinear import bilinear_matmul
>>> A = expr.ConstVal(torch.ones(2, 3)) + 0.1 * expr.LpEpsilon([2, 3])
>>> B = expr.ConstVal(torch.ones(3, 4)) + 0.1 * expr.LpEpsilon([3, 4])
>>> C = bilinear_matmul(A, B)
>>> C.shape
torch.Size([2, 4])
"""
return square_matmul(A, B)
# assert len(A.shape) >= 2 and len(B.shape) >= 2, \
# f"Need at least 2D for matmul, got {A.shape} @ {B.shape}"
# assert A.shape[-1] == B.shape[-2], \
# f"Inner dims must match: {A.shape} @ {B.shape}"
# Ac, As = A.split_const() # Ac: constant part, As: epsilon part
# Bc, Bs = B.split_const() # Bc: constant part, Bs: epsilon part
# result = Ac @ Bs + As @ Bc + Ac @ Bc
# # Error bound: |E| ≤ hw(A) * hw(B) where hw = half-width
# assert As.is_symmetric_to_0() and Bs.is_symmetric_to_0()
# error_bound = As.ub() @ Bs.ub()
# new_eps = LpEpsilon(error_bound.shape)
# result = result + error_bound * new_eps
# return result
[docs]
def bilinear_elementwise(A: Expr, B: Expr) -> Expr:
r"""Linearize element-wise product of two symbolic expressions.
Both A and B must have the same shape.
The approximation is:
.. math::
A \odot B \approx c_A \odot B + A \odot c_B - c_A \odot c_B + E
with element-wise error bound:
.. math::
|E| \le \mathrm{hw}(A) \odot \mathrm{hw}(B)
Args:
A: First expression.
B: Second expression (same shape as ``A``).
Returns:
An expression over-approximating ``A * B``.
Examples
--------
>>> import torch
>>> import boundlab.expr as expr
>>> from boundlab.zono.bilinear import bilinear_elementwise
>>> A = expr.ConstVal(torch.ones(3)) + 0.2 * expr.LpEpsilon([3])
>>> B = expr.ConstVal(torch.zeros(3)) + 0.3 * expr.LpEpsilon([3])
>>> C = bilinear_elementwise(A, B)
>>> C.shape
torch.Size([3])
"""
assert A.shape == B.shape, \
f"Shapes must match for element-wise product: {A.shape} vs {B.shape}"
Ac, As = A.split_const() # Ac: constant part, As: zero-constant part
Bc, Bs = B.split_const() # Bc: constant part, Bs: zero-constant part
result = Ac * Bs + As * Bc + Ac * Bc
# Error bound: |E| ≤ hw(A) * hw(B)
if As.is_symmetric_to_0():
Ahw = As.ub()
else:
A_ub, A_lb = As.ublb()
Ahw = (A_ub - A_lb) / 2.0
if Bs.is_symmetric_to_0():
Bhw = Bs.ub()
else:
B_ub, B_lb = Bs.ublb()
Bhw = (B_ub - B_lb) / 2.0
error_bound = Ahw * Bhw
new_eps = LpEpsilon(error_bound.shape)
result = result + error_bound * new_eps
return result
def _eps_children(e: Expr) -> dict[LpEpsilon, LinearOp]:
"""Return ``{LpEpsilon: LinearOp}`` for the symbolic part of ``e``.
``e`` is assumed to be the zero-constant half of a symmetric decomposition:
either an :class:`AffineSum` whose children are :class:`LpEpsilon` nodes,
or a :class:`ConstVal` (empty children).
"""
if isinstance(e, AffineSum):
children = dict(e.children_dict)
for child in children:
assert isinstance(child, LpEpsilon), \
f"DeepT-Precise requires LpEpsilon children, got {type(child).__name__}"
return children
return {}
def _topk_truncate(Tc, topk):
"""Keep the ``topk`` largest-|coef| (i,j) slots per output entry of a
coefficient tensor ``(..., S_lo, S_hi)``; return the truncated tensor and the
dropped |coef| mass ``(...)`` to fold back into the fresh-epsilon residual."""
*pre, Slo, Shi = Tc.shape
flat = Tc.reshape(*pre, Slo * Shi)
K = min(topk, flat.shape[-1])
idx = flat.abs().topk(K, dim=-1).indices
keep = torch.zeros_like(flat, dtype=torch.bool).scatter_(-1, idx, True)
dropped = (flat.abs() * (~keep)).sum(dim=-1)
flat = torch.where(keep, flat, torch.zeros_like(flat))
return flat.reshape(*pre, Slo, Shi), dropped
def deept_precise_matmul(A: Expr, B: Expr,
symmetrize: bool = False,
interval_cap: bool = True,
chunk_elems: int = 32_000_000) -> Expr:
r"""DeepT-Precise relaxation of ``A @ B`` (2D operands only).
``symmetrize`` enables the ``"precise+sym"`` off-diagonal tightening
(``|T_ij + T_ji|`` instead of ``|T_ij| + |T_ji|``); ``interval_cap`` keeps
the naive interval product as a sound upper cap on the residual.
The bilinear error :math:`A_s B_s` is expanded by enumerating every pair
of noise symbols ``(eps_A, eps_B)`` with ``eps_A`` drawn from ``A`` and
``eps_B`` from ``B``. For each pair:
1. Materialise both jacobians.
2. Contract the ``k`` dim per matmul rules to get
:math:`T[i, j, s_A, s_B] = \sum_l \alpha_{eps_A, s_A}[i, l]\,
\beta_{eps_B, s_B}[l, j]`.
3. If ``eps_A is eps_B`` (shared epsilon), the positions where
``s_A == s_B`` represent :math:`\epsilon^2 \in [0, 1]`: they contribute
``relu(T)`` to the upper-error tensor and ``relu(-T)`` to the
lower-error tensor. Everywhere else (and every position of mismatched
pairs) corresponds to :math:`\epsilon_i \epsilon_j \in [-1, 1]` and
contributes ``|T|`` to both.
4. Sum over all input dims and accumulate into ``upper_err`` / ``lower_err``.
The resulting asymmetric bilinear interval ``[-lower_err, +upper_err]``
is repackaged as ``center + half_width · ε_new`` where
``center = (upper_err - lower_err) / 2`` and
``half_width = (upper_err + lower_err) / 2``.
Args:
A: Left operand with shape ``(m, k)``.
B: Right operand with shape ``(k, n)``.
Returns:
An expression over-approximating ``A @ B``.
"""
assert len(A.shape) >= 2 and len(B.shape) >= 2, \
f"Need at least 2D for matmul, got {A.shape} @ {B.shape}"
assert A.shape[:-2] == B.shape[:-2], \
f"Batch dims must match: {A.shape} @ {B.shape}"
assert A.shape[-1] == B.shape[-2], \
f"Inner dims must match: {A.shape} @ {B.shape}"
Ac, As = A.split_const()
Bc, Bs = B.split_const()
if not (As.is_symmetric_to_0() and Bs.is_symmetric_to_0()):
# Rare: a non-symmetric symbolic part (asymmetric upstream construct).
# square_matmul handles centering soundly; defer to it.
return square_matmul(A, B)
result = Ac @ Bs + As @ Bc + Ac @ Bc
hwA = As.ub() # (..., m, k) half-widths (lb = -ub)
hwB = Bs.ub() # (..., k, n)
A_children = _eps_children(As)
B_children = _eps_children(Bs)
b = reduce(operator.mul, A.shape[:-2], 1)
m, k, n = A.shape[-2], A.shape[-1], B.shape[-1]
from boundlab.config import config
use_reg = config.quad_registry
topk = config.quad_registry_topk
pair_terms = [] # [(coef (b,m,n,S_lo,S_hi), PairEpsilon)]
# Accumulate an ASYMMETRIC residual interval [err_L, err_U] per (b, m, n):
# the ε²∈[0,1] diagonal correction makes the residual genuinely one-sided.
err_U = torch.zeros(b, m, n)
err_L = torch.zeros(b, m, n)
for eps_A, op_A in A_children.items():
jac_A = op_A.jacobian().reshape(b, m, k, -1) # (b, m, k, S_A)
SA = jac_A.shape[-1]
for eps_B, op_B in B_children.items():
jac_B = op_B.jacobian().reshape(b, k, n, -1) # (b, k, n, S_B)
SB = jac_B.shape[-1]
is_diag = eps_A is eps_B
is_reg = (not is_diag and use_reg
and getattr(eps_A, "is_input", False)
and getattr(eps_B, "is_input", False))
if is_reg:
# Distinct INPUT-level pair: keep as an identity-tracked
# PairEpsilon coefficient (registry-shared) so equal-and-opposite
# contributions cancel in a common AffineSum slot (e.g. across a
# difference). Needs the full (b,m,n,S_lo,S_hi) coefficient tensor
# for its EinsumOp, so materialize it (one pair at a time).
from boundlab.expr._quad import get_pair
T = torch.einsum("bmki,bknj->bmnij", jac_A, jac_B)
pair, need_t = get_pair(eps_A, eps_B)
Tc = T.transpose(-1, -2) if need_t else T # → (b,m,n,S_lo,S_hi)
if topk and topk > 0:
Tc, dropped = _topk_truncate(Tc, topk)
err_U += dropped
err_L -= dropped
pair_terms.append((Tc, pair))
continue
# Concretizing path (shared-ε ε² diagonal, or distinct fold): chunk
# over n so the peak (b, m, chunk_n, S_A, S_B) einsum tensor stays
# within the element budget. The residual is a per-output-entry sum,
# so chunking only partitions it — numerically identical.
chunk_n = max(1, min(n, chunk_elems // max(1, b * m * SA * SB)))
for n0 in range(0, n, chunk_n):
n1 = min(n0 + chunk_n, n)
T = torch.einsum("bmki,bknj->bmnij", jac_A, jac_B[:, :, n0:n1, :])
if is_diag:
# Shared ε quadratic form Σ_ij T_ij ε_i ε_j, ε ∈ [-1,1]^S:
# diagonal ε_i²∈[0,1] → +relu(T_ii) up, -relu(-T_ii) down
# off-diag ε_iε_j∈[-1,1] → ±|·| (symmetrized if requested)
T_diag = T.diagonal(dim1=-2, dim2=-1)
err_U[:, :, n0:n1] += torch.relu(T_diag).sum(dim=-1)
err_L[:, :, n0:n1] -= torch.relu(-T_diag).sum(dim=-1)
if symmetrize:
Msym = 0.5 * (T + T.transpose(-1, -2))
off = (Msym.abs().sum(dim=(-2, -1))
- Msym.diagonal(dim1=-2, dim2=-1).abs().sum(dim=-1))
else:
off = (T.abs().sum(dim=(-2, -1)) - T_diag.abs().sum(dim=-1))
err_U[:, :, n0:n1] += off
err_L[:, :, n0:n1] -= off
else:
# Distinct symbols: each product ε_iε_j ∈ [-1, 1] independently.
absum = T.abs().sum(dim=(-2, -1))
err_U[:, :, n0:n1] += absum
err_L[:, :, n0:n1] -= absum
err_U = err_U.reshape(A.shape[:-2] + (m, n))
err_L = err_L.reshape(A.shape[:-2] + (m, n))
# Never map an overflowed (inf) residual to a finite cap; NaN (0*inf) widens
# to the conservative extreme.
U = torch.where(torch.isnan(err_U), torch.full_like(err_U, torch.inf), err_U)
L = torch.where(torch.isnan(err_L), torch.full_like(err_L, -torch.inf), err_L)
if interval_cap:
cap = hwA @ hwB # (..., m, n), >= 0
U = torch.minimum(U, cap)
L = torch.maximum(L, -cap)
result = result + (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape,
reason="precise_matmul")
# Attach identity-tracked distinct-pair terms (uncapped — each backward-
# concretizes to Σ|T|, the same as folding, but the shared registry symbol
# lets equal-and-opposite contributions cancel in a common AffineSum slot).
if pair_terms:
from boundlab.linearop._einsum import EinsumOp
out_shape = tuple(A.shape[:-2]) + (m, n)
no = len(out_shape)
for Tc, pair in pair_terms:
Slo, Shi = Tc.shape[-2], Tc.shape[-1]
tensor = Tc.reshape(out_shape + (Slo, Shi))
op = EinsumOp(tensor, input_dims=[no, no + 1], output_dims=list(range(no)))
result = result + AffineSum((op, pair))
return result
[docs]
def matmul_handler(A, B):
"""Dispatcher implementation for ``torch.matmul``.
Routing rules:
- ``Expr @ Expr``: McCormick-style bilinear relaxation.
- ``Expr @ Tensor`` or ``Tensor @ Expr``: exact affine path.
- ``Tensor @ Tensor``: delegated to ``torch.matmul``.
Examples
--------
>>> import torch
>>> import boundlab.expr as expr
>>> from boundlab.zono.bilinear import matmul_handler
>>> A = expr.ConstVal(torch.ones(1, 2)) + expr.LpEpsilon([1, 2])
>>> B = torch.ones(2, 1)
>>> matmul_handler(A, B).shape
torch.Size([1, 1])
"""
if isinstance(A, Expr) and _is_const(B):
return A @ B # Expr.__matmul__(Tensor)
elif _is_const(A) and isinstance(B, Expr):
return A @ B # Tensor.__matmul__ → Expr.__rmatmul__(Tensor)
elif _is_const(A) and _is_const(B):
return torch.matmul(A, B)
elif isinstance(A, Expr) and isinstance(B, Expr):
from boundlab.config import config
if config.matmul_mode in ("precise", "precise+sym"):
R = deept_precise_matmul(
A, B,
symmetrize=(config.matmul_mode == "precise+sym"),
interval_cap=config.matmul_interval_cap,
)
else:
R = bilinear_matmul(A, B)
# Simplex clamp: when the left operand is a softmax output, each row sums
# to 1, so S@V is a per-token convex combination of V's rows — bounded by
# the per-column hull of V regardless of how loose the softmax zonotope
# is. Cap the bilinear result against that hull where it is tighter.
if config.softmax_simplex_clamp and getattr(A, "_simplex_rows", False):
R = _simplex_clamp(R, B)
return R
else:
return NotImplemented
def _simplex_clamp(R: Expr, V: Expr) -> Expr:
r"""Cap ``R = S @ V`` (S softmax rows, summing to 1 over the contracted dim)
by the per-column hull of ``V``: ``out[..., t, j] ∈ [min_l lb(V)[l, j],
max_l ub(V)[l, j]]``. Sound (convex combination), and applied per entry only
where it tightens ``R`` (mask of two independently-valid enclosures)."""
Vu, Vl = V.ublb()
box_hi = Vu.amax(dim=-2, keepdim=True) # (..., 1, d)
box_lo = Vl.amin(dim=-2, keepdim=True)
Ru, Rl = R.ublb()
box_w = box_hi - box_lo # broadcasts over the token dim
if torch.isfinite(Ru).all() and torch.isfinite(Rl).all():
mask = (box_w < (Ru - Rl)).to(Ru.dtype) # (..., T, d), 1 where box tighter
center = ((box_hi + box_lo) / 2).expand_as(Ru)
radius = (box_w / 2).expand_as(Ru)
box_expr = ConstVal(center) + radius * LpEpsilon(list(R.shape), reason="simplex_clamp")
return mask * box_expr + (1.0 - mask) * R
# Catastrophic regime: R carries non-finite coefficients, so a 0/1 blend
# would evaluate 0*inf = NaN. Drop R's symbolic structure and return the
# per-entry intersection of R's interval (where finite) with the hull box.
new_hi = torch.where(torch.isfinite(Ru), torch.minimum(Ru, box_hi.expand_as(Ru)),
box_hi.expand_as(Ru))
new_lo = torch.where(torch.isfinite(Rl), torch.maximum(Rl, box_lo.expand_as(Rl)),
box_lo.expand_as(Rl))
center = (new_hi + new_lo) / 2
radius = (new_hi - new_lo) / 2
return ConstVal(center) + radius * LpEpsilon(list(R.shape), reason="simplex_clamp")
def _is_const(tensor: Union[torch.Tensor, Expr, int]) -> bool:
return isinstance(tensor, torch.Tensor) or isinstance(tensor, ConstVal) or isinstance(tensor, int)
def _square_matmul_symmetric_legacy(Ac, As, Bc, Bs):
"""Original Expr-level symmetric quarter-square path. Kept as a reference
implementation (selected by env var BOUNDLAB_LEGACY_SQUARE_MATMUL) so the
generator-level path can be checked for numerical equivalence."""
result = Ac @ Bs + As @ Bc + Ac @ Bc
m, k = As.shape[-2], As.shape[-1]
n = Bs.shape[-1]
Au = As.ub()
Bu = Bs.ub()
U = Au @ Bu
L = -U
Au = Au.unsqueeze(-1).expand(*Au.shape, n)
Bu = Bu.unsqueeze(-3).expand(*Bu.shape[:-2], m, k, n)
As = As.unsqueeze(-1).expand(*As.shape, n)
Bs = Bs.unsqueeze(-3).expand(*Bs.shape[:-2], m, k, n)
a = torch.sqrt(Au)
b = torch.sqrt(Bu)
lama = a / b
lamb = b / a
_pos = (lama * As + lamb * Bs).ub() ** 2 / 4
_neg = -(lama * As - lamb * Bs).ub() ** 2 / 4
Pos = torch.where(torch.isnan(_pos), torch.full_like(_pos, torch.inf), _pos)
Neg = torch.where(torch.isnan(_neg), torch.full_like(_neg, -torch.inf), _neg)
U = torch.minimum(Pos.sum(dim=-2), U)
L = torch.maximum(Neg.sum(dim=-2), L)
result += (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape, reason="square_matmul")
return result
def _ubst_chunk(lama, lamb, HA, HB_chunk, shared_jac, n0, n1, chunk_s):
r"""Compute ``ub|λa·As + λb·Bs|`` and ``ub|λa·As − λb·Bs|`` for an n-chunk,
from the jacobians. ``lama``/``lamb`` broadcast as ``(b,m,k,nc)`` or
``(b,m,1,nc)`` (λ shared across the contraction dim l)."""
base = lama * HA.unsqueeze(-1) + lamb * HB_chunk.unsqueeze(-3) # (b,m,k,nc)
sterm = torch.zeros_like(base)
tterm = torch.zeros_like(base)
for jA, jB in shared_jac:
jBc = jB[:, :, n0:n1, :] # (b, k, nc, S)
S = jA.shape[-1]
for s0 in range(0, S, chunk_s):
s1 = min(s0 + chunk_s, S)
pa = lama.unsqueeze(-1) * jA[:, :, :, s0:s1].unsqueeze(-2) # (b,m,k,nc,sc)
pb = lamb.unsqueeze(-1) * jBc[:, :, :, s0:s1].unsqueeze(1) # (b,m,k,nc,sc)
sterm = sterm + (pa + pb).abs().sum(-1)
tterm = tterm + (pa - pb).abs().sum(-1)
return base + sterm, base + tterm
def _residual_chunk_size(b, m, k, maxS, chunk_elems, chunk_s, n):
return max(1, min(n, chunk_elems // max(1, b * m * k * min(maxS, chunk_s))))
def _accumulate_UL(lama_fn, HA, HB, shared_jac, cap, b, m, k, n, chunk_n, chunk_s):
"""Final (no-grad) residual interval [L, U] (b,m,n) with the nan→±inf
rescue and the interval cap, computed chunk by chunk."""
U = torch.empty(b, m, n)
L = torch.empty(b, m, n)
for n0 in range(0, n, chunk_n):
n1 = min(n0 + chunk_n, n)
lama, lamb = lama_fn(n0, n1)
ub_s, ub_t = _ubst_chunk(lama, lamb, HA, HB[:, :, n0:n1], shared_jac, n0, n1, chunk_s)
_pos = ub_s ** 2 / 4
_neg = -(ub_t ** 2) / 4
Pos = torch.where(torch.isnan(_pos), torch.full_like(_pos, torch.inf), _pos)
Neg = torch.where(torch.isnan(_neg), torch.full_like(_neg, -torch.inf), _neg)
cap_c = cap[:, :, n0:n1]
U[:, :, n0:n1] = torch.minimum(Pos.sum(dim=-2), cap_c)
L[:, :, n0:n1] = torch.maximum(Neg.sum(dim=-2), -cap_c)
return U, L
def _optimize_theta(theta0, HA, HB, shared_jac, cap, b, m, k, n, chunk_n, chunk_s,
steps, lr=0.1):
r"""α-CROWN-style optimization of the per-(i,l,j) quarter-square parameter.
``λa = exp(θ)``, ``λb = exp(−θ)`` (so ``λa·λb = 1`` and both > 0 — every θ is
sound because the quarter-square identity is exact for any λ>0; only the
bounding of the squares is relaxed). ``θ`` is initialized to the closed-form
heuristic (step 0 reproduces the default bound), then Adam minimizes the
total residual width ``Σ(U − L)``. Backward is done per n-chunk so peak
memory stays bounded; θ is detached on return. Falls back to ``θ0`` if a
non-finite loss/gradient appears (still sound — θ0 is the heuristic).
"""
theta = theta0.detach().clone().requires_grad_(True)
opt = torch.optim.Adam([theta], lr=lr)
best_theta = theta0.detach().clone()
best_val = float("inf")
for _ in range(steps):
opt.zero_grad()
total = 0.0
bad = False
for n0 in range(0, n, chunk_n):
n1 = min(n0 + chunk_n, n)
th = theta[:, :, :, n0:n1]
lama = torch.exp(th)
lamb = torch.exp(-th)
ub_s, ub_t = _ubst_chunk(lama, lamb, HA, HB[:, :, n0:n1], shared_jac,
n0, n1, chunk_s)
cap_c = cap[:, :, n0:n1]
Uc = torch.minimum((ub_s ** 2 / 4).sum(dim=-2), cap_c)
Lc = torch.maximum((-(ub_t ** 2) / 4).sum(dim=-2), -cap_c)
loss = (Uc - Lc).sum()
if not torch.isfinite(loss):
bad = True
break
loss.backward()
total += float(loss.detach())
if bad or theta.grad is None or not torch.isfinite(theta.grad).all():
break # keep best-so-far (θ0 at minimum)
# Track the best θ seen (θ0 is evaluated at step 0), so the returned
# bound's total width never exceeds the heuristic's.
if total < best_val:
best_val = total
best_theta = theta.detach().clone()
opt.step()
return best_theta
def _square_matmul_symmetric_gen(Ac, As, Bc, Bs, chunk_elems, chunk_s,
optimize=False, opt_steps=25,
lin_As=None, lin_Bs=None):
r"""Generator-level symmetric quarter-square residual.
Numerically equivalent to :func:`_square_matmul_symmetric_legacy` (for the
heuristic λ) but never materializes the ``(..., m, k, n)`` symbolic operand
(the source of the ``O(m·k·n·E)`` backward-pass blow-up). It bounds
``ub|λa·As + λb·Bs|[i,l,j] = Σ_symbols |coef|`` directly from the jacobians,
splitting symbols into A-only, B-only and *shared* (combined before abs so
A–B correlation is preserved), chunked over ``n`` (and the symbol axis).
With ``optimize`` the per-(i,l,j) λ is tuned by Adam (α-CROWN style) from the
heuristic init; any λ>0 is sound, so this only tightens.
``lin_As`` / ``lin_Bs`` override the operands used for the *exact affine*
term ``Ac @ Bs + As @ Bc + Ac @ Bc`` only; the ``U``/``L`` bilinear residual
is always computed from the (flat, ``LpEpsilon``-leaf) ``As`` / ``Bs``. This
lets the zonohex path keep the un-flattened ``ZonoHexGate`` expressions in
the affine term (so the gate's x/y coupling still cancels downstream in the
differential) while the residual uses the flattened symmetric parts the
jacobian machinery requires.
"""
lin_As = As if lin_As is None else lin_As
lin_Bs = Bs if lin_Bs is None else lin_Bs
result = Ac @ lin_Bs + lin_As @ Bc + Ac @ Bc
*batch, m, k = As.shape
n = Bs.shape[-1]
b = reduce(operator.mul, batch, 1)
Au = As.ub().reshape(b, m, k)
Bu = Bs.ub().reshape(b, k, n)
cap = torch.bmm(Au, Bu) # (b, m, n)
A_children = _eps_children(As)
B_children = _eps_children(Bs)
shared = [c for c in A_children if c in B_children]
a_only = [c for c in A_children if c not in B_children]
b_only = [c for c in B_children if c not in A_children]
HA = torch.zeros(b, m, k)
for c in a_only:
HA = HA + A_children[c].jacobian().reshape(b, m, k, -1).abs().sum(-1)
HB = torch.zeros(b, k, n)
for c in b_only:
HB = HB + B_children[c].jacobian().reshape(b, k, n, -1).abs().sum(-1)
shared_jac = [(A_children[c].jacobian().reshape(b, m, k, -1),
B_children[c].jacobian().reshape(b, k, n, -1)) for c in shared]
sa = torch.sqrt(Au) # (b, m, k)
sb = torch.sqrt(Bu) # (b, k, n)
maxS = max((jA.shape[-1] for jA, _ in shared_jac), default=1)
chunk_n = _residual_chunk_size(b, m, k, maxS, chunk_elems, chunk_s, n)
# Optimization only helps when there are shared symbols (i.e. A–B
# correlation the no-correlation heuristic λ does not exploit).
if optimize and shared_jac:
tiny = torch.finfo(Au.dtype).tiny
theta0 = 0.5 * (torch.log(Au.clamp_min(tiny)).unsqueeze(-1)
- torch.log(Bu.clamp_min(tiny)).unsqueeze(-3)) # (b,m,k,n)
theta = _optimize_theta(theta0, HA, HB, shared_jac, cap, b, m, k, n,
chunk_n, chunk_s, opt_steps)
def lama_fn(n0, n1):
th = theta[:, :, :, n0:n1]
return torch.exp(th), torch.exp(-th)
else:
def lama_fn(n0, n1):
sb_c = sb[:, :, n0:n1]
return (sa.unsqueeze(-1) / sb_c.unsqueeze(-3),
sb_c.unsqueeze(-3) / sa.unsqueeze(-1))
U, L = _accumulate_UL(lama_fn, HA, HB, shared_jac, cap, b, m, k, n, chunk_n, chunk_s)
out_shape = tuple(batch) + (m, n)
U = U.reshape(out_shape)
L = L.reshape(out_shape)
result = result + (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape, reason="square_matmul")
return result
def _flatten_zonohex(s):
"""Flatten a ``ZonoHexGate`` symbolic operand to a plain ``AffineSum`` over
``LpEpsilon`` leaves. Returns ``(sym, const)`` where ``sym`` is the symmetric
part and ``const`` the exposed constant (``0`` when there is no gate, so the
non-zonohex path is byte-for-byte identical).
For the zonohex differential path a gate in the symbolic operand is not
flagged symmetric, so the fast generator residual cannot consume it. ``eqprop``
flattens the gate exactly (``==`` propagation, the same machinery
``ZonoHexGate.simplify`` uses). The gate output is zero-centered, so ``const``
is numerically ~0; it is returned so the full-flatten path can fold it into
the center."""
from boundlab.diff.zonohex import ZonoHexGate
if not any(isinstance(n, ZonoHexGate) for n in s.all_subnodes()):
return s, 0
from boundlab.prop import eqprop
c, sym = eqprop(s).split_const()
return sym, c
def square_matmul(A: Expr, B: Expr, chunk_elems: int = 32_000_000,
chunk_s: int = 4096) -> Expr:
assert len(A.shape) >= 2 and len(B.shape) >= 2, \
f"Need at least 2D for matmul, got {A.shape} @ {B.shape}"
assert A.shape[:-2] == B.shape[:-2], \
f"Batch dims must match: {A.shape} @ {B.shape}"
assert A.shape[-1] == B.shape[-2], \
f"Inner dims must match: {A.shape} @ {B.shape}"
Ac, As = A.split_const()
Bc, Bs = B.split_const()
# zonohex fast path: a ZonoHexGate in the symbolic part is not flagged
# symmetric, so it would otherwise fall into the slow per-element else branch
# (materializes the (m,k,n) operand and calls ublb). Flatten the gate to a
# plain symmetric AffineSum over LpEpsilon leaves (eqprop, via
# ZonoHexGate.simplify's exact `==` machinery) — that restores the symmetric
# flag and lets the fast generator path consume it.
from boundlab.config import config
As_res, ac2 = _flatten_zonohex(As)
Bs_res, bc2 = _flatten_zonohex(Bs)
if config.matmul_zonohex_flatten:
# full-flatten: the affine term uses the flat operands too (gates stop
# accumulating across layers → far faster on deep nets, marginally
# looser). Fold the exposed (~0) constant into the center.
lin_As, lin_Bs = As_res, Bs_res
Ac, Bc = Ac + ac2, Bc + bc2
else:
# gate-affine: the exact affine term keeps the un-flattened gate
# (preserves x/y coupling → exact); only the U/L residual is flat.
lin_As, lin_Bs = As, Bs
if As_res.is_symmetric_to_0() and Bs_res.is_symmetric_to_0():
if os.environ.get("BOUNDLAB_LEGACY_SQUARE_MATMUL"):
return _square_matmul_symmetric_legacy(Ac, As_res, Bc, Bs_res)
return _square_matmul_symmetric_gen(
Ac, As_res, Bc, Bs_res, chunk_elems, chunk_s,
optimize=config.matmul_optimize_lambda,
opt_steps=config.matmul_lambda_steps,
lin_As=lin_As, lin_Bs=lin_Bs)
else:
m, k, n = A.shape[-2], A.shape[-1], B.shape[-1]
Au, Al = A.ublb()
Bu, Bl = B.ublb()
Ac = (Au + Al) / 2
As = A - Ac
Bc = (Bu + Bl) / 2
Bs = B - Bc
result = Ac @ Bc + As @ Bc + Ac @ Bs
Asu = (Au - Al) / 2
Bsu = (Bu - Bl) / 2
U = Asu @ Bsu
L = -U
Asu = Asu.unsqueeze(-1).expand(*Asu.shape, n) # (..., m, k, n)
Bsu = Bsu.unsqueeze(-3).expand(*Bsu.shape[:-2], m, k, n) # (..., m, k, n)
As = As.unsqueeze(-1).expand(*As.shape, n) # (..., m, k, n)
Bs = Bs.unsqueeze(-3).expand(*Bs.shape[:-2], m, k, n) # (..., m, k, n)
# print("As:", As)
# print("Bs:", Bs)
a = torch.sqrt(Asu)
b = torch.sqrt(Bsu)
lama = a / b
lamb = b / a
pos_expr = lama * As + lamb * Bs
neg_expr = lama * As - lamb * Bs
# print("pos_expr:", pos_expr)
# print("neg_expr:", neg_expr)
from boundlab.diff.zonohex import ZonoHexGate
if isinstance(pos_expr, AffineSum):
pos_expr = pos_expr.replace_subnode_once(lambda c: c.simplify() if isinstance(c, ZonoHexGate) else None)
if isinstance(neg_expr, AffineSum):
neg_expr = neg_expr.replace_subnode_once(lambda c: c.simplify() if isinstance(c, ZonoHexGate) else None)
# print(pos_expr)
# print(neg_expr)
PosU, PosL = pos_expr.ublb()
NegU, NegL = neg_expr.ublb()
_pos = torch.maximum(PosU ** 2, PosL ** 2) / 4
_neg = -torch.maximum(NegU ** 2, NegL ** 2) / 4
Pos = torch.where(torch.isnan(_pos), torch.full_like(_pos, torch.inf), _pos)
Neg = torch.where(torch.isnan(_neg), torch.full_like(_neg, -torch.inf), _neg)
U = torch.minimum(Pos.sum(dim=-2), U) # (..., m, n)
L = torch.maximum(Neg.sum(dim=-2), L) # (..., m, n)
result += (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape)
return result