Source code for boundlab.zono.bilinear

"""Bilinear operation handlers for zonotope abstract interpretation.

Provides McCormick-style linearization for matmul and element-wise product
when both operands are symbolic expressions (Expr @ Expr or Expr * Expr),
plus the tighter DeepT-Precise relaxation from Bonaert et al. (2021).
"""

from functools import reduce
import operator
from typing import Union

import torch

from boundlab import prop
from boundlab.expr._core import Expr
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._base import LinearOp


[docs] def bilinear_matmul(A: Expr, B: Expr) -> Expr: r"""Linearize ``A @ B`` when both operands are symbolic expressions. A: (m, k), B: (k, n) → result: (m, n) The method uses a first-order expansion around expression centers: .. math:: A B \approx c_A B + A c_B - c_A c_B + E where :math:`E` is bounded by interval half-widths: .. math:: |E| \le \mathrm{hw}(A)\,\mathrm{hw}(B) and represented using fresh noise symbols. Args: A: Left expression with shape ``(m, k)``. B: Right expression with shape ``(k, n)``. Returns: An expression over-approximating ``A @ B``. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.zono.bilinear import bilinear_matmul >>> A = expr.ConstVal(torch.ones(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> B = expr.ConstVal(torch.ones(3, 4)) + 0.1 * expr.LpEpsilon([3, 4]) >>> C = bilinear_matmul(A, B) >>> C.shape torch.Size([2, 4]) """ return square_matmul(A, B)
# assert len(A.shape) >= 2 and len(B.shape) >= 2, \ # f"Need at least 2D for matmul, got {A.shape} @ {B.shape}" # assert A.shape[-1] == B.shape[-2], \ # f"Inner dims must match: {A.shape} @ {B.shape}" # Ac, As = A.split_const() # Ac: constant part, As: epsilon part # Bc, Bs = B.split_const() # Bc: constant part, Bs: epsilon part # result = Ac @ Bs + As @ Bc + Ac @ Bc # # Error bound: |E| ≤ hw(A) * hw(B) where hw = half-width # assert As.is_symmetric_to_0() and Bs.is_symmetric_to_0() # error_bound = As.ub() @ Bs.ub() # new_eps = LpEpsilon(error_bound.shape) # result = result + error_bound * new_eps # return result
[docs] def bilinear_elementwise(A: Expr, B: Expr) -> Expr: r"""Linearize element-wise product of two symbolic expressions. Both A and B must have the same shape. The approximation is: .. math:: A \odot B \approx c_A \odot B + A \odot c_B - c_A \odot c_B + E with element-wise error bound: .. math:: |E| \le \mathrm{hw}(A) \odot \mathrm{hw}(B) Args: A: First expression. B: Second expression (same shape as ``A``). Returns: An expression over-approximating ``A * B``. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.zono.bilinear import bilinear_elementwise >>> A = expr.ConstVal(torch.ones(3)) + 0.2 * expr.LpEpsilon([3]) >>> B = expr.ConstVal(torch.zeros(3)) + 0.3 * expr.LpEpsilon([3]) >>> C = bilinear_elementwise(A, B) >>> C.shape torch.Size([3]) """ assert A.shape == B.shape, \ f"Shapes must match for element-wise product: {A.shape} vs {B.shape}" Ac, As = A.split_const() # Ac: constant part, As: zero-constant part Bc, Bs = B.split_const() # Bc: constant part, Bs: zero-constant part result = Ac * Bs + As * Bc + Ac * Bc # Error bound: |E| ≤ hw(A) * hw(B) if As.is_symmetric_to_0(): Ahw = As.ub() else: A_ub, A_lb = As.ublb() Ahw = (A_ub - A_lb) / 2.0 if Bs.is_symmetric_to_0(): Bhw = Bs.ub() else: B_ub, B_lb = Bs.ublb() Bhw = (B_ub - B_lb) / 2.0 error_bound = Ahw * Bhw new_eps = LpEpsilon(error_bound.shape) result = result + error_bound * new_eps return result
def _eps_children(e: Expr) -> dict[LpEpsilon, LinearOp]: """Return ``{LpEpsilon: LinearOp}`` for the symbolic part of ``e``. ``e`` is assumed to be the zero-constant half of a symmetric decomposition: either an :class:`AffineSum` whose children are :class:`LpEpsilon` nodes, or a :class:`ConstVal` (empty children). """ if isinstance(e, AffineSum): children = dict(e.children_dict) for child in children: assert isinstance(child, LpEpsilon), \ f"DeepT-Precise requires LpEpsilon children, got {type(child).__name__}" return children return {} def deept_precise_matmul(A: Expr, B: Expr) -> Expr: r"""DeepT-Precise relaxation of ``A @ B`` (2D operands only). The bilinear error :math:`A_s B_s` is expanded by enumerating every pair of noise symbols ``(eps_A, eps_B)`` with ``eps_A`` drawn from ``A`` and ``eps_B`` from ``B``. For each pair: 1. Materialise both jacobians. 2. Contract the ``k`` dim per matmul rules to get :math:`T[i, j, s_A, s_B] = \sum_l \alpha_{eps_A, s_A}[i, l]\, \beta_{eps_B, s_B}[l, j]`. 3. If ``eps_A is eps_B`` (shared epsilon), the positions where ``s_A == s_B`` represent :math:`\epsilon^2 \in [0, 1]`: they contribute ``relu(T)`` to the upper-error tensor and ``relu(-T)`` to the lower-error tensor. Everywhere else (and every position of mismatched pairs) corresponds to :math:`\epsilon_i \epsilon_j \in [-1, 1]` and contributes ``|T|`` to both. 4. Sum over all input dims and accumulate into ``upper_err`` / ``lower_err``. The resulting asymmetric bilinear interval ``[-lower_err, +upper_err]`` is repackaged as ``center + half_width · ε_new`` where ``center = (upper_err - lower_err) / 2`` and ``half_width = (upper_err + lower_err) / 2``. Args: A: Left operand with shape ``(m, k)``. B: Right operand with shape ``(k, n)``. Returns: An expression over-approximating ``A @ B``. """ assert len(A.shape) >= 2 and len(B.shape) >= 2, \ f"Need at least 2D for matmul, got {A.shape} @ {B.shape}" assert A.shape[:-2] == B.shape[:-2], \ f"Batch dims must match: {A.shape} @ {B.shape}" assert A.shape[-1] == B.shape[-2], \ f"Inner dims must match: {A.shape} @ {B.shape}" Ac, As = A.split_const() Bc, Bs = B.split_const() result = Ac @ Bs + As @ Bc + Ac @ Bc A_children = _eps_children(As) B_children = _eps_children(Bs) b = reduce(operator.mul, A.shape[:-2], 1) m, k, n = A.shape[-2], A.shape[-1], B.shape[-1] err = torch.zeros(b, m, n) for eps_A, op_A in A_children.items(): jac_A = op_A.jacobian().reshape(b, m, k, -1) for eps_B, op_B in B_children.items(): jac_B = op_B.jacobian().reshape(b, k, n, -1) T = torch.einsum("bmki, bknj->bmnij", jac_A, jac_B) err += T.abs().sum(dim=(-2, -1)) err = err.reshape(A.shape[:-2] + (m, n)) new_eps = LpEpsilon(err.shape) result = result + err * new_eps return result
[docs] def matmul_handler(A, B): """Dispatcher implementation for ``torch.matmul``. Routing rules: - ``Expr @ Expr``: McCormick-style bilinear relaxation. - ``Expr @ Tensor`` or ``Tensor @ Expr``: exact affine path. - ``Tensor @ Tensor``: delegated to ``torch.matmul``. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.zono.bilinear import matmul_handler >>> A = expr.ConstVal(torch.ones(1, 2)) + expr.LpEpsilon([1, 2]) >>> B = torch.ones(2, 1) >>> matmul_handler(A, B).shape torch.Size([1, 1]) """ if isinstance(A, Expr) and _is_const(B): return A @ B # Expr.__matmul__(Tensor) elif _is_const(A) and isinstance(B, Expr): return A @ B # Tensor.__matmul__ → Expr.__rmatmul__(Tensor) elif _is_const(A) and _is_const(B): return torch.matmul(A, B) elif isinstance(A, Expr) and isinstance(B, Expr): # precise = deept_precise_matmul(A, B) normal = bilinear_matmul(A, B) return normal else: return NotImplemented
def _is_const(tensor: Union[torch.Tensor, Expr, int]) -> bool: return isinstance(tensor, torch.Tensor) or isinstance(tensor, ConstVal) or isinstance(tensor, int) def square_matmul(A: Expr, B: Expr) -> Expr: assert len(A.shape) >= 2 and len(B.shape) >= 2, \ f"Need at least 2D for matmul, got {A.shape} @ {B.shape}" assert A.shape[:-2] == B.shape[:-2], \ f"Batch dims must match: {A.shape} @ {B.shape}" assert A.shape[-1] == B.shape[-2], \ f"Inner dims must match: {A.shape} @ {B.shape}" Ac, As = A.split_const() Bc, Bs = B.split_const() if As.is_symmetric_to_0() and Bs.is_symmetric_to_0(): result = Ac @ Bs + As @ Bc + Ac @ Bc m, k, n = A.shape[-2], A.shape[-1], B.shape[-1] Au = As.ub() Bu = Bs.ub() U = Au @ Bu L = -U Au = Au.unsqueeze(-1).expand(*Au.shape, n) # (..., m, k, n) Bu = Bu.unsqueeze(-3).expand(*Bu.shape[:-2], m, k, n) # (..., m, k, n) As = As.unsqueeze(-1).expand(*As.shape, n) # (..., m, k, n) Bs = Bs.unsqueeze(-3).expand(*Bs.shape[:-2], m, k, n) # (..., m, k, n) a = torch.sqrt(Au) b = torch.sqrt(Bu) lama = a / b lamb = b / a Pos = torch.nan_to_num((lama * As + lamb * Bs).ub() ** 2 / 4, nan=1e10, posinf=1e10, neginf=1e10) Neg = torch.nan_to_num(-(lama * As - lamb * Bs).ub() ** 2 / 4, nan=-1e10, posinf=-1e10, neginf=-1e10) U = torch.minimum(Pos.sum(dim=-2), U) # (..., m, n) L = torch.maximum(Neg.sum(dim=-2), L) # (..., m, n) result += (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape) return result else: m, k, n = A.shape[-2], A.shape[-1], B.shape[-1] Au, Al = A.ublb() Bu, Bl = B.ublb() Ac = (Au + Al) / 2 As = A - Ac Bc = (Bu + Bl) / 2 Bs = B - Bc result = Ac @ Bc + As @ Bc + Ac @ Bs Asu = (Au - Al) / 2 Bsu = (Bu - Bl) / 2 U = Asu @ Bsu L = -U Asu = Asu.unsqueeze(-1).expand(*Asu.shape, n) # (..., m, k, n) Bsu = Bsu.unsqueeze(-3).expand(*Bsu.shape[:-2], m, k, n) # (..., m, k, n) As = As.unsqueeze(-1).expand(*As.shape, n) # (..., m, k, n) Bs = Bs.unsqueeze(-3).expand(*Bs.shape[:-2], m, k, n) # (..., m, k, n) # print("As:", As) # print("Bs:", Bs) a = torch.sqrt(Asu) b = torch.sqrt(Bsu) lama = a / b lamb = b / a pos_expr = lama * As + lamb * Bs neg_expr = lama * As - lamb * Bs # print("pos_expr:", pos_expr) # print("neg_expr:", neg_expr) # from boundlab.diff.zonohex import ZonoHexGate # if isinstance(pos_expr, AffineSum): # pos_expr = pos_expr.replace_subnode_once(lambda c: c.simplify() if isinstance(c, ZonoHexGate) else None) # if isinstance(neg_expr, AffineSum): # neg_expr = neg_expr.replace_subnode_once(lambda c: c.simplify() if isinstance(c, ZonoHexGate) else None) # print(pos_expr) # print(neg_expr) PosU, PosL = pos_expr.ublb() NegU, NegL = neg_expr.ublb() Pos = torch.nan_to_num(torch.maximum(PosU ** 2, PosL ** 2) / 4, nan=1e10, posinf=1e10, neginf=1e10) Neg = torch.nan_to_num(-torch.maximum(NegU ** 2, NegL ** 2) / 4, nan=-1e10, posinf=-1e10, neginf=-1e10) U = torch.minimum(Pos.sum(dim=-2), U) # (..., m, n) L = torch.maximum(Neg.sum(dim=-2), L) # (..., m, n) result += (U + L) / 2 + (U - L) / 2 * LpEpsilon(result.shape) return result