Source code for boundlab.zono
r"""Zonotope-Based Abstract Interpretation for Neural Networks
This module provides zonotope transformations for computing over-approximations
of neural network outputs under bounded input perturbations.
Examples
--------
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> import boundlab.zono as zono
>>> model = nn.Sequential(nn.Linear(4, 5), nn.ReLU(), nn.Linear(5, 3))
>>> op = zono.interpret(model)
>>> x = expr.ConstVal(torch.zeros(4)) + expr.LpEpsilon([4])
>>> y = op(x)
>>> y.ub().shape
torch.Size([3])
"""
from __future__ import annotations
import dataclasses
import torch
from boundlab import expr
from boundlab.expr._affine import AffineSum
from boundlab.expr._core import Expr
from boundlab.expr._var import LpEpsilon
from boundlab.interp import ONNX_BASE_INTERPRETER, Interpreter
from boundlab.linearop import LinearOp
interpret = Interpreter(ONNX_BASE_INTERPRETER)
"""Zonotope-based interpreter.
Examples
--------
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> import boundlab.zono as zono
>>> op = zono.interpret(nn.Linear(2, 1))
>>> y = op(expr.ConstVal(torch.zeros(2)) + expr.LpEpsilon([2]))
>>> y.shape
torch.Size([1])
"""
[docs]
@dataclasses.dataclass
class ZonoBounds:
"""Data class representing zonotope bounds for a neural network layer.
Examples
--------
``input_weights`` has one entry per input expression to the linearizer.
For unary ops such as ReLU, this is typically a single slope tensor.
"""
bias: torch.Tensor # The bias term of the zonotope
error_coeffs: LinearOp
input_weights: list[torch.Tensor | 0] # Hadamard product weights of the input terms
def _register_linearizer(name: str):
def decorator(linearizer: callable):
def handler(*exprs: Expr) -> Expr:
bounds = linearizer(*exprs)
assert all(w.shape == e.shape for w, e in zip(bounds.input_weights, exprs)), \
"Input weights must match the shapes of the input expressions."
# Apply slopes to input expressions
result_expr = sum(w * e for w, e in zip(bounds.input_weights, exprs) if w != 0) + bounds.bias
# Introduce a fresh noise symbol for the approximation error
new_eps = LpEpsilon(bounds.error_coeffs.input_shape)
result_expr = result_expr + bounds.error_coeffs(new_eps)
return result_expr
interpret[name] = handler
return linearizer
return decorator
# =====================================================================
# Import activation modules — each calls _register_linearizer
# =====================================================================
from .relu import relu_linearizer
from .exp import exp_linearizer
from .reciprocal import reciprocal_linearizer
from .tanh import tanh_linearizer
# ONNX activation handlers
interpret["Relu"] = interpret["relu"]
interpret["Tanh"] = interpret["tanh"]
# Bilinear matmul handler (supports Expr @ Expr)
from .bilinear import matmul_handler, bilinear_matmul, bilinear_elementwise # noqa: F401
interpret["MatMul"] = matmul_handler
# Softmax
from .softmax import softmax_handler
interpret["Softmax"] = lambda X, axis=-1: softmax_handler(X, dim=axis)
__all__ = [
"interpret", "ZonoBounds",
"relu_linearizer", "exp_linearizer", "reciprocal_linearizer", "tanh_linearizer",
"bilinear_matmul", "bilinear_elementwise", "matmul_handler", "softmax_handler",
]