"""Quadratic epsilon terms for the zonosq interpreter."""
from __future__ import annotations
from typing import Literal, Union
import torch
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._tuple import GetTupleItem, TupleExpr
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._base import LinearOp, ZeroOp
[docs]
class QuadraticEpsilon(TupleExpr):
r"""Tuple of ``(eps, eps^2, eps' * eps^2)`` for ``eps, eps' in [-1, 1]``."""
[docs]
def __init__(self, epsilon: LpEpsilon):
super().__init__()
self.epsilon = epsilon
self.inner_shape = epsilon.shape
def __eq__(self, other):
if not isinstance(other, QuadraticEpsilon):
return False
return self.epsilon == other.epsilon
def __hash__(self):
return hash(self.epsilon)
@property
def shape(self) -> tuple[torch.Size, ...]:
return (self.inner_shape, self.inner_shape, self.inner_shape)
@property
def children(self) -> tuple[Expr, ...]:
"""Children expressions that contribute to this TupleExpr. This is used for topological sorting and weight propagation."""
return ()
@staticmethod
def _materialize_weight(weight, output_shape: torch.Size, input_shape: torch.Size) -> torch.Tensor:
if weight == 0:
return torch.zeros(output_shape + input_shape)
assert isinstance(weight, LinearOp), (
f"QuadraticEpsilon expects LinearOp weights, got {type(weight).__name__}"
)
return weight.jacobian().reshape(output_shape + input_shape)
@staticmethod
def _quadratic_extremum(
linear: torch.Tensor,
quadratic: torch.Tensor,
*,
upper: bool,
input_ndim: int,
) -> torch.Tensor:
endpoint_hi = torch.maximum(linear + quadratic, -linear + quadratic)
endpoint_lo = torch.minimum(linear + quadratic, -linear + quadratic)
if upper:
vertex_mask = quadratic < 0
fill = torch.full_like(quadratic, -torch.inf)
endpoint = endpoint_hi
reduce_fn = torch.maximum
else:
vertex_mask = quadratic > 0
fill = torch.full_like(quadratic, torch.inf)
endpoint = endpoint_lo
reduce_fn = torch.minimum
safe_quadratic = torch.where(vertex_mask, quadratic, torch.ones_like(quadratic))
vertex = -linear / (2.0 * safe_quadratic)
vertex_mask = vertex_mask & (vertex >= -1.0) & (vertex <= 1.0)
vertex_value = linear * vertex + quadratic * vertex * vertex
value = reduce_fn(endpoint, torch.where(vertex_mask, vertex_value, fill))
if input_ndim == 0:
return value
input_dims = tuple(range(value.dim() - input_ndim, value.dim()))
return value.sum(dim=input_dims)
[docs]
def backward(self, *weights, direction="==") -> tuple[Union[torch.Tensor, Literal[0]], list] | None:
"""Concretize ``a*eps + b*eps^2 + c*eps'*eps^2`` for each output."""
w1, w2, w3 = weights
if direction == "==":
return None
output_shape = next(
(w.output_shape for w in weights if isinstance(w, LinearOp)),
torch.Size(()),
)
w1_t = self._materialize_weight(w1, output_shape, self.inner_shape)
w2_t = self._materialize_weight(w2, output_shape, self.inner_shape)
w3_t = self._materialize_weight(w3, output_shape, self.inner_shape)
quadratic = w2_t + w3_t.abs() if direction == "<=" else w2_t - w3_t.abs()
return self._quadratic_extremum(
w1_t,
quadratic,
upper=(direction == "<="),
input_ndim=len(self.inner_shape),
), []
[docs]
def with_children(self, *new_children: Expr) -> "TupleExpr":
"""Return a new TupleExpr with the same flags but new children. This is used for expression rewriting during bound propagation."""
return self
[docs]
def to_string(self, *children_str: str, indent: int = 0) -> str:
"""Convert this expression to a string for debugging purposes. The children_str arguments are the string representations of the children expressions, in the same order as self.children."""
return self.epsilon.to_string().replace("𝜀", "(𝜀, 𝜀^2, 𝜀' 𝜀^2)")
[docs]
class ZonosqExpr:
constant: Union[torch.Tensor, None]
children_dict: dict[LpEpsilon, list[LinearOp]]
[docs]
def __init__(self, affine_sum: AffineSum):
self.constant = affine_sum.constant
self.children_dict = {}
for expr, lin in affine_sum.children_dict.items():
if isinstance(expr, LpEpsilon):
self.add(expr, lin, 0)
elif isinstance(expr, GetTupleItem) and isinstance(expr.tuple_expr, QuadraticEpsilon):
self.add(expr.tuple_expr.epsilon, lin, expr.index)
else:
raise ValueError(f"Unexpected expression type: {type(expr)}")
[docs]
def add(self, expr: LpEpsilon, lin: LinearOp, idx: int) -> None:
if expr not in self.children_dict:
self.children_dict[expr] = [ZeroOp.like(lin), ZeroOp.like(lin), ZeroOp.like(lin)]
self.children_dict[expr][idx] = self.children_dict[expr][idx] + lin
[docs]
def affine_sum(self, shape: torch.Size | None = None) -> AffineSum:
li = []
for expr, lin in self.children_dict.items():
eps = QuadraticEpsilon(expr)
for i, lin_op in enumerate(lin):
if isinstance(lin_op, ZeroOp):
continue
if not lin_op.jacobian().any():
continue
li.append((lin_op, eps[i]))
if not li:
if isinstance(self.constant, torch.Tensor):
return ConstVal(self.constant)
assert shape is not None, "Need an output shape for a zero ZonosqExpr."
return ConstVal(torch.zeros(shape))
return AffineSum(*li, const=self.constant)