Source code for boundlab.zono

r"""Zonotope-Based Abstract Interpretation for Neural Networks

This module provides zonotope transformations for computing over-approximations
of neural network outputs under bounded input perturbations.

Examples
--------
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> import boundlab.zono as zono
>>> model = nn.Sequential(nn.Linear(4, 5), nn.ReLU(), nn.Linear(5, 3))
>>> op = zono.interpret(model)
>>> x = expr.ConstVal(torch.zeros(4)) + expr.LpEpsilon([4])
>>> y = op(x)
>>> y.ub().shape
torch.Size([3])
"""

from __future__ import annotations

import dataclasses
import torch

from boundlab import expr
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._var import LpEpsilon
from boundlab.interp import ONNX_BASE_INTERPRETER, Interpreter
from boundlab.linearop import LinearOp
from boundlab.linearop._einsum import EinsumOp
from boundlab.utils import not0

interpret = Interpreter[Expr](ONNX_BASE_INTERPRETER)
"""Zonotope-based interpreter.

Examples
--------
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> import boundlab.zono as zono
>>> op = zono.interpret(nn.Linear(2, 1))
>>> y = op(expr.ConstVal(torch.zeros(2)) + expr.LpEpsilon([2]))
>>> y.shape
torch.Size([1])
"""


[docs] @dataclasses.dataclass class ZonoBounds: """Data class representing zonotope bounds for a neural network layer. Examples -------- ``input_weights`` has one entry per input expression to the linearizer. For unary ops such as ReLU, this is typically a single slope tensor. """ bias: torch.Tensor # The bias term of the zonotope error_coeffs: LinearOp input_weights: list[torch.Tensor | 0] # Hadamard product weights of the input terms def __post_init__(self): if isinstance(self.error_coeffs, torch.Tensor): self.error_coeffs = EinsumOp.from_hardmard(self.error_coeffs)
[docs] def apply_without_error(self, *inputs: Expr) -> Expr: """Apply the zonotope bounds to given input tensors, ignoring the error term.""" assert len(inputs) == len(self.input_weights), \ f"Expected {len(self.input_weights)} input expressions, got {len(inputs)}" result = self.bias for w, e in zip(self.input_weights, inputs): if not0(w): result = result + w * e return result
def _register_linearizer(name: str): def decorator(linearizer: callable): def handler(*exprs: Expr) -> Expr: if all(isinstance(e, ConstVal) for e in exprs): return NotImplemented ubs_lbs = [e.ublb() for e in exprs] bounds = linearizer(*[t for ub, lb in ubs_lbs for t in (ub, lb)]) assert all(w.shape == e.shape for w, e in zip(bounds.input_weights, exprs)), \ "Input weights must match the shapes of the input expressions." # Apply slopes to input expressions result_expr = sum(w * e for w, e in zip(bounds.input_weights, exprs) if not0(w)) + bounds.bias # Introduce a fresh noise symbol for the approximation error new_eps = LpEpsilon(bounds.error_coeffs.input_shape, reason=linearizer.__name__) result_expr = result_expr + bounds.error_coeffs(new_eps) return result_expr interpret[name] = handler return linearizer return decorator # ===================================================================== # Import activation modules — each calls _register_linearizer # ===================================================================== from .relu import relu_linearizer from .exp import exp_linearizer from .reciprocal import reciprocal_linearizer from .tanh import tanh_linearizer # ONNX activation handlers interpret["Relu"] = interpret["relu"] interpret["Tanh"] = interpret["tanh"] interpret["exp"] = interpret["Exp"] interpret["reciprocal"] = interpret["Reciprocal"] # Bilinear handlers (supports Expr @ Expr and Expr * Expr) from .bilinear import matmul_handler, bilinear_matmul, bilinear_elementwise # noqa: F401 interpret["MatMul"] = matmul_handler # interpret["Mul"] = mul_handler from .softmax import softmax_handler, softmax_handler_basedon_softmax2 interpret["Softmax"] = lambda X, axis=-1: softmax_handler(X, dim=axis) from .softmax2 import softmax2_handler, softmax2_linearizer __all__ = [ "interpret", "ZonoBounds", "relu_linearizer", "exp_linearizer", "reciprocal_linearizer", "tanh_linearizer", "bilinear_matmul", "bilinear_elementwise", "matmul_handler", "mul_handler", "softmax_handler", "softmax2_handler", "softmax2_linearizer", ]