"""Bilinear handlers for the quadratic-zonotope interpreter."""
from __future__ import annotations
from functools import reduce
import operator
from typing import Union
import torch
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._base import LinearOp, ScalarOp, ZeroOp
from boundlab.linearop._einsum import EinsumOp
from boundlab.zono.bilinear import (
bilinear_elementwise,
square_matmul as _zono_square_matmul,
)
from .epsilon import ZonosqExpr
def _is_const(tensor: Union[torch.Tensor, Expr, int]) -> bool:
return isinstance(tensor, (torch.Tensor, ConstVal, int, float))
def _as_affine(expr: Expr) -> AffineSum:
if isinstance(expr, AffineSum):
return expr
return AffineSum((ScalarOp(1.0, expr.shape), expr))
def _numel(shape: torch.Size) -> int:
return reduce(operator.mul, shape, 1)
def _is_zero(op: LinearOp) -> bool:
return isinstance(op, ZeroOp)
def _coeff_op(coeff: torch.Tensor, eps: LpEpsilon) -> LinearOp:
tensor = coeff.reshape(coeff.shape[:-1] + eps.shape)
return EinsumOp.from_full(tensor, len(eps.shape))
def _add_coeff(out: ZonosqExpr, eps: LpEpsilon, coeff: torch.Tensor, idx: int) -> None:
if coeff.numel() == 0 or not coeff.any():
return
out.add(eps, _coeff_op(coeff, eps), idx)
def _contract_pair(
op_a: LinearOp,
op_b: LinearOp,
out_shape: torch.Size,
) -> torch.Tensor:
batch = _numel(out_shape[:-2])
m, n = out_shape[-2], out_shape[-1]
k = op_a.output_shape[-1]
a = op_a.jacobian().reshape(batch, m, k, -1)
b = op_b.jacobian().reshape(batch, k, n, -1)
t = torch.einsum("bmks,bknt->bmnst", a, b)
return t.reshape(out_shape + (a.shape[-1], b.shape[-1]))
def _is_literal_zero(value) -> bool:
return isinstance(value, int) and value == 0
def _matmul_or_zero(left, right, out_shape: torch.Size) -> Expr:
if _is_literal_zero(left) or _is_literal_zero(right):
return ConstVal(torch.zeros(out_shape))
return left @ right
def _split_const_or_symbolic(value: Expr):
try:
return value.split_const()
except NotImplementedError:
return 0, value
def _add_linear_linear(
out: ZonosqExpr,
eps_a: LpEpsilon,
eps_b: LpEpsilon,
coeff: torch.Tensor,
) -> None:
abs_coeff = coeff.abs()
if eps_a is eps_b:
diag = coeff.diagonal(dim1=-2, dim2=-1)
_add_coeff(out, eps_a, diag, 1)
offdiag = abs_coeff.clone()
diag_view = offdiag.diagonal(dim1=-2, dim2=-1)
diag_view.zero_()
err = 0.5 * offdiag.sum(dim=-1) + 0.5 * offdiag.sum(dim=-2)
_add_coeff(out, eps_a, err, 2)
return
_add_coeff(out, eps_a, 0.5 * abs_coeff.sum(dim=-1), 2)
_add_coeff(out, eps_b, 0.5 * abs_coeff.sum(dim=-2), 2)
def _add_linear_quadratic(
out: ZonosqExpr,
linear_eps: LpEpsilon,
quadratic_eps: LpEpsilon,
coeff: torch.Tensor,
) -> None:
del linear_eps
_add_coeff(out, quadratic_eps, coeff.abs().sum(dim=-2), 2)
def _add_quadratic_quadratic(
out: ZonosqExpr,
eps_a: LpEpsilon,
eps_b: LpEpsilon,
coeff: torch.Tensor,
) -> None:
abs_coeff = coeff.abs()
if eps_a is eps_b:
_add_coeff(out, eps_a, abs_coeff.sum(dim=-1), 2)
return
_add_coeff(out, eps_a, 0.5 * abs_coeff.sum(dim=-1), 2)
_add_coeff(out, eps_b, 0.5 * abs_coeff.sum(dim=-2), 2)
def _add_coeff_term(
out: ZonosqExpr,
eps_a: LpEpsilon,
kind_a: int,
eps_b: LpEpsilon,
kind_b: int,
coeff: torch.Tensor,
) -> None:
if kind_a == 0 and kind_b == 0:
_add_linear_linear(out, eps_a, eps_b, coeff)
elif kind_a == 0 and kind_b in (1, 2):
_add_linear_quadratic(out, eps_a, eps_b, coeff)
elif kind_b == 0 and kind_a in (1, 2):
_add_linear_quadratic(out, eps_b, eps_a, coeff.transpose(-1, -2))
elif kind_a in (1, 2) and kind_b in (1, 2):
_add_quadratic_quadratic(out, eps_a, eps_b, coeff)
else:
raise AssertionError(f"Unexpected zonosq term kinds: {kind_a}, {kind_b}")
def _matmul_out_shape(A: Expr, B: Expr) -> torch.Size:
assert len(A.shape) >= 2 and len(B.shape) >= 2, \
f"Need at least 2D for matmul, got {A.shape} @ {B.shape}"
assert A.shape[:-2] == B.shape[:-2], \
f"Batch dims must match: {A.shape} @ {B.shape}"
assert A.shape[-1] == B.shape[-2], \
f"Inner dims must match: {A.shape} @ {B.shape}"
return torch.Size(A.shape[:-2] + (A.shape[-2], B.shape[-1]))
[docs]
def zonosq_matmuls(*pairs: tuple[Expr, Expr]) -> Expr:
r"""Abstract the sum of matmuls ``Σ_i A_i @ B_i`` as a single zonosq expression.
Equivalent to concatenating the factors block-wise
(``[A_1 | … | A_n] @ [B_1; … ; B_n]``) and abstracting one matmul, but
without materialising the concatenation (which would hide the per-term
:class:`LpEpsilon` structure behind slicing ops). Crucially, the bilinear
coefficients of every product are accumulated **before** the error
abstraction takes absolute values, so contributions of shared epsilon
symbols cancel across terms — making this at least as tight as, and never
looser than, abstracting each ``A_i @ B_i`` independently and summing.
"""
assert len(pairs) >= 1, "zonosq_matmuls requires at least one (A, B) pair."
try:
out_shape = _matmul_out_shape(*pairs[0])
assert all(_matmul_out_shape(A, B) == out_shape for A, B in pairs[1:]), \
"All pairs must share the same output shape."
base: Expr | None = None
# (eps_a, kind_a, eps_b, kind_b) -> accumulated contraction coefficient.
coeff_terms: dict[tuple[LpEpsilon, int, LpEpsilon, int], torch.Tensor] = {}
for A, B in pairs:
Ac, As = _split_const_or_symbolic(A)
Bc, Bs = _split_const_or_symbolic(B)
pair_base = (
_matmul_or_zero(Ac, Bc, out_shape)
+ _matmul_or_zero(Ac, Bs, out_shape)
+ _matmul_or_zero(As, Bc, out_shape)
)
base = pair_base if base is None else base + pair_base
a_terms = ZonosqExpr(_as_affine(As)).children_dict
b_terms = ZonosqExpr(_as_affine(Bs)).children_dict
for eps_a, ops_a in a_terms.items():
for eps_b, ops_b in b_terms.items():
for kind_a, op_a in enumerate(ops_a):
if _is_zero(op_a):
continue
for kind_b, op_b in enumerate(ops_b):
if _is_zero(op_b):
continue
coeff = _contract_pair(op_a, op_b, out_shape)
key = (eps_a, kind_a, eps_b, kind_b)
prev = coeff_terms.get(key)
coeff_terms[key] = coeff if prev is None else prev + coeff
if not coeff_terms:
return base
except (AssertionError, TypeError, ValueError, NotImplementedError):
return reduce(operator.add, (_zono_square_matmul(A, B) for A, B in pairs))
out = ZonosqExpr(_as_affine(base))
for (eps_a, kind_a, eps_b, kind_b), coeff in coeff_terms.items():
_add_coeff_term(out, eps_a, kind_a, eps_b, kind_b, coeff)
return out.affine_sum(out_shape)
[docs]
def zonosq_matmul(A: Expr, B: Expr) -> Expr:
return zonosq_matmuls((A, B))
[docs]
def bilinear_matmul(A: Expr, B: Expr) -> Expr:
return zonosq_matmul(A, B)
[docs]
def matmul_handler(A, B):
if isinstance(A, Expr) and _is_const(B):
return A @ B
if _is_const(A) and isinstance(B, Expr):
return A @ B
if _is_const(A) and _is_const(B):
return torch.matmul(A, B)
if isinstance(A, Expr) and isinstance(B, Expr):
return zonosq_matmul(A, B)
return NotImplemented
__all__ = [
"bilinear_elementwise",
"bilinear_matmul",
"matmul_handler",
"zonosq_matmul",
"zonosq_matmuls",
]