from __future__ import annotations
r"""Triple quadratic-zonotope differential verification.
This module mirrors :mod:`boundlab.diff.zono3` but uses
:mod:`boundlab.zonosq` as the base interpreter and uses quadratic-zonotope
bilinear handlers for differential ``Mul``/``MatMul``. It intentionally does
not provide a gradlin variant.
"""
import dataclasses
import torch
from boundlab import interp
from boundlab import zonosq
from boundlab.diff import expr
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.expr._affine import ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._var import LpEpsilon
from boundlab.interp import Interpreter
from boundlab.linearop._base import LinearOp
from boundlab.utils import not0
from boundlab.zonosq import ZonoBounds
def _apply_weights(weights, inputs) -> Expr | None:
result = None
for w, e in zip(weights, inputs):
if isinstance(w, int) and w == 0:
continue
term = w * e
result = term if result is None else result + term
return result
def _build_triple_from_dzb(
dzb: "DiffZonosqBounds",
xs: list[Expr],
ys: list[Expr],
ds: list[Expr],
reason: str = "",
) -> DiffExpr3:
x_sum = _apply_weights(dzb.x_bounds.input_weights, xs)
x_result = ConstVal(dzb.x_bounds.bias) if x_sum is None else x_sum + dzb.x_bounds.bias
eps_x = None
if dzb.x_bounds.error_coeffs is not None:
eps_x = LpEpsilon(dzb.x_bounds.error_coeffs.input_shape, reason=reason)
x_result = x_result + dzb.x_bounds.error_coeffs(eps_x)
y_sum = _apply_weights(dzb.y_bounds.input_weights, ys)
y_result = ConstVal(dzb.y_bounds.bias) if y_sum is None else y_sum + dzb.y_bounds.bias
eps_y = None
if dzb.y_bounds.error_coeffs is not None:
eps_y = LpEpsilon(dzb.y_bounds.error_coeffs.input_shape, reason=reason)
y_result = y_result + dzb.y_bounds.error_coeffs(eps_y)
d_result = ConstVal(dzb.diff_bounds.bias)
if dzb.diff_x_weights != 0:
s = _apply_weights(dzb.diff_x_weights, xs)
if s is not None:
d_result = d_result + s
if dzb.diff_y_weights != 0:
s = _apply_weights(dzb.diff_y_weights, ys)
if s is not None:
d_result = d_result + s
d_in = _apply_weights(dzb.diff_bounds.input_weights, ds)
if d_in is not None:
d_result = d_result + d_in
if eps_x is not None and not0(dzb.diff_x_error):
d_result = d_result + dzb.diff_x_error(eps_x)
if eps_y is not None and not0(dzb.diff_y_error):
d_result = d_result + dzb.diff_y_error(eps_y)
if dzb.diff_bounds.error_coeffs is not None:
eps_d = LpEpsilon(dzb.diff_bounds.error_coeffs.input_shape, reason=reason)
d_result = d_result + dzb.diff_bounds.error_coeffs(eps_d)
from boundlab.prop import bound_width
sub = x_result - y_result
w_sub = bound_width(sub)
w_d = bound_width(d_result)
sub_finite = bool(torch.isfinite(w_sub).all())
d_finite = bool(torch.isfinite(w_d).all())
if sub_finite and d_finite:
mask = (w_sub < w_d).float()
d_result = mask * sub + (1.0 - mask) * d_result
elif sub_finite:
d_result = sub
return DiffExpr3(x_result, y_result, d_result)
interpret = Interpreter[Expr | DiffExpr2 | DiffExpr3](zonosq.interpret)
[docs]
@dataclasses.dataclass
class DiffZonosqBounds:
x_bounds: ZonoBounds
y_bounds: ZonoBounds
diff_bounds: ZonoBounds
diff_x_error: LinearOp
diff_x_weights: list[torch.Tensor | 0] | 0
diff_y_error: LinearOp
diff_y_weights: list[torch.Tensor | 0] | 0
# Compatibility name used by the shared default linearizer wrappers.
DiffZonoBounds = DiffZonosqBounds
[docs]
def linearizer_to_hander(linearizer):
def handler(*args):
if not any(isinstance(a, (DiffExpr3, DiffExpr2)) for a in args):
return NotImplemented
xs, ys, ds = [], [], []
for a in args:
if isinstance(a, DiffExpr3):
xs.append(a.x)
ys.append(a.y)
ds.append(a.diff)
elif isinstance(a, DiffExpr2):
xs.append(a.x)
ys.append(a.y)
ds.append(a.x - a.y)
else:
xs.append(a)
ys.append(a)
ds.append(ConstVal(torch.zeros(tuple(a.shape))))
return _build_triple_from_dzb(
linearizer(xs, ys, ds), xs, ys, ds, reason=linearizer.__name__
)
return handler
from .default import ( # noqa: E402
const_heaviside_pruning,
const_topk_pruning,
diff_bilinear_elementwise,
diff_bilinear_matmul,
diff_matmul_handler,
diff_mul_handler,
diff_softmax_handler,
diff_softmax_pruning_handler,
exp_linearizer,
reciprocal_linearizer,
relu_linearizer,
tanh_linearizer,
)
_relu_diff = linearizer_to_hander(relu_linearizer)
_tanh_diff = linearizer_to_hander(tanh_linearizer)
_exp_diff = linearizer_to_hander(exp_linearizer)
_reciprocal_diff = linearizer_to_hander(reciprocal_linearizer)
interpret["relu"] = _relu_diff
interpret["Relu"] = _relu_diff
interpret["tanh"] = _tanh_diff
interpret["Tanh"] = _tanh_diff
interpret["exp"] = _exp_diff
interpret["Exp"] = _exp_diff
interpret["reciprocal"] = _reciprocal_diff
interpret["Reciprocal"] = _reciprocal_diff
interpret["HeavisidePruning"] = const_heaviside_pruning
interpret["TopKPruning"] = (
lambda scores, data, k, dim=-1, largest=True:
const_topk_pruning(scores, data, k, dim=dim, largest=largest)
)
from boundlab.diff.op import diff_pair_handler # noqa: E402
interpret["DiffPair"] = diff_pair_handler
def onnx_boardcasted(fn):
return lambda X, Y, *args, **kwargs: interp.FnList._call(
fn, *interp._onnx_broadcast(X, Y), *args, **kwargs
)
interpret["Mul"] = onnx_boardcasted(diff_mul_handler)
interpret["MatMul"] = diff_matmul_handler
interpret["Div"] = onnx_boardcasted(
lambda a, b: diff_mul_handler(
a, interpret["Reciprocal"](ConstVal(b) if isinstance(b, torch.Tensor) else b)
)
)
interpret["Softmax"] = lambda X, axis=-1: diff_softmax_handler(
X,
dim=axis,
exp_handler=interpret["Exp"],
reciprocal_handler=interpret["Reciprocal"],
)
interpret["SoftmaxPruning"] = lambda scores, data, dim=-1: diff_softmax_pruning_handler(
scores,
data,
dim=dim,
exp_handler=interpret["Exp"],
reciprocal_handler=interpret["Reciprocal"],
heaviside_handler=interpret["HeavisidePruning"],
)
__all__ = [
"interpret",
"DiffZonosqBounds",
"DiffZonoBounds",
"expr",
"linearizer_to_hander",
"relu_linearizer",
"tanh_linearizer",
"exp_linearizer",
"reciprocal_linearizer",
"diff_bilinear_elementwise",
"diff_bilinear_matmul",
"diff_softmax_handler",
"diff_softmax_pruning_handler",
]