Source code for boundlab.zonosq.epsilon

"""Quadratic epsilon terms for the zonosq interpreter."""

from __future__ import annotations

from typing import Literal, Union

import torch

from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._tuple import GetTupleItem, TupleExpr
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._base import LinearOp, ZeroOp


[docs] class QuadraticEpsilon(TupleExpr): r"""Tuple of ``(eps, eps^2, eps' * eps^2)`` for ``eps, eps' in [-1, 1]``."""
[docs] def __init__(self, epsilon: LpEpsilon): super().__init__() self.epsilon = epsilon self.inner_shape = epsilon.shape
def __eq__(self, other): if not isinstance(other, QuadraticEpsilon): return False return self.epsilon == other.epsilon def __hash__(self): return hash(self.epsilon) @property def shape(self) -> tuple[torch.Size, ...]: return (self.inner_shape, self.inner_shape, self.inner_shape) @property def children(self) -> tuple[Expr, ...]: """Children expressions that contribute to this TupleExpr. This is used for topological sorting and weight propagation.""" return () @staticmethod def _materialize_weight(weight, output_shape: torch.Size, input_shape: torch.Size) -> torch.Tensor: if weight == 0: return torch.zeros(output_shape + input_shape) assert isinstance(weight, LinearOp), ( f"QuadraticEpsilon expects LinearOp weights, got {type(weight).__name__}" ) return weight.jacobian().reshape(output_shape + input_shape) @staticmethod def _quadratic_extremum( linear: torch.Tensor, quadratic: torch.Tensor, *, upper: bool, input_ndim: int, ) -> torch.Tensor: endpoint_hi = torch.maximum(linear + quadratic, -linear + quadratic) endpoint_lo = torch.minimum(linear + quadratic, -linear + quadratic) if upper: vertex_mask = quadratic < 0 fill = torch.full_like(quadratic, -torch.inf) endpoint = endpoint_hi reduce_fn = torch.maximum else: vertex_mask = quadratic > 0 fill = torch.full_like(quadratic, torch.inf) endpoint = endpoint_lo reduce_fn = torch.minimum safe_quadratic = torch.where(vertex_mask, quadratic, torch.ones_like(quadratic)) vertex = -linear / (2.0 * safe_quadratic) vertex_mask = vertex_mask & (vertex >= -1.0) & (vertex <= 1.0) vertex_value = linear * vertex + quadratic * vertex * vertex value = reduce_fn(endpoint, torch.where(vertex_mask, vertex_value, fill)) if input_ndim == 0: return value input_dims = tuple(range(value.dim() - input_ndim, value.dim())) return value.sum(dim=input_dims)
[docs] def backward(self, *weights, direction="==") -> tuple[Union[torch.Tensor, Literal[0]], list] | None: """Concretize ``a*eps + b*eps^2 + c*eps'*eps^2`` for each output.""" w1, w2, w3 = weights if direction == "==": return None output_shape = next( (w.output_shape for w in weights if isinstance(w, LinearOp)), torch.Size(()), ) w1_t = self._materialize_weight(w1, output_shape, self.inner_shape) w2_t = self._materialize_weight(w2, output_shape, self.inner_shape) w3_t = self._materialize_weight(w3, output_shape, self.inner_shape) quadratic = w2_t + w3_t.abs() if direction == "<=" else w2_t - w3_t.abs() return self._quadratic_extremum( w1_t, quadratic, upper=(direction == "<="), input_ndim=len(self.inner_shape), ), []
[docs] def with_children(self, *new_children: Expr) -> "TupleExpr": """Return a new TupleExpr with the same flags but new children. This is used for expression rewriting during bound propagation.""" return self
[docs] def to_string(self, *children_str: str, indent: int = 0) -> str: """Convert this expression to a string for debugging purposes. The children_str arguments are the string representations of the children expressions, in the same order as self.children.""" return self.epsilon.to_string().replace("𝜀", "(𝜀, 𝜀^2, 𝜀' 𝜀^2)")
[docs] class ZonosqExpr: constant: Union[torch.Tensor, None] children_dict: dict[LpEpsilon, list[LinearOp]]
[docs] def __init__(self, affine_sum: AffineSum): self.constant = affine_sum.constant self.children_dict = {} for expr, lin in affine_sum.children_dict.items(): if isinstance(expr, LpEpsilon): self.add(expr, lin, 0) elif isinstance(expr, GetTupleItem) and isinstance(expr.tuple_expr, QuadraticEpsilon): self.add(expr.tuple_expr.epsilon, lin, expr.index) else: raise ValueError(f"Unexpected expression type: {type(expr)}")
[docs] def add(self, expr: LpEpsilon, lin: LinearOp, idx: int) -> None: if expr not in self.children_dict: self.children_dict[expr] = [ZeroOp.like(lin), ZeroOp.like(lin), ZeroOp.like(lin)] self.children_dict[expr][idx] = self.children_dict[expr][idx] + lin
[docs] def affine_sum(self, shape: torch.Size | None = None) -> AffineSum: li = [] for expr, lin in self.children_dict.items(): eps = QuadraticEpsilon(expr) for i, lin_op in enumerate(lin): if isinstance(lin_op, ZeroOp): continue if not lin_op.jacobian().any(): continue li.append((lin_op, eps[i])) if not li: if isinstance(self.constant, torch.Tensor): return ConstVal(self.constant) assert shape is not None, "Need an output shape for a zero ZonosqExpr." return ConstVal(torch.zeros(shape)) return AffineSum(*li, const=self.constant)