Source code for boundlab.zonosq.bilinear

"""Bilinear handlers for the quadratic-zonotope interpreter."""

from __future__ import annotations

from functools import reduce
import operator
from typing import Union

import torch

from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._base import LinearOp, ScalarOp, ZeroOp
from boundlab.linearop._einsum import EinsumOp
from boundlab.zono.bilinear import (
    bilinear_elementwise,
    square_matmul as _zono_square_matmul,
)

from .epsilon import ZonosqExpr


def _is_const(tensor: Union[torch.Tensor, Expr, int]) -> bool:
    return isinstance(tensor, (torch.Tensor, ConstVal, int, float))


def _as_affine(expr: Expr) -> AffineSum:
    if isinstance(expr, AffineSum):
        return expr
    return AffineSum((ScalarOp(1.0, expr.shape), expr))


def _numel(shape: torch.Size) -> int:
    return reduce(operator.mul, shape, 1)


def _is_zero(op: LinearOp) -> bool:
    return isinstance(op, ZeroOp)


def _coeff_op(coeff: torch.Tensor, eps: LpEpsilon) -> LinearOp:
    tensor = coeff.reshape(coeff.shape[:-1] + eps.shape)
    return EinsumOp.from_full(tensor, len(eps.shape))


def _add_coeff(out: ZonosqExpr, eps: LpEpsilon, coeff: torch.Tensor, idx: int) -> None:
    if coeff.numel() == 0 or not coeff.any():
        return
    out.add(eps, _coeff_op(coeff, eps), idx)


def _contract_pair(
    op_a: LinearOp,
    op_b: LinearOp,
    out_shape: torch.Size,
) -> torch.Tensor:
    batch = _numel(out_shape[:-2])
    m, n = out_shape[-2], out_shape[-1]
    k = op_a.output_shape[-1]
    a = op_a.jacobian().reshape(batch, m, k, -1)
    b = op_b.jacobian().reshape(batch, k, n, -1)
    t = torch.einsum("bmks,bknt->bmnst", a, b)
    return t.reshape(out_shape + (a.shape[-1], b.shape[-1]))


def _is_literal_zero(value) -> bool:
    return isinstance(value, int) and value == 0


def _matmul_or_zero(left, right, out_shape: torch.Size) -> Expr:
    if _is_literal_zero(left) or _is_literal_zero(right):
        return ConstVal(torch.zeros(out_shape))
    return left @ right


def _split_const_or_symbolic(value: Expr):
    try:
        return value.split_const()
    except NotImplementedError:
        return 0, value


def _add_linear_linear(
    out: ZonosqExpr,
    eps_a: LpEpsilon,
    eps_b: LpEpsilon,
    coeff: torch.Tensor,
) -> None:
    abs_coeff = coeff.abs()
    if eps_a is eps_b:
        diag = coeff.diagonal(dim1=-2, dim2=-1)
        _add_coeff(out, eps_a, diag, 1)

        offdiag = abs_coeff.clone()
        diag_view = offdiag.diagonal(dim1=-2, dim2=-1)
        diag_view.zero_()
        err = 0.5 * offdiag.sum(dim=-1) + 0.5 * offdiag.sum(dim=-2)
        _add_coeff(out, eps_a, err, 2)
        return

    _add_coeff(out, eps_a, 0.5 * abs_coeff.sum(dim=-1), 2)
    _add_coeff(out, eps_b, 0.5 * abs_coeff.sum(dim=-2), 2)


def _add_linear_quadratic(
    out: ZonosqExpr,
    linear_eps: LpEpsilon,
    quadratic_eps: LpEpsilon,
    coeff: torch.Tensor,
) -> None:
    del linear_eps
    _add_coeff(out, quadratic_eps, coeff.abs().sum(dim=-2), 2)


def _add_quadratic_quadratic(
    out: ZonosqExpr,
    eps_a: LpEpsilon,
    eps_b: LpEpsilon,
    coeff: torch.Tensor,
) -> None:
    abs_coeff = coeff.abs()
    if eps_a is eps_b:
        _add_coeff(out, eps_a, abs_coeff.sum(dim=-1), 2)
        return
    _add_coeff(out, eps_a, 0.5 * abs_coeff.sum(dim=-1), 2)
    _add_coeff(out, eps_b, 0.5 * abs_coeff.sum(dim=-2), 2)


def _add_coeff_term(
    out: ZonosqExpr,
    eps_a: LpEpsilon,
    kind_a: int,
    eps_b: LpEpsilon,
    kind_b: int,
    coeff: torch.Tensor,
) -> None:
    if kind_a == 0 and kind_b == 0:
        _add_linear_linear(out, eps_a, eps_b, coeff)
    elif kind_a == 0 and kind_b in (1, 2):
        _add_linear_quadratic(out, eps_a, eps_b, coeff)
    elif kind_b == 0 and kind_a in (1, 2):
        _add_linear_quadratic(out, eps_b, eps_a, coeff.transpose(-1, -2))
    elif kind_a in (1, 2) and kind_b in (1, 2):
        _add_quadratic_quadratic(out, eps_a, eps_b, coeff)
    else:
        raise AssertionError(f"Unexpected zonosq term kinds: {kind_a}, {kind_b}")


def _matmul_out_shape(A: Expr, B: Expr) -> torch.Size:
    assert len(A.shape) >= 2 and len(B.shape) >= 2, \
        f"Need at least 2D for matmul, got {A.shape} @ {B.shape}"
    assert A.shape[:-2] == B.shape[:-2], \
        f"Batch dims must match: {A.shape} @ {B.shape}"
    assert A.shape[-1] == B.shape[-2], \
        f"Inner dims must match: {A.shape} @ {B.shape}"
    return torch.Size(A.shape[:-2] + (A.shape[-2], B.shape[-1]))


[docs] def zonosq_matmuls(*pairs: tuple[Expr, Expr]) -> Expr: r"""Abstract the sum of matmuls ``Σ_i A_i @ B_i`` as a single zonosq expression. Equivalent to concatenating the factors block-wise (``[A_1 | … | A_n] @ [B_1; … ; B_n]``) and abstracting one matmul, but without materialising the concatenation (which would hide the per-term :class:`LpEpsilon` structure behind slicing ops). Crucially, the bilinear coefficients of every product are accumulated **before** the error abstraction takes absolute values, so contributions of shared epsilon symbols cancel across terms — making this at least as tight as, and never looser than, abstracting each ``A_i @ B_i`` independently and summing. """ assert len(pairs) >= 1, "zonosq_matmuls requires at least one (A, B) pair." try: out_shape = _matmul_out_shape(*pairs[0]) assert all(_matmul_out_shape(A, B) == out_shape for A, B in pairs[1:]), \ "All pairs must share the same output shape." base: Expr | None = None # (eps_a, kind_a, eps_b, kind_b) -> accumulated contraction coefficient. coeff_terms: dict[tuple[LpEpsilon, int, LpEpsilon, int], torch.Tensor] = {} for A, B in pairs: Ac, As = _split_const_or_symbolic(A) Bc, Bs = _split_const_or_symbolic(B) pair_base = ( _matmul_or_zero(Ac, Bc, out_shape) + _matmul_or_zero(Ac, Bs, out_shape) + _matmul_or_zero(As, Bc, out_shape) ) base = pair_base if base is None else base + pair_base a_terms = ZonosqExpr(_as_affine(As)).children_dict b_terms = ZonosqExpr(_as_affine(Bs)).children_dict for eps_a, ops_a in a_terms.items(): for eps_b, ops_b in b_terms.items(): for kind_a, op_a in enumerate(ops_a): if _is_zero(op_a): continue for kind_b, op_b in enumerate(ops_b): if _is_zero(op_b): continue coeff = _contract_pair(op_a, op_b, out_shape) key = (eps_a, kind_a, eps_b, kind_b) prev = coeff_terms.get(key) coeff_terms[key] = coeff if prev is None else prev + coeff if not coeff_terms: return base except (AssertionError, TypeError, ValueError, NotImplementedError): return reduce(operator.add, (_zono_square_matmul(A, B) for A, B in pairs)) out = ZonosqExpr(_as_affine(base)) for (eps_a, kind_a, eps_b, kind_b), coeff in coeff_terms.items(): _add_coeff_term(out, eps_a, kind_a, eps_b, kind_b, coeff) return out.affine_sum(out_shape)
[docs] def zonosq_matmul(A: Expr, B: Expr) -> Expr: return zonosq_matmuls((A, B))
[docs] def bilinear_matmul(A: Expr, B: Expr) -> Expr: return zonosq_matmul(A, B)
[docs] def matmul_handler(A, B): if isinstance(A, Expr) and _is_const(B): return A @ B if _is_const(A) and isinstance(B, Expr): return A @ B if _is_const(A) and _is_const(B): return torch.matmul(A, B) if isinstance(A, Expr) and isinstance(B, Expr): return zonosq_matmul(A, B) return NotImplemented
__all__ = [ "bilinear_elementwise", "bilinear_matmul", "matmul_handler", "zonosq_matmul", "zonosq_matmuls", ]