Source code for boundlab.diff.zono3.default.bilinear

"""

Scalar Multiplication (element-wise):
    a1·b1 − a2·b2 = a1·∆b + ∆a·b2

Dot Product / Matrix Multiply:
    A1@B1 − A2@B2 = A1@∆B + ∆A@B2

"""

import torch

from boundlab import expr
from boundlab.expr._affine import ConstVal
from boundlab.expr._core import Expr
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.zono.bilinear import (
    bilinear_elementwise,
    bilinear_matmul,
    deept_precise_matmul,
    _is_const,
)


def _mm(p, q):
    """matmul for the differential terms, honoring ``config.matmul_mode``.

    Square mode (the default) keeps the exact ``bilinear_matmul`` quarter-square
    path, so existing results reproduce byte-for-byte. ``precise`` / ``precise+sym``
    route genuine zonotope x zonotope products (e.g. attention ``Q@K^T``,
    ``attn@V``) through the DeepT-Precise relaxation (the shared-eps eps^2 diagonal
    correction). Const/affine operands stay on ``bilinear_matmul`` -- that product
    is exact, so the relaxation choice is moot there. This mirrors how the
    single-model ``matmul_handler`` dispatches on ``config.matmul_mode``."""
    from boundlab.config import config
    if (config.matmul_mode in ("precise", "precise+sym")
            and not _is_const(p) and not _is_const(q)):
        return deept_precise_matmul(
            p, q,
            symmetrize=(config.matmul_mode == "precise+sym"),
            interval_cap=config.matmul_interval_cap,
        )
    return bilinear_matmul(p, q)


[docs] def diff_bilinear_elementwise(a: DiffExpr3, b: DiffExpr3) -> DiffExpr3: assert a.shape == b.shape, \ f"Shapes must match: {a.shape} vs {b.shape}" out_x = bilinear_elementwise(a.x, b.x) out_y = bilinear_elementwise(a.y, b.y) # Diff: a1·Δb + Δa·b2 term1 = bilinear_elementwise(a.x, b.diff) term2 = bilinear_elementwise(a.diff, b.y) out_diff = term1 + term2 return DiffExpr3(out_x, out_y, out_diff)
[docs] def diff_bilinear_matmul(a: DiffExpr3, b: DiffExpr3) -> DiffExpr3: out_x = _mm(a.x, b.x) out_y = _mm(a.y, b.y) # Diff: A1@(B1 - B2) + (A1 - A2)@B2 term1 = _mm(a.x, b.diff) term2 = _mm(a.diff, b.y) out_diff = term1 + term2 # Reset: if Z_Δ bound is wider than Z_x - Z_y, swap per-neuron. from boundlab.prop import bound_width sub_diff = out_x - out_y bw_d = bound_width(out_diff) bw_s = bound_width(sub_diff) n_reset = (bw_s < bw_d).sum().item() n_total = bw_d.numel() max_d = bw_d.max().item() max_s = bw_s.max().item() mask = (bw_s < bw_d).float() out_diff = mask * sub_diff + (1.0 - mask) * out_diff return DiffExpr3(out_x, out_y, out_diff)
def diff_mul_handler(a, b): if isinstance(a, DiffExpr3) and isinstance(b, DiffExpr3): return diff_bilinear_elementwise(a, b) if isinstance(a, DiffExpr3) and isinstance(b, DiffExpr2): try: return a * b except TypeError: b3 = DiffExpr3(b.x, b.y, b.x - b.y) return diff_bilinear_elementwise(a, b3) if isinstance(a, DiffExpr2) and isinstance(b, DiffExpr3): try: return b * a except TypeError: a3 = DiffExpr3(a.x, a.y, a.x - a.y) return diff_bilinear_elementwise(a3, b) if isinstance(a, DiffExpr2) and isinstance(b, DiffExpr2): try: return a * b except TypeError: a3 = DiffExpr3(a.x, a.y, a.x - a.y) b3 = DiffExpr3(b.x, b.y, b.x - b.y) return diff_bilinear_elementwise(a3, b3) if isinstance(a, (DiffExpr3, DiffExpr2)): return a * b if isinstance(b, (DiffExpr3, DiffExpr2)): return b * a if isinstance(a, Expr) and isinstance(b, torch.Tensor): return a * b if isinstance(a, torch.Tensor) and isinstance(b, Expr): return b * a return a * b def diff_matmul_handler(a, b): from boundlab.zono.bilinear import matmul_handler as std_matmul_handler if isinstance(a, DiffExpr2) and a.is_constant() or isinstance(b, (torch.Tensor, ConstVal)): return a @ b if isinstance(b, DiffExpr2) and b.is_constant() or isinstance(a, (torch.Tensor, ConstVal)): return a @ b if isinstance(a, DiffExpr2): a = DiffExpr3(a.x, a.y, a.x - a.y) if isinstance(b, DiffExpr2): b = DiffExpr3(b.x, b.y, b.x - b.y) if isinstance(a, DiffExpr3) or isinstance(b, DiffExpr3): if isinstance(a, Expr): a = DiffExpr3(a, a, expr.ConstVal(a.shape)) if isinstance(b, Expr): b = DiffExpr3(b, b, expr.ConstVal(b.shape)) return diff_bilinear_matmul(a, b) return std_matmul_handler(a, b) __all__ = [ "diff_bilinear_elementwise", "diff_bilinear_matmul", "diff_mul_handler", "diff_matmul_handler", ]