r"""Polytope-Based Abstract Interpretation for Neural Networks.
This module provides CROWN-style abstract interpretation using linear
polytope relaxations of nonlinear activations. Each neuron is bounded
by a pair of linear envelopes:
.. math::
\lambda_\ell \odot x + b_\ell \;\le\; f(x) \;\le\; \lambda_u \odot x + b_u
The central :class:`PolyBoundGate` expression represents an abstract
function with fixed :math:`\pm 1` offsets; general CROWN-style bounds
are wrapped by rescaling around their midpoint.
Examples
--------
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> import boundlab.poly as poly
>>> model = nn.Sequential(nn.Linear(4, 5), nn.ReLU(), nn.Linear(5, 3))
>>> op = poly.interpret(model)
>>> x = expr.ConstVal(torch.zeros(4)) + expr.LpEpsilon([4])
>>> y = op(x)
>>> y.ub().shape
torch.Size([3])
"""
from __future__ import annotations
import dataclasses
import inspect
from typing import Literal
import torch
from boundlab import expr as _expr
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.expr._core import Expr, ExprFlags
from boundlab.interp import ONNX_BASE_INTERPRETER, Interpreter
from boundlab.linearop import LinearOp
from boundlab.linearop._einsum import EinsumOp
interpret = Interpreter[Expr](ONNX_BASE_INTERPRETER)
"""Polytope-based interpreter.
Dispatches neural-network operators to CROWN-style linearizers that
produce :class:`PolyBoundGate`-wrapped expressions.
Examples
--------
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> import boundlab.poly as poly
>>> op = poly.interpret(nn.Linear(2, 1))
>>> y = op(expr.ConstVal(torch.zeros(2)) + expr.LpEpsilon([2]))
>>> y.shape
torch.Size([1])
"""
[docs]
class PolyBoundGate(Expr):
r"""Abstract gate bounded pointwise by a pair of linear polytopes on its child.
Represents an elementwise function :math:`f(x)` whose output is constrained by
.. math::
\lambda_\ell \odot x - 1 \;\le\; f(x) \;\le\; \lambda_u \odot x + 1,
where :math:`x` is the child expression and :math:`\lambda_u, \lambda_\ell`
(``upper_lam``, ``lower_lam``) are concrete tensors of the same shape
as the child.
Backward propagation splits the incoming weight :math:`w` by sign
element-wise on its materialized Jacobian. For direction ``"<="``:
.. math::
w \cdot f(x) \;\le\; (w_+ \odot \lambda_u + w_- \odot \lambda_\ell)\,x
+ \sum_j |w_{\cdot j}|,
and symmetrically for ``">="`` with the slopes swapped and the bias negated.
"""
[docs]
def __init__(self, child: Expr, upper_lam: torch.Tensor, lower_lam: torch.Tensor,
*, reason: str | None = None):
super().__init__(ExprFlags.NONE)
assert child.shape == upper_lam.shape == lower_lam.shape, (
f"Shape mismatch: child={child.shape}, "
f"upper_lam={upper_lam.shape}, lower_lam={lower_lam.shape}"
)
self._child = child
self.lam_mean = EinsumOp.from_hardmard((upper_lam + lower_lam) / 2, len(child.shape))
self.lam_halfdiff = EinsumOp.from_hardmard((upper_lam - lower_lam) / 2, len(child.shape))
self.reason = reason if reason is not None else str(inspect.stack()[1].function)
@property
def shape(self) -> torch.Size:
return self._child.shape
@property
def children(self) -> tuple[Expr, ...]:
return (self._child,)
[docs]
def with_children(self, *new_children: Expr) -> "PolyBoundGate":
(new_child,) = new_children
return PolyBoundGate(new_child, self.upper_lam, self.lower_lam, reason=self.reason)
[docs]
def backward(self, weights: LinearOp, direction: Literal[">=", "<=", "=="]):
if direction == "==":
return None
einop: EinsumOp = weights.einsum_op()
if direction == "<=":
new_op = einop @ self.lam_mean + einop.abs() @ self.lam_halfdiff
bias = einop.norm_input(p=1).jacobian()
else:
new_op = einop @ self.lam_mean - einop.abs() @ self.lam_halfdiff
bias = -einop.norm_input(p=1).jacobian()
return bias, [new_op]
[docs]
def to_string(self, child_str: str) -> str:
return f"PolyBoundGate({child_str})"
[docs]
@dataclasses.dataclass
class PolyBounds:
r"""CROWN-style linear relaxation bounds for a unary nonlinearity.
Represents pointwise constraints
.. math::
\lambda_\ell \odot x + b_\ell \;\le\; f(x) \;\le\; \lambda_u \odot x + b_u.
The constituent tensors are per-neuron and share the activation's
input/output shape.
"""
upper_lam: torch.Tensor
upper_bias: torch.Tensor
lower_lam: torch.Tensor
lower_bias: torch.Tensor
def _bounds_to_expr(x: Expr, bounds: "PolyBounds", *, eps: float = 1e-30,
reason: str | None = None) -> Expr:
r"""Wrap ``x`` as an expression satisfying ``bounds``.
Expresses the CROWN relaxation as
.. math::
f(x) \;=\; \bar{\lambda} \odot x + \bar{b}
+ \beta \odot \mathrm{PolyBoundGate}(x, U, L),
where :math:`\bar{\lambda}, \bar{b}` are the midpoint slopes/biases,
:math:`\beta = (b_u - b_\ell)/2` is the bias half-width, and
:math:`U = -L = (\lambda_u - \lambda_\ell)/(2\beta)`. Neurons with a
tight bound (:math:`\beta = 0`) contribute only the affine part.
"""
ul, ub = bounds.upper_lam, bounds.upper_bias
ll, lb = bounds.lower_lam, bounds.lower_bias
base_lam = 0.5 * (ul + ll)
base_bias = 0.5 * (ub + lb)
err = 0.5 * (ub - lb)
slope_slack = 0.5 * (ul - ll)
exact = err <= eps
err_safe = torch.where(exact, torch.ones_like(err), err)
U = torch.where(exact, torch.zeros_like(ul), slope_slack / err_safe)
L = -U
err = torch.where(exact, torch.zeros_like(err), err)
gate = PolyBoundGate(x, U, L, reason=reason)
return base_lam * x + base_bias + err * gate
def _register_linearizer(name: str):
r"""Register a CROWN-style linearizer under ``name`` in :data:`interpret`.
The decorated function takes pairs of concrete ``(ub, lb)`` tensors
— one pair per input expression — and returns a :class:`PolyBounds`.
The registered handler evaluates bounds via :func:`~boundlab.prop.ublb`,
invokes the linearizer, and wraps the result with
:class:`PolyBoundGate` via :func:`_bounds_to_expr`.
"""
def decorator(linearizer):
def handler(*exprs: Expr) -> Expr:
if all(isinstance(e, ConstVal) for e in exprs):
return NotImplemented
assert len(exprs) == 1, \
"Only unary linearizers are supported; got {} inputs.".format(len(exprs))
(x,) = exprs
ub, lb = x.ublb()
bounds = linearizer(ub, lb)
assert (
bounds.upper_lam.shape == x.shape
and bounds.lower_lam.shape == x.shape
and bounds.upper_bias.shape == x.shape
and bounds.lower_bias.shape == x.shape
), "PolyBounds tensors must match the input expression shape."
return _bounds_to_expr(x, bounds, reason=linearizer.__name__)
interpret[name] = handler
return linearizer
return decorator
# =====================================================================
# Import activation modules — each calls _register_linearizer
# =====================================================================
from .relu import relu_linearizer
from .exp import exp_linearizer
from .reciprocal import reciprocal_linearizer
from .tanh import tanh_linearizer
from .square import square_linearizer
# ONNX activation handlers
interpret["Relu"] = interpret["relu"]
interpret["Tanh"] = interpret["tanh"]
# Softmax
from .softmax import softmax_handler
from .bilinear import matmul_handler, mul_handler
interpret["Softmax"] = lambda X, axis=-1: softmax_handler(X, dim=axis)
interpret["MatMul"] = matmul_handler
interpret["Mul"] = mul_handler
__all__ = [
"PolyBoundGate",
"PolyBounds",
"interpret",
"relu_linearizer",
"exp_linearizer",
"reciprocal_linearizer",
"tanh_linearizer",
"square_linearizer",
"softmax_handler",
"matmul_handler",
"mul_handler",
]