Source code for boundlab.diff.zonosq3.default.bilinear

"""Differential bilinear handlers backed by :mod:`boundlab.zonosq`."""

from __future__ import annotations

import torch

from boundlab import expr
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.expr._affine import ConstVal
from boundlab.expr._core import Expr
from boundlab.zonosq.bilinear import (
    bilinear_elementwise,
    bilinear_matmul,
    matmul_handler as zonosq_matmul_handler,
    zonosq_matmuls,
)


def _zero_like_expr(e: Expr) -> Expr:
    return expr.ConstVal(torch.zeros(e.shape))


[docs] def diff_bilinear_elementwise(a: DiffExpr3, b: DiffExpr3) -> DiffExpr3: assert a.shape == b.shape, f"Shapes must match: {a.shape} vs {b.shape}" out_x = bilinear_elementwise(a.x, b.x) out_y = bilinear_elementwise(a.y, b.y) # a1*b1 - a2*b2 = a1*(b1-b2) + (a1-a2)*b2 term1 = bilinear_elementwise(a.x, b.diff) term2 = bilinear_elementwise(a.diff, b.y) out_diff = term1 + term2 return DiffExpr3(out_x, out_y, out_diff)
[docs] def diff_bilinear_matmul(a: DiffExpr3, b: DiffExpr3) -> DiffExpr3: out_x = bilinear_matmul(a.x, b.x) out_y = bilinear_matmul(a.y, b.y) # A1@B1 - A2@B2 = A1@(B1-B2) + (A1-A2)@B2 # Abstract both products as a single matmul so the bilinear error bounds # shared epsilon symbols across both terms jointly (accumulating # coefficients before taking absolute values), which is at least as tight # as abstracting the two products independently and adding. out_diff = zonosq_matmuls((a.x, b.diff), (a.diff, b.y)) # Keep the tighter per-coordinate expression between direct diff tracking # and subtracting the two branch outputs. from boundlab.prop import bound_width sub_diff = out_x - out_y bw_d = bound_width(out_diff) bw_s = bound_width(sub_diff) finite_d = torch.isfinite(bw_d) finite_s = torch.isfinite(bw_s) use_sub = finite_s & ((~finite_d) | (bw_s < bw_d)) if bool(use_sub.all()): out_diff = sub_diff elif bool(use_sub.any()) and bool(finite_d.all()) and bool(finite_s.all()): mask = use_sub.float() out_diff = mask * sub_diff + (1.0 - mask) * out_diff return DiffExpr3(out_x, out_y, out_diff)
def _promote2(x: DiffExpr2) -> DiffExpr3: return DiffExpr3(x.x, x.y, x.x - x.y) def diff_mul_handler(a, b): if isinstance(a, DiffExpr3) and isinstance(b, DiffExpr3): return diff_bilinear_elementwise(a, b) if isinstance(a, DiffExpr3) and isinstance(b, DiffExpr2): try: return a * b except TypeError: return diff_bilinear_elementwise(a, _promote2(b)) if isinstance(a, DiffExpr2) and isinstance(b, DiffExpr3): try: return b * a except TypeError: return diff_bilinear_elementwise(_promote2(a), b) if isinstance(a, DiffExpr2) and isinstance(b, DiffExpr2): try: return a * b except TypeError: return diff_bilinear_elementwise(_promote2(a), _promote2(b)) if isinstance(a, (DiffExpr3, DiffExpr2)): return a * b if isinstance(b, (DiffExpr3, DiffExpr2)): return b * a if isinstance(a, Expr) and isinstance(b, torch.Tensor): return a * b if isinstance(a, torch.Tensor) and isinstance(b, Expr): return b * a return a * b def diff_matmul_handler(a, b): if (isinstance(a, DiffExpr2) and a.is_constant()) or isinstance(b, (torch.Tensor, ConstVal)): return a @ b if (isinstance(b, DiffExpr2) and b.is_constant()) or isinstance(a, (torch.Tensor, ConstVal)): return a @ b if isinstance(a, DiffExpr2): a = _promote2(a) if isinstance(b, DiffExpr2): b = _promote2(b) if isinstance(a, DiffExpr3) or isinstance(b, DiffExpr3): if isinstance(a, Expr): a = DiffExpr3(a, a, _zero_like_expr(a)) if isinstance(b, Expr): b = DiffExpr3(b, b, _zero_like_expr(b)) return diff_bilinear_matmul(a, b) return zonosq_matmul_handler(a, b) __all__ = [ "diff_bilinear_elementwise", "diff_bilinear_matmul", "diff_mul_handler", "diff_matmul_handler", ]