from __future__ import annotations
from boundlab import interp
from boundlab.diff import expr
from boundlab.diff.zono3 import gradlin
from boundlab.utils import not0
r"""Triple-Zonotope-Based Abstract Interpretation for Differential Verification
This module provides zonotope transformations for computing over-approximations
of the difference ``f₁(x) − f₂(x)`` between two structurally identical networks,
achieving tighter bounds than verifying each network independently.
The interpreter operates on **triples** :class:`~boundlab.diff.expr.DiffExpr3`
``(x, y, d)`` where:
- ``x``: expression tracking network 1's output zonotope,
- ``y``: expression tracking network 2's output zonotope,
- ``d``: expression tracking the *difference* ``f₁(x) − f₂(x)``.
Affine operations (addition, scalar multiplication, linear layers, shape ops)
are handled directly: the bias cancels in the diff component, and weight
matrices are applied to all three components (without bias for ``d``).
Non-linear operations (ReLU, …) use specialised differential linearisers
derived from VeryDiff (Teuber et al., 2024).
Examples
--------
Build a :class:`~boundlab.diff.expr.DiffExpr3` and propagate it through a model:
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> from boundlab.diff.expr import DiffExpr3
>>> from boundlab.diff.zono3 import interpret
>>> model = nn.Sequential(nn.Linear(4, 5), nn.ReLU(), nn.Linear(5, 3))
>>> op = interpret(model)
>>> x = expr.ConstVal(torch.zeros(4)) + expr.LpEpsilon([4])
>>> y = expr.ConstVal(torch.ones(4)) + expr.LpEpsilon([4])
>>> d = x - y
>>> out = op(DiffExpr3(x, y, d))
>>> out.diff.ub().shape, out.diff.lb().shape
(torch.Size([3]), torch.Size([3]))
"""
import dataclasses
import torch
from boundlab.expr._core import Expr
from boundlab.expr._affine import ConstVal
from boundlab.expr._var import LpEpsilon
from boundlab.interp import Interpreter # noqa: F401
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.linearop._base import LinearOp
from boundlab.zono import ZonoBounds, interpret as std_interpret
# =====================================================================
# Expression builders
# =====================================================================
def _apply_weights(weights, inputs) -> Expr | None:
"""Return the weighted sum of inputs, skipping zero weights. Returns None if all zero."""
result = None
for w, e in zip(weights, inputs):
if isinstance(w, int) and w == 0:
continue
term = w * e
result = term if result is None else result + term
return result
def _build_triple_from_dzb(
dzb: "DiffZonoBounds",
xs: list[Expr],
ys: list[Expr],
ds: list[Expr],
reason: str = "",
) -> DiffExpr3:
"""Build a :class:`~boundlab.diff.expr.DiffExpr3` from *dzb*, sharing epsilon variables.
*xs*, *ys*, *ds* are parallel lists — one entry per input to the
nonlinearity (length 1 for unary ops, 2 for binary, etc.).
``x_bounds.input_weights[i]`` is applied to ``xs[i]``, and so on.
The fresh epsilon introduced for ``x_bounds.error_coeffs`` (``eps_x``) is
reused verbatim in ``diff_x_error(eps_x)``, and likewise for ``eps_y``.
This makes the diff expression track ``x_output − y_output`` **exactly**
for neurons handled by the cases 1–8 path (no extra approximation error),
yielding tighter bounds — especially for L2 perturbations and multi-layer
networks where the shared epsilon structure cancels downstream.
"""
# Build x expression; capture the fresh eps_x for reuse in diff.
x_sum = _apply_weights(dzb.x_bounds.input_weights, xs)
x_result = ConstVal(dzb.x_bounds.bias) if x_sum is None else x_sum + dzb.x_bounds.bias
eps_x = None
if dzb.x_bounds.error_coeffs is not None:
eps_x = LpEpsilon(dzb.x_bounds.error_coeffs.input_shape, reason=reason)
x_result = x_result + dzb.x_bounds.error_coeffs(eps_x)
# Build y expression; capture the fresh eps_y for reuse in diff.
y_sum = _apply_weights(dzb.y_bounds.input_weights, ys)
y_result = ConstVal(dzb.y_bounds.bias) if y_sum is None else y_sum + dzb.y_bounds.bias
eps_y = None
if dzb.y_bounds.error_coeffs is not None:
eps_y = LpEpsilon(dzb.y_bounds.error_coeffs.input_shape, reason=reason)
y_result = y_result + dzb.y_bounds.error_coeffs(eps_y)
# Build diff expression, reusing eps_x and eps_y.
d_result = ConstVal(dzb.diff_bounds.bias)
if dzb.diff_x_weights != 0:
s = _apply_weights(dzb.diff_x_weights, xs)
if s is not None:
d_result = d_result + s
if dzb.diff_y_weights != 0:
s = _apply_weights(dzb.diff_y_weights, ys)
if s is not None:
d_result = d_result + s
d_in = _apply_weights(dzb.diff_bounds.input_weights, ds)
if d_in is not None:
d_result = d_result + d_in
# Shared errors: same eps variables as x_result and y_result.
if eps_x is not None and not0(dzb.diff_x_error):
d_result = d_result + dzb.diff_x_error(eps_x)
if eps_y is not None and not0(dzb.diff_y_error):
d_result = d_result + dzb.diff_y_error(eps_y)
# Fresh diff-only error (e.g. case-9 triangle relaxation on d directly).
if dzb.diff_bounds.error_coeffs is not None:
eps_d = LpEpsilon(dzb.diff_bounds.error_coeffs.input_shape, reason=reason)
d_result = d_result + dzb.diff_bounds.error_coeffs(eps_d)
from boundlab.prop import bound_width
sub = x_result - y_result
mask = (bound_width(sub) < bound_width(d_result)).float()
d_result = mask * sub + (1.0 - mask) * d_result
return DiffExpr3(x_result, y_result, d_result)
# =====================================================================
# Interpreter
# =====================================================================
interpret = Interpreter[Expr | DiffExpr2 | DiffExpr3](std_interpret)
"""Differential-verification interpreter.
Feed it a :class:`~boundlab.diff.expr.DiffExpr3` ``(x, y, d)`` where *x* and
*y* are the two networks' zonotope expressions and *d* over-approximates their
difference, or a plain :class:`~boundlab.expr.Expr` for standard zonotope
interpretation.
Examples
--------
Differential mode (:class:`~boundlab.diff.expr.DiffExpr3` input):
>>> import torch
>>> from torch import nn
>>> import boundlab.expr as expr
>>> from boundlab.diff.expr import DiffExpr3
>>> from boundlab.diff.zono3 import interpret
>>> model = nn.Linear(4, 3)
>>> op = interpret(model)
>>> x = expr.ConstVal(torch.randn(4)) + expr.LpEpsilon([4])
>>> y = expr.ConstVal(torch.randn(4)) + expr.LpEpsilon([4])
>>> out = op(DiffExpr3(x, y, x - y))
>>> out.diff.ub().shape
torch.Size([3])
Fallback mode (plain :class:`~boundlab.expr.Expr`) matches standard zonotope
interpretation:
>>> z = expr.ConstVal(torch.randn(4)) + expr.LpEpsilon([4])
>>> z_out = op(z)
>>> z_out.ub().shape
torch.Size([3])
"""
[docs]
@dataclasses.dataclass
class DiffZonoBounds:
x_bounds: ZonoBounds
y_bounds: ZonoBounds
diff_bounds: ZonoBounds
diff_x_error: LinearOp
diff_x_weights: list[torch.Tensor | 0] | 0
diff_y_error: LinearOp
diff_y_weights: list[torch.Tensor | 0] | 0
# =====================================================================
# Lineariser registration
# =====================================================================
[docs]
def linearizer_to_hander(linearizer):
"""Register a differential lineariser for a non-linear activation.
The decorated function receives ``(xs, ys, ds)`` — three parallel lists of
:class:`~boundlab.expr.Expr`, one entry per input to the nonlinearity —
and returns a :class:`DiffZonoBounds`.
For unary activations (relu, tanh, …) each list has length 1. For binary
operations each list has length 2, with ``xs[i]`` / ``ys[i]`` / ``ds[i]``
being the *i*-th input's x-network, y-network, and diff components
respectively.
``diff_bounds.input_weights[i]`` is the weight applied to ``ds[i]``;
``diff_x_weights[i]`` / ``diff_y_weights[i]`` are the weights applied to
``xs[i]`` / ``ys[i]``. ``diff_x_error`` / ``diff_y_error`` are applied to
the **same** epsilon variables introduced for ``x_bounds`` / ``y_bounds``,
enabling exact diff tracking for cases where no fresh error is needed.
All inputs must be :class:`~boundlab.diff.expr.DiffExpr3` or
:class:`~boundlab.diff.expr.DiffExpr2`; if none are, the call falls back to
the standard zonotope handler. :class:`~boundlab.diff.expr.DiffExpr2`
inputs have their diff synthesised as ``x − y``.
"""
def handler(*args):
if not any(isinstance(a, (DiffExpr3, DiffExpr2)) for a in args):
return NotImplemented
xs, ys, ds = [], [], []
for a in args:
if isinstance(a, DiffExpr3):
xs.append(a.x);
ys.append(a.y);
ds.append(a.diff)
elif isinstance(a, DiffExpr2):
xs.append(a.x);
ys.append(a.y);
ds.append(a.x - a.y)
else:
xs.append(a);
ys.append(a);
ds.append(expr.ConstVal(None)) # constant: diff is 0
return _build_triple_from_dzb(linearizer(xs, ys, ds), xs, ys, ds, reason=linearizer.__name__)
return handler
# =====================================================================
# Activation modules (imported last so helpers are already defined)
# =====================================================================
from .default import ( # noqa: E402
relu_linearizer,
tanh_linearizer,
exp_linearizer,
reciprocal_linearizer,
diff_mul_handler,
diff_matmul_handler,
diff_bilinear_elementwise,
diff_bilinear_matmul,
diff_softmax_handler,
diff_heaviside_pruning_handler,
)
# Activation op names — both ATen (lowercase) and ONNX (capitalised) forms.
_relu_diff = linearizer_to_hander(relu_linearizer)
_tanh_diff = linearizer_to_hander(tanh_linearizer)
_exp_diff = linearizer_to_hander(exp_linearizer)
_reciprocal_diff = linearizer_to_hander(reciprocal_linearizer)
interpret["relu"] = _relu_diff
interpret["Relu"] = _relu_diff
interpret["tanh"] = _tanh_diff
interpret["Tanh"] = _tanh_diff
interpret["exp"] = _exp_diff
interpret["Exp"] = _exp_diff
interpret["reciprocal"] = _reciprocal_diff
interpret["Reciprocal"] = _reciprocal_diff
# diff_pair: converts paired tensors (from boundlab::diff_pair ONNX nodes) to DiffExpr2
from boundlab.diff.op import diff_pair_handler # noqa: E402
interpret["DiffPair"] = diff_pair_handler
def onnx_boardcasted(fn):
return lambda X, Y, *args, **kwargs: fn(*interp._onnx_broadcast(X, Y), *args, **kwargs)
interpret["Mul"] = onnx_boardcasted(diff_mul_handler)
interpret["MatMul"] = diff_matmul_handler
interpret["Div"] = onnx_boardcasted(lambda a, b: diff_mul_handler(a, interpret["Reciprocal"](b)))
interpret["Softmax"] = lambda X, axis=-1: diff_softmax_handler(X, dim=axis)
interpret_gradlin = Interpreter(interpret)
gradlin.register_all(interpret_gradlin, linearizer_to_hander)
__all__ = [
"interpret",
"DiffZonoBounds",
"expr",
"linearizer_to_hander",
"relu_linearizer",
"tanh_linearizer",
"exp_linearizer",
"reciprocal_linearizer",
"diff_bilinear_elementwise",
"diff_bilinear_matmul",
"diff_softmax_handler",
"diff_heaviside_pruning_handler",
]