Source code for boundlab.zono
r"""Zonotope-Based Abstract Interpretation for Neural Networks
This module provides zonotope transformations for computing over-approximations
of neural network outputs under bounded input perturbations.
Examples
--------
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> import boundlab.zono as zono
>>> model = nn.Sequential(nn.Linear(4, 5), nn.ReLU(), nn.Linear(5, 3))
>>> op = zono.interpret(model)
>>> x = expr.ConstVal(torch.zeros(4)) + expr.LpEpsilon([4])
>>> y = op(x)
>>> y.ub().shape
torch.Size([3])
"""
from __future__ import annotations
import dataclasses
import torch
from boundlab import expr
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._var import LpEpsilon
from boundlab.interp import ONNX_BASE_INTERPRETER, Interpreter
from boundlab.linearop import LinearOp
from boundlab.linearop._einsum import EinsumOp
from boundlab.utils import not0
interpret = Interpreter[Expr](ONNX_BASE_INTERPRETER)
"""Zonotope-based interpreter.
Examples
--------
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> import boundlab.zono as zono
>>> op = zono.interpret(nn.Linear(2, 1))
>>> y = op(expr.ConstVal(torch.zeros(2)) + expr.LpEpsilon([2]))
>>> y.shape
torch.Size([1])
"""
[docs]
@dataclasses.dataclass
class ZonoBounds:
"""Data class representing zonotope bounds for a neural network layer.
Examples
--------
``input_weights`` has one entry per input expression to the linearizer.
For unary ops such as ReLU, this is typically a single slope tensor.
"""
bias: torch.Tensor # The bias term of the zonotope
error_coeffs: LinearOp
input_weights: list[torch.Tensor | 0] # Hadamard product weights of the input terms
def __post_init__(self):
if isinstance(self.error_coeffs, torch.Tensor):
self.error_coeffs = EinsumOp.from_hardmard(self.error_coeffs)
[docs]
def apply_without_error(self, *inputs: Expr) -> Expr:
"""Apply the zonotope bounds to given input tensors, ignoring the error term."""
assert len(inputs) == len(self.input_weights), \
f"Expected {len(self.input_weights)} input expressions, got {len(inputs)}"
result = self.bias
for w, e in zip(self.input_weights, inputs):
if not0(w):
result = result + w * e
return result
def _register_linearizer(name: str):
def decorator(linearizer: callable):
def handler(*exprs: Expr) -> Expr:
if all(isinstance(e, ConstVal) for e in exprs):
return NotImplemented
ubs_lbs = [e.ublb() for e in exprs]
bounds = linearizer(*[t for ub, lb in ubs_lbs for t in (ub, lb)])
assert all(w.shape == e.shape for w, e in zip(bounds.input_weights, exprs)), \
"Input weights must match the shapes of the input expressions."
# Apply slopes to input expressions
result_expr = sum(w * e for w, e in zip(bounds.input_weights, exprs) if not0(w)) + bounds.bias
# Introduce a fresh noise symbol for the approximation error
new_eps = LpEpsilon(bounds.error_coeffs.input_shape, reason=linearizer.__name__)
result_expr = result_expr + bounds.error_coeffs(new_eps)
return result_expr
interpret[name] = handler
return linearizer
return decorator
# =====================================================================
# Import activation modules — each calls _register_linearizer
# =====================================================================
from .relu import relu_linearizer
from .exp import exp_linearizer
from .reciprocal import reciprocal_linearizer
from .tanh import tanh_linearizer
# ONNX activation handlers
interpret["Relu"] = interpret["relu"]
interpret["Tanh"] = interpret["tanh"]
interpret["exp"] = interpret["Exp"]
interpret["reciprocal"] = interpret["Reciprocal"]
# Bilinear handlers (supports Expr @ Expr and Expr * Expr)
from .bilinear import matmul_handler, bilinear_matmul, bilinear_elementwise # noqa: F401
interpret["MatMul"] = matmul_handler
# interpret["Mul"] = mul_handler
from .softmax import softmax_handler, softmax_handler_basedon_softmax2
interpret["Softmax"] = lambda X, axis=-1: softmax_handler(X, dim=axis)
from .softmax2 import softmax2_handler, softmax2_linearizer
__all__ = [
"interpret", "ZonoBounds",
"relu_linearizer", "exp_linearizer", "reciprocal_linearizer", "tanh_linearizer",
"bilinear_matmul", "bilinear_elementwise", "matmul_handler", "mul_handler",
"softmax_handler", "softmax2_handler", "softmax2_linearizer",
]