Source code for boundlab.zono.bilinear

"""Bilinear operation handlers for zonotope abstract interpretation.

Provides McCormick-style linearization for matmul and element-wise product
when both operands are symbolic expressions (Expr @ Expr or Expr * Expr),
plus the tighter DeepT-Precise relaxation from Bonaert et al. (2021).
"""

from functools import reduce
import operator
import os
from typing import Union

import torch

from boundlab import prop
from boundlab.expr._core import Expr
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._base import LinearOp


[docs] def bilinear_matmul(A: Expr, B: Expr) -> Expr: r"""Linearize ``A @ B`` when both operands are symbolic expressions. A: (m, k), B: (k, n) → result: (m, n) The method uses a first-order expansion around expression centers: .. math:: A B \approx c_A B + A c_B - c_A c_B + E where :math:`E` is bounded by interval half-widths: .. math:: |E| \le \mathrm{hw}(A)\,\mathrm{hw}(B) and represented using fresh noise symbols. Args: A: Left expression with shape ``(m, k)``. B: Right expression with shape ``(k, n)``. Returns: An expression over-approximating ``A @ B``. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.zono.bilinear import bilinear_matmul >>> A = expr.ConstVal(torch.ones(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> B = expr.ConstVal(torch.ones(3, 4)) + 0.1 * expr.LpEpsilon([3, 4]) >>> C = bilinear_matmul(A, B) >>> C.shape torch.Size([2, 4]) """ return square_matmul(A, B)
# assert len(A.shape) >= 2 and len(B.shape) >= 2, \ # f"Need at least 2D for matmul, got {A.shape} @ {B.shape}" # assert A.shape[-1] == B.shape[-2], \ # f"Inner dims must match: {A.shape} @ {B.shape}" # Ac, As = A.split_const() # Ac: constant part, As: epsilon part # Bc, Bs = B.split_const() # Bc: constant part, Bs: epsilon part # result = Ac @ Bs + As @ Bc + Ac @ Bc # # Error bound: |E| ≤ hw(A) * hw(B) where hw = half-width # assert As.is_symmetric_to_0() and Bs.is_symmetric_to_0() # error_bound = As.ub() @ Bs.ub() # new_eps = LpEpsilon(error_bound.shape) # result = result + error_bound * new_eps # return result
[docs] def bilinear_elementwise(A: Expr, B: Expr) -> Expr: r"""Linearize element-wise product of two symbolic expressions. Both A and B must have the same shape. The approximation is: .. math:: A \odot B \approx c_A \odot B + A \odot c_B - c_A \odot c_B + E with element-wise error bound: .. math:: |E| \le \mathrm{hw}(A) \odot \mathrm{hw}(B) Args: A: First expression. B: Second expression (same shape as ``A``). Returns: An expression over-approximating ``A * B``. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.zono.bilinear import bilinear_elementwise >>> A = expr.ConstVal(torch.ones(3)) + 0.2 * expr.LpEpsilon([3]) >>> B = expr.ConstVal(torch.zeros(3)) + 0.3 * expr.LpEpsilon([3]) >>> C = bilinear_elementwise(A, B) >>> C.shape torch.Size([3]) """ assert A.shape == B.shape, \ f"Shapes must match for element-wise product: {A.shape} vs {B.shape}" Ac, As = A.split_const() # Ac: constant part, As: zero-constant part Bc, Bs = B.split_const() # Bc: constant part, Bs: zero-constant part result = Ac * Bs + As * Bc + Ac * Bc # Error bound: |E| ≤ hw(A) * hw(B) if As.is_symmetric_to_0(): Ahw = As.ub() else: A_ub, A_lb = As.ublb() Ahw = (A_ub - A_lb) / 2.0 if Bs.is_symmetric_to_0(): Bhw = Bs.ub() else: B_ub, B_lb = Bs.ublb() Bhw = (B_ub - B_lb) / 2.0 error_bound = Ahw * Bhw new_eps = LpEpsilon(error_bound.shape) result = result + error_bound * new_eps return result
def _eps_children(e: Expr) -> dict[LpEpsilon, LinearOp]: """Return ``{LpEpsilon: LinearOp}`` for the symbolic part of ``e``. ``e`` is assumed to be the zero-constant half of a symmetric decomposition: either an :class:`AffineSum` whose children are :class:`LpEpsilon` nodes, or a :class:`ConstVal` (empty children). """ if isinstance(e, AffineSum): children = dict(e.children_dict) for child in children: assert isinstance(child, LpEpsilon), \ f"DeepT-Precise requires LpEpsilon children, got {type(child).__name__}" return children return {} def _topk_truncate(Tc, topk): """Keep the ``topk`` largest-|coef| (i,j) slots per output entry of a coefficient tensor ``(..., S_lo, S_hi)``; return the truncated tensor and the dropped |coef| mass ``(...)`` to fold back into the fresh-epsilon residual.""" *pre, Slo, Shi = Tc.shape flat = Tc.reshape(*pre, Slo * Shi) K = min(topk, flat.shape[-1]) idx = flat.abs().topk(K, dim=-1).indices keep = torch.zeros_like(flat, dtype=torch.bool).scatter_(-1, idx, True) dropped = (flat.abs() * (~keep)).sum(dim=-1) flat = torch.where(keep, flat, torch.zeros_like(flat)) return flat.reshape(*pre, Slo, Shi), dropped def deept_precise_matmul(A: Expr, B: Expr, symmetrize: bool = False, interval_cap: bool = True, chunk_elems: int = 32_000_000) -> Expr: r"""DeepT-Precise relaxation of ``A @ B`` (2D operands only). ``symmetrize`` enables the ``"precise+sym"`` off-diagonal tightening (``|T_ij + T_ji|`` instead of ``|T_ij| + |T_ji|``); ``interval_cap`` keeps the naive interval product as a sound upper cap on the residual. The bilinear error :math:`A_s B_s` is expanded by enumerating every pair of noise symbols ``(eps_A, eps_B)`` with ``eps_A`` drawn from ``A`` and ``eps_B`` from ``B``. For each pair: 1. Materialise both jacobians. 2. Contract the ``k`` dim per matmul rules to get :math:`T[i, j, s_A, s_B] = \sum_l \alpha_{eps_A, s_A}[i, l]\, \beta_{eps_B, s_B}[l, j]`. 3. If ``eps_A is eps_B`` (shared epsilon), the positions where ``s_A == s_B`` represent :math:`\epsilon^2 \in [0, 1]`: they contribute ``relu(T)`` to the upper-error tensor and ``relu(-T)`` to the lower-error tensor. Everywhere else (and every position of mismatched pairs) corresponds to :math:`\epsilon_i \epsilon_j \in [-1, 1]` and contributes ``|T|`` to both. 4. Sum over all input dims and accumulate into ``upper_err`` / ``lower_err``. The resulting asymmetric bilinear interval ``[-lower_err, +upper_err]`` is repackaged as ``center + half_width · ε_new`` where ``center = (upper_err - lower_err) / 2`` and ``half_width = (upper_err + lower_err) / 2``. Args: A: Left operand with shape ``(m, k)``. B: Right operand with shape ``(k, n)``. Returns: An expression over-approximating ``A @ B``. """ assert len(A.shape) >= 2 and len(B.shape) >= 2, \ f"Need at least 2D for matmul, got {A.shape} @ {B.shape}" assert A.shape[:-2] == B.shape[:-2], \ f"Batch dims must match: {A.shape} @ {B.shape}" assert A.shape[-1] == B.shape[-2], \ f"Inner dims must match: {A.shape} @ {B.shape}" Ac, As = A.split_const() Bc, Bs = B.split_const() if not (As.is_symmetric_to_0() and Bs.is_symmetric_to_0()): # Rare: a non-symmetric symbolic part (asymmetric upstream construct). # square_matmul handles centering soundly; defer to it. return square_matmul(A, B) result = Ac @ Bs + As @ Bc + Ac @ Bc hwA = As.ub() # (..., m, k) half-widths (lb = -ub) hwB = Bs.ub() # (..., k, n) A_children = _eps_children(As) B_children = _eps_children(Bs) b = reduce(operator.mul, A.shape[:-2], 1) m, k, n = A.shape[-2], A.shape[-1], B.shape[-1] from boundlab.config import config use_reg = config.quad_registry topk = config.quad_registry_topk pair_terms = [] # [(coef (b,m,n,S_lo,S_hi), PairEpsilon)] # Accumulate an ASYMMETRIC residual interval [err_L, err_U] per (b, m, n): # the ε²∈[0,1] diagonal correction makes the residual genuinely one-sided. err_U = torch.zeros(b, m, n) err_L = torch.zeros(b, m, n) for eps_A, op_A in A_children.items(): jac_A = op_A.jacobian().reshape(b, m, k, -1) # (b, m, k, S_A) SA = jac_A.shape[-1] for eps_B, op_B in B_children.items(): jac_B = op_B.jacobian().reshape(b, k, n, -1) # (b, k, n, S_B) SB = jac_B.shape[-1] is_diag = eps_A is eps_B is_reg = (not is_diag and use_reg and getattr(eps_A, "is_input", False) and getattr(eps_B, "is_input", False)) if is_reg: # Distinct INPUT-level pair: keep as an identity-tracked # PairEpsilon coefficient (registry-shared) so equal-and-opposite # contributions cancel in a common AffineSum slot (e.g. across a # difference). Needs the full (b,m,n,S_lo,S_hi) coefficient tensor # for its EinsumOp, so materialize it (one pair at a time). from boundlab.expr._quad import get_pair T = torch.einsum("bmki,bknj->bmnij", jac_A, jac_B) pair, need_t = get_pair(eps_A, eps_B) Tc = T.transpose(-1, -2) if need_t else T # → (b,m,n,S_lo,S_hi) if topk and topk > 0: Tc, dropped = _topk_truncate(Tc, topk) err_U += dropped err_L -= dropped pair_terms.append((Tc, pair)) continue # Concretizing path (shared-ε ε² diagonal, or distinct fold): chunk # over n so the peak (b, m, chunk_n, S_A, S_B) einsum tensor stays # within the element budget. The residual is a per-output-entry sum, # so chunking only partitions it — numerically identical. chunk_n = max(1, min(n, chunk_elems // max(1, b * m * SA * SB))) for n0 in range(0, n, chunk_n): n1 = min(n0 + chunk_n, n) T = torch.einsum("bmki,bknj->bmnij", jac_A, jac_B[:, :, n0:n1, :]) if is_diag: # Shared ε quadratic form Σ_ij T_ij ε_i ε_j, ε ∈ [-1,1]^S: # diagonal ε_i²∈[0,1] → +relu(T_ii) up, -relu(-T_ii) down # off-diag ε_iε_j∈[-1,1] → ±|·| (symmetrized if requested) T_diag = T.diagonal(dim1=-2, dim2=-1) err_U[:, :, n0:n1] += torch.relu(T_diag).sum(dim=-1) err_L[:, :, n0:n1] -= torch.relu(-T_diag).sum(dim=-1) if symmetrize: Msym = 0.5 * (T + T.transpose(-1, -2)) off = (Msym.abs().sum(dim=(-2, -1)) - Msym.diagonal(dim1=-2, dim2=-1).abs().sum(dim=-1)) else: off = (T.abs().sum(dim=(-2, -1)) - T_diag.abs().sum(dim=-1)) err_U[:, :, n0:n1] += off err_L[:, :, n0:n1] -= off else: # Distinct symbols: each product ε_iε_j ∈ [-1, 1] independently. absum = T.abs().sum(dim=(-2, -1)) err_U[:, :, n0:n1] += absum err_L[:, :, n0:n1] -= absum err_U = err_U.reshape(A.shape[:-2] + (m, n)) err_L = err_L.reshape(A.shape[:-2] + (m, n)) # Never map an overflowed (inf) residual to a finite cap; NaN (0*inf) widens # to the conservative extreme. U = torch.where(torch.isnan(err_U), torch.full_like(err_U, torch.inf), err_U) L = torch.where(torch.isnan(err_L), torch.full_like(err_L, -torch.inf), err_L) if interval_cap: cap = hwA @ hwB # (..., m, n), >= 0 U = torch.minimum(U, cap) L = torch.maximum(L, -cap) result = result + (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape, reason="precise_matmul") # Attach identity-tracked distinct-pair terms (uncapped — each backward- # concretizes to Σ|T|, the same as folding, but the shared registry symbol # lets equal-and-opposite contributions cancel in a common AffineSum slot). if pair_terms: from boundlab.linearop._einsum import EinsumOp out_shape = tuple(A.shape[:-2]) + (m, n) no = len(out_shape) for Tc, pair in pair_terms: Slo, Shi = Tc.shape[-2], Tc.shape[-1] tensor = Tc.reshape(out_shape + (Slo, Shi)) op = EinsumOp(tensor, input_dims=[no, no + 1], output_dims=list(range(no))) result = result + AffineSum((op, pair)) return result
[docs] def matmul_handler(A, B): """Dispatcher implementation for ``torch.matmul``. Routing rules: - ``Expr @ Expr``: McCormick-style bilinear relaxation. - ``Expr @ Tensor`` or ``Tensor @ Expr``: exact affine path. - ``Tensor @ Tensor``: delegated to ``torch.matmul``. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.zono.bilinear import matmul_handler >>> A = expr.ConstVal(torch.ones(1, 2)) + expr.LpEpsilon([1, 2]) >>> B = torch.ones(2, 1) >>> matmul_handler(A, B).shape torch.Size([1, 1]) """ if isinstance(A, Expr) and _is_const(B): return A @ B # Expr.__matmul__(Tensor) elif _is_const(A) and isinstance(B, Expr): return A @ B # Tensor.__matmul__ → Expr.__rmatmul__(Tensor) elif _is_const(A) and _is_const(B): return torch.matmul(A, B) elif isinstance(A, Expr) and isinstance(B, Expr): from boundlab.config import config if config.matmul_mode in ("precise", "precise+sym"): R = deept_precise_matmul( A, B, symmetrize=(config.matmul_mode == "precise+sym"), interval_cap=config.matmul_interval_cap, ) else: R = bilinear_matmul(A, B) # Simplex clamp: when the left operand is a softmax output, each row sums # to 1, so S@V is a per-token convex combination of V's rows — bounded by # the per-column hull of V regardless of how loose the softmax zonotope # is. Cap the bilinear result against that hull where it is tighter. if config.softmax_simplex_clamp and getattr(A, "_simplex_rows", False): R = _simplex_clamp(R, B) return R else: return NotImplemented
def _simplex_clamp(R: Expr, V: Expr) -> Expr: r"""Cap ``R = S @ V`` (S softmax rows, summing to 1 over the contracted dim) by the per-column hull of ``V``: ``out[..., t, j] ∈ [min_l lb(V)[l, j], max_l ub(V)[l, j]]``. Sound (convex combination), and applied per entry only where it tightens ``R`` (mask of two independently-valid enclosures).""" Vu, Vl = V.ublb() box_hi = Vu.amax(dim=-2, keepdim=True) # (..., 1, d) box_lo = Vl.amin(dim=-2, keepdim=True) Ru, Rl = R.ublb() box_w = box_hi - box_lo # broadcasts over the token dim if torch.isfinite(Ru).all() and torch.isfinite(Rl).all(): mask = (box_w < (Ru - Rl)).to(Ru.dtype) # (..., T, d), 1 where box tighter center = ((box_hi + box_lo) / 2).expand_as(Ru) radius = (box_w / 2).expand_as(Ru) box_expr = ConstVal(center) + radius * LpEpsilon(list(R.shape), reason="simplex_clamp") return mask * box_expr + (1.0 - mask) * R # Catastrophic regime: R carries non-finite coefficients, so a 0/1 blend # would evaluate 0*inf = NaN. Drop R's symbolic structure and return the # per-entry intersection of R's interval (where finite) with the hull box. new_hi = torch.where(torch.isfinite(Ru), torch.minimum(Ru, box_hi.expand_as(Ru)), box_hi.expand_as(Ru)) new_lo = torch.where(torch.isfinite(Rl), torch.maximum(Rl, box_lo.expand_as(Rl)), box_lo.expand_as(Rl)) center = (new_hi + new_lo) / 2 radius = (new_hi - new_lo) / 2 return ConstVal(center) + radius * LpEpsilon(list(R.shape), reason="simplex_clamp") def _is_const(tensor: Union[torch.Tensor, Expr, int]) -> bool: return isinstance(tensor, torch.Tensor) or isinstance(tensor, ConstVal) or isinstance(tensor, int) def _square_matmul_symmetric_legacy(Ac, As, Bc, Bs): """Original Expr-level symmetric quarter-square path. Kept as a reference implementation (selected by env var BOUNDLAB_LEGACY_SQUARE_MATMUL) so the generator-level path can be checked for numerical equivalence.""" result = Ac @ Bs + As @ Bc + Ac @ Bc m, k = As.shape[-2], As.shape[-1] n = Bs.shape[-1] Au = As.ub() Bu = Bs.ub() U = Au @ Bu L = -U Au = Au.unsqueeze(-1).expand(*Au.shape, n) Bu = Bu.unsqueeze(-3).expand(*Bu.shape[:-2], m, k, n) As = As.unsqueeze(-1).expand(*As.shape, n) Bs = Bs.unsqueeze(-3).expand(*Bs.shape[:-2], m, k, n) a = torch.sqrt(Au) b = torch.sqrt(Bu) lama = a / b lamb = b / a _pos = (lama * As + lamb * Bs).ub() ** 2 / 4 _neg = -(lama * As - lamb * Bs).ub() ** 2 / 4 Pos = torch.where(torch.isnan(_pos), torch.full_like(_pos, torch.inf), _pos) Neg = torch.where(torch.isnan(_neg), torch.full_like(_neg, -torch.inf), _neg) U = torch.minimum(Pos.sum(dim=-2), U) L = torch.maximum(Neg.sum(dim=-2), L) result += (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape, reason="square_matmul") return result def _ubst_chunk(lama, lamb, HA, HB_chunk, shared_jac, n0, n1, chunk_s): r"""Compute ``ub|λa·As + λb·Bs|`` and ``ub|λa·As − λb·Bs|`` for an n-chunk, from the jacobians. ``lama``/``lamb`` broadcast as ``(b,m,k,nc)`` or ``(b,m,1,nc)`` (λ shared across the contraction dim l).""" base = lama * HA.unsqueeze(-1) + lamb * HB_chunk.unsqueeze(-3) # (b,m,k,nc) sterm = torch.zeros_like(base) tterm = torch.zeros_like(base) for jA, jB in shared_jac: jBc = jB[:, :, n0:n1, :] # (b, k, nc, S) S = jA.shape[-1] for s0 in range(0, S, chunk_s): s1 = min(s0 + chunk_s, S) pa = lama.unsqueeze(-1) * jA[:, :, :, s0:s1].unsqueeze(-2) # (b,m,k,nc,sc) pb = lamb.unsqueeze(-1) * jBc[:, :, :, s0:s1].unsqueeze(1) # (b,m,k,nc,sc) sterm = sterm + (pa + pb).abs().sum(-1) tterm = tterm + (pa - pb).abs().sum(-1) return base + sterm, base + tterm def _residual_chunk_size(b, m, k, maxS, chunk_elems, chunk_s, n): return max(1, min(n, chunk_elems // max(1, b * m * k * min(maxS, chunk_s)))) def _accumulate_UL(lama_fn, HA, HB, shared_jac, cap, b, m, k, n, chunk_n, chunk_s): """Final (no-grad) residual interval [L, U] (b,m,n) with the nan→±inf rescue and the interval cap, computed chunk by chunk.""" U = torch.empty(b, m, n) L = torch.empty(b, m, n) for n0 in range(0, n, chunk_n): n1 = min(n0 + chunk_n, n) lama, lamb = lama_fn(n0, n1) ub_s, ub_t = _ubst_chunk(lama, lamb, HA, HB[:, :, n0:n1], shared_jac, n0, n1, chunk_s) _pos = ub_s ** 2 / 4 _neg = -(ub_t ** 2) / 4 Pos = torch.where(torch.isnan(_pos), torch.full_like(_pos, torch.inf), _pos) Neg = torch.where(torch.isnan(_neg), torch.full_like(_neg, -torch.inf), _neg) cap_c = cap[:, :, n0:n1] U[:, :, n0:n1] = torch.minimum(Pos.sum(dim=-2), cap_c) L[:, :, n0:n1] = torch.maximum(Neg.sum(dim=-2), -cap_c) return U, L def _optimize_theta(theta0, HA, HB, shared_jac, cap, b, m, k, n, chunk_n, chunk_s, steps, lr=0.1): r"""α-CROWN-style optimization of the per-(i,l,j) quarter-square parameter. ``λa = exp(θ)``, ``λb = exp(−θ)`` (so ``λa·λb = 1`` and both > 0 — every θ is sound because the quarter-square identity is exact for any λ>0; only the bounding of the squares is relaxed). ``θ`` is initialized to the closed-form heuristic (step 0 reproduces the default bound), then Adam minimizes the total residual width ``Σ(U − L)``. Backward is done per n-chunk so peak memory stays bounded; θ is detached on return. Falls back to ``θ0`` if a non-finite loss/gradient appears (still sound — θ0 is the heuristic). """ theta = theta0.detach().clone().requires_grad_(True) opt = torch.optim.Adam([theta], lr=lr) best_theta = theta0.detach().clone() best_val = float("inf") for _ in range(steps): opt.zero_grad() total = 0.0 bad = False for n0 in range(0, n, chunk_n): n1 = min(n0 + chunk_n, n) th = theta[:, :, :, n0:n1] lama = torch.exp(th) lamb = torch.exp(-th) ub_s, ub_t = _ubst_chunk(lama, lamb, HA, HB[:, :, n0:n1], shared_jac, n0, n1, chunk_s) cap_c = cap[:, :, n0:n1] Uc = torch.minimum((ub_s ** 2 / 4).sum(dim=-2), cap_c) Lc = torch.maximum((-(ub_t ** 2) / 4).sum(dim=-2), -cap_c) loss = (Uc - Lc).sum() if not torch.isfinite(loss): bad = True break loss.backward() total += float(loss.detach()) if bad or theta.grad is None or not torch.isfinite(theta.grad).all(): break # keep best-so-far (θ0 at minimum) # Track the best θ seen (θ0 is evaluated at step 0), so the returned # bound's total width never exceeds the heuristic's. if total < best_val: best_val = total best_theta = theta.detach().clone() opt.step() return best_theta def _square_matmul_symmetric_gen(Ac, As, Bc, Bs, chunk_elems, chunk_s, optimize=False, opt_steps=25, lin_As=None, lin_Bs=None): r"""Generator-level symmetric quarter-square residual. Numerically equivalent to :func:`_square_matmul_symmetric_legacy` (for the heuristic λ) but never materializes the ``(..., m, k, n)`` symbolic operand (the source of the ``O(m·k·n·E)`` backward-pass blow-up). It bounds ``ub|λa·As + λb·Bs|[i,l,j] = Σ_symbols |coef|`` directly from the jacobians, splitting symbols into A-only, B-only and *shared* (combined before abs so A–B correlation is preserved), chunked over ``n`` (and the symbol axis). With ``optimize`` the per-(i,l,j) λ is tuned by Adam (α-CROWN style) from the heuristic init; any λ>0 is sound, so this only tightens. ``lin_As`` / ``lin_Bs`` override the operands used for the *exact affine* term ``Ac @ Bs + As @ Bc + Ac @ Bc`` only; the ``U``/``L`` bilinear residual is always computed from the (flat, ``LpEpsilon``-leaf) ``As`` / ``Bs``. This lets the zonohex path keep the un-flattened ``ZonoHexGate`` expressions in the affine term (so the gate's x/y coupling still cancels downstream in the differential) while the residual uses the flattened symmetric parts the jacobian machinery requires. """ lin_As = As if lin_As is None else lin_As lin_Bs = Bs if lin_Bs is None else lin_Bs result = Ac @ lin_Bs + lin_As @ Bc + Ac @ Bc *batch, m, k = As.shape n = Bs.shape[-1] b = reduce(operator.mul, batch, 1) Au = As.ub().reshape(b, m, k) Bu = Bs.ub().reshape(b, k, n) cap = torch.bmm(Au, Bu) # (b, m, n) A_children = _eps_children(As) B_children = _eps_children(Bs) shared = [c for c in A_children if c in B_children] a_only = [c for c in A_children if c not in B_children] b_only = [c for c in B_children if c not in A_children] HA = torch.zeros(b, m, k) for c in a_only: HA = HA + A_children[c].jacobian().reshape(b, m, k, -1).abs().sum(-1) HB = torch.zeros(b, k, n) for c in b_only: HB = HB + B_children[c].jacobian().reshape(b, k, n, -1).abs().sum(-1) shared_jac = [(A_children[c].jacobian().reshape(b, m, k, -1), B_children[c].jacobian().reshape(b, k, n, -1)) for c in shared] sa = torch.sqrt(Au) # (b, m, k) sb = torch.sqrt(Bu) # (b, k, n) maxS = max((jA.shape[-1] for jA, _ in shared_jac), default=1) chunk_n = _residual_chunk_size(b, m, k, maxS, chunk_elems, chunk_s, n) # Optimization only helps when there are shared symbols (i.e. A–B # correlation the no-correlation heuristic λ does not exploit). if optimize and shared_jac: tiny = torch.finfo(Au.dtype).tiny theta0 = 0.5 * (torch.log(Au.clamp_min(tiny)).unsqueeze(-1) - torch.log(Bu.clamp_min(tiny)).unsqueeze(-3)) # (b,m,k,n) theta = _optimize_theta(theta0, HA, HB, shared_jac, cap, b, m, k, n, chunk_n, chunk_s, opt_steps) def lama_fn(n0, n1): th = theta[:, :, :, n0:n1] return torch.exp(th), torch.exp(-th) else: def lama_fn(n0, n1): sb_c = sb[:, :, n0:n1] return (sa.unsqueeze(-1) / sb_c.unsqueeze(-3), sb_c.unsqueeze(-3) / sa.unsqueeze(-1)) U, L = _accumulate_UL(lama_fn, HA, HB, shared_jac, cap, b, m, k, n, chunk_n, chunk_s) out_shape = tuple(batch) + (m, n) U = U.reshape(out_shape) L = L.reshape(out_shape) result = result + (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape, reason="square_matmul") return result def _flatten_zonohex(s): """Flatten a ``ZonoHexGate`` symbolic operand to a plain ``AffineSum`` over ``LpEpsilon`` leaves. Returns ``(sym, const)`` where ``sym`` is the symmetric part and ``const`` the exposed constant (``0`` when there is no gate, so the non-zonohex path is byte-for-byte identical). For the zonohex differential path a gate in the symbolic operand is not flagged symmetric, so the fast generator residual cannot consume it. ``eqprop`` flattens the gate exactly (``==`` propagation, the same machinery ``ZonoHexGate.simplify`` uses). The gate output is zero-centered, so ``const`` is numerically ~0; it is returned so the full-flatten path can fold it into the center.""" from boundlab.diff.zonohex import ZonoHexGate if not any(isinstance(n, ZonoHexGate) for n in s.all_subnodes()): return s, 0 from boundlab.prop import eqprop c, sym = eqprop(s).split_const() return sym, c def square_matmul(A: Expr, B: Expr, chunk_elems: int = 32_000_000, chunk_s: int = 4096) -> Expr: assert len(A.shape) >= 2 and len(B.shape) >= 2, \ f"Need at least 2D for matmul, got {A.shape} @ {B.shape}" assert A.shape[:-2] == B.shape[:-2], \ f"Batch dims must match: {A.shape} @ {B.shape}" assert A.shape[-1] == B.shape[-2], \ f"Inner dims must match: {A.shape} @ {B.shape}" Ac, As = A.split_const() Bc, Bs = B.split_const() # zonohex fast path: a ZonoHexGate in the symbolic part is not flagged # symmetric, so it would otherwise fall into the slow per-element else branch # (materializes the (m,k,n) operand and calls ublb). Flatten the gate to a # plain symmetric AffineSum over LpEpsilon leaves (eqprop, via # ZonoHexGate.simplify's exact `==` machinery) — that restores the symmetric # flag and lets the fast generator path consume it. from boundlab.config import config As_res, ac2 = _flatten_zonohex(As) Bs_res, bc2 = _flatten_zonohex(Bs) if config.matmul_zonohex_flatten: # full-flatten: the affine term uses the flat operands too (gates stop # accumulating across layers → far faster on deep nets, marginally # looser). Fold the exposed (~0) constant into the center. lin_As, lin_Bs = As_res, Bs_res Ac, Bc = Ac + ac2, Bc + bc2 else: # gate-affine: the exact affine term keeps the un-flattened gate # (preserves x/y coupling → exact); only the U/L residual is flat. lin_As, lin_Bs = As, Bs if As_res.is_symmetric_to_0() and Bs_res.is_symmetric_to_0(): if os.environ.get("BOUNDLAB_LEGACY_SQUARE_MATMUL"): return _square_matmul_symmetric_legacy(Ac, As_res, Bc, Bs_res) return _square_matmul_symmetric_gen( Ac, As_res, Bc, Bs_res, chunk_elems, chunk_s, optimize=config.matmul_optimize_lambda, opt_steps=config.matmul_lambda_steps, lin_As=lin_As, lin_Bs=lin_Bs) else: m, k, n = A.shape[-2], A.shape[-1], B.shape[-1] Au, Al = A.ublb() Bu, Bl = B.ublb() Ac = (Au + Al) / 2 As = A - Ac Bc = (Bu + Bl) / 2 Bs = B - Bc result = Ac @ Bc + As @ Bc + Ac @ Bs Asu = (Au - Al) / 2 Bsu = (Bu - Bl) / 2 U = Asu @ Bsu L = -U Asu = Asu.unsqueeze(-1).expand(*Asu.shape, n) # (..., m, k, n) Bsu = Bsu.unsqueeze(-3).expand(*Bsu.shape[:-2], m, k, n) # (..., m, k, n) As = As.unsqueeze(-1).expand(*As.shape, n) # (..., m, k, n) Bs = Bs.unsqueeze(-3).expand(*Bs.shape[:-2], m, k, n) # (..., m, k, n) # print("As:", As) # print("Bs:", Bs) a = torch.sqrt(Asu) b = torch.sqrt(Bsu) lama = a / b lamb = b / a pos_expr = lama * As + lamb * Bs neg_expr = lama * As - lamb * Bs # print("pos_expr:", pos_expr) # print("neg_expr:", neg_expr) from boundlab.diff.zonohex import ZonoHexGate if isinstance(pos_expr, AffineSum): pos_expr = pos_expr.replace_subnode_once(lambda c: c.simplify() if isinstance(c, ZonoHexGate) else None) if isinstance(neg_expr, AffineSum): neg_expr = neg_expr.replace_subnode_once(lambda c: c.simplify() if isinstance(c, ZonoHexGate) else None) # print(pos_expr) # print(neg_expr) PosU, PosL = pos_expr.ublb() NegU, NegL = neg_expr.ublb() _pos = torch.maximum(PosU ** 2, PosL ** 2) / 4 _neg = -torch.maximum(NegU ** 2, NegL ** 2) / 4 Pos = torch.where(torch.isnan(_pos), torch.full_like(_pos, torch.inf), _pos) Neg = torch.where(torch.isnan(_neg), torch.full_like(_neg, -torch.inf), _neg) U = torch.minimum(Pos.sum(dim=-2), U) # (..., m, n) L = torch.maximum(Neg.sum(dim=-2), L) # (..., m, n) result += (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape) return result