Source code for boundlab.diff.zonosq3

from __future__ import annotations

r"""Triple quadratic-zonotope differential verification.

This module mirrors :mod:`boundlab.diff.zono3` but uses
:mod:`boundlab.zonosq` as the base interpreter and uses quadratic-zonotope
bilinear handlers for differential ``Mul``/``MatMul``. It intentionally does
not provide a gradlin variant.
"""

import dataclasses

import torch

from boundlab import interp
from boundlab import zonosq
from boundlab.diff import expr
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.expr._affine import ConstVal
from boundlab.expr._core import Expr
from boundlab.expr._var import LpEpsilon
from boundlab.interp import Interpreter
from boundlab.linearop._base import LinearOp
from boundlab.utils import not0
from boundlab.zonosq import ZonoBounds


def _apply_weights(weights, inputs) -> Expr | None:
    result = None
    for w, e in zip(weights, inputs):
        if isinstance(w, int) and w == 0:
            continue
        term = w * e
        result = term if result is None else result + term
    return result


def _build_triple_from_dzb(
    dzb: "DiffZonosqBounds",
    xs: list[Expr],
    ys: list[Expr],
    ds: list[Expr],
    reason: str = "",
) -> DiffExpr3:
    x_sum = _apply_weights(dzb.x_bounds.input_weights, xs)
    x_result = ConstVal(dzb.x_bounds.bias) if x_sum is None else x_sum + dzb.x_bounds.bias
    eps_x = None
    if dzb.x_bounds.error_coeffs is not None:
        eps_x = LpEpsilon(dzb.x_bounds.error_coeffs.input_shape, reason=reason)
        x_result = x_result + dzb.x_bounds.error_coeffs(eps_x)

    y_sum = _apply_weights(dzb.y_bounds.input_weights, ys)
    y_result = ConstVal(dzb.y_bounds.bias) if y_sum is None else y_sum + dzb.y_bounds.bias
    eps_y = None
    if dzb.y_bounds.error_coeffs is not None:
        eps_y = LpEpsilon(dzb.y_bounds.error_coeffs.input_shape, reason=reason)
        y_result = y_result + dzb.y_bounds.error_coeffs(eps_y)

    d_result = ConstVal(dzb.diff_bounds.bias)

    if dzb.diff_x_weights != 0:
        s = _apply_weights(dzb.diff_x_weights, xs)
        if s is not None:
            d_result = d_result + s

    if dzb.diff_y_weights != 0:
        s = _apply_weights(dzb.diff_y_weights, ys)
        if s is not None:
            d_result = d_result + s

    d_in = _apply_weights(dzb.diff_bounds.input_weights, ds)
    if d_in is not None:
        d_result = d_result + d_in

    if eps_x is not None and not0(dzb.diff_x_error):
        d_result = d_result + dzb.diff_x_error(eps_x)
    if eps_y is not None and not0(dzb.diff_y_error):
        d_result = d_result + dzb.diff_y_error(eps_y)

    if dzb.diff_bounds.error_coeffs is not None:
        eps_d = LpEpsilon(dzb.diff_bounds.error_coeffs.input_shape, reason=reason)
        d_result = d_result + dzb.diff_bounds.error_coeffs(eps_d)

    from boundlab.prop import bound_width

    sub = x_result - y_result
    w_sub = bound_width(sub)
    w_d = bound_width(d_result)
    sub_finite = bool(torch.isfinite(w_sub).all())
    d_finite = bool(torch.isfinite(w_d).all())
    if sub_finite and d_finite:
        mask = (w_sub < w_d).float()
        d_result = mask * sub + (1.0 - mask) * d_result
    elif sub_finite:
        d_result = sub

    return DiffExpr3(x_result, y_result, d_result)


interpret = Interpreter[Expr | DiffExpr2 | DiffExpr3](zonosq.interpret)


[docs] @dataclasses.dataclass class DiffZonosqBounds: x_bounds: ZonoBounds y_bounds: ZonoBounds diff_bounds: ZonoBounds diff_x_error: LinearOp diff_x_weights: list[torch.Tensor | 0] | 0 diff_y_error: LinearOp diff_y_weights: list[torch.Tensor | 0] | 0
# Compatibility name used by the shared default linearizer wrappers. DiffZonoBounds = DiffZonosqBounds
[docs] def linearizer_to_hander(linearizer): def handler(*args): if not any(isinstance(a, (DiffExpr3, DiffExpr2)) for a in args): return NotImplemented xs, ys, ds = [], [], [] for a in args: if isinstance(a, DiffExpr3): xs.append(a.x) ys.append(a.y) ds.append(a.diff) elif isinstance(a, DiffExpr2): xs.append(a.x) ys.append(a.y) ds.append(a.x - a.y) else: xs.append(a) ys.append(a) ds.append(ConstVal(torch.zeros(tuple(a.shape)))) return _build_triple_from_dzb( linearizer(xs, ys, ds), xs, ys, ds, reason=linearizer.__name__ ) return handler
from .default import ( # noqa: E402 const_heaviside_pruning, const_topk_pruning, diff_bilinear_elementwise, diff_bilinear_matmul, diff_matmul_handler, diff_mul_handler, diff_softmax_handler, diff_softmax_pruning_handler, exp_linearizer, reciprocal_linearizer, relu_linearizer, tanh_linearizer, ) _relu_diff = linearizer_to_hander(relu_linearizer) _tanh_diff = linearizer_to_hander(tanh_linearizer) _exp_diff = linearizer_to_hander(exp_linearizer) _reciprocal_diff = linearizer_to_hander(reciprocal_linearizer) interpret["relu"] = _relu_diff interpret["Relu"] = _relu_diff interpret["tanh"] = _tanh_diff interpret["Tanh"] = _tanh_diff interpret["exp"] = _exp_diff interpret["Exp"] = _exp_diff interpret["reciprocal"] = _reciprocal_diff interpret["Reciprocal"] = _reciprocal_diff interpret["HeavisidePruning"] = const_heaviside_pruning interpret["TopKPruning"] = ( lambda scores, data, k, dim=-1, largest=True: const_topk_pruning(scores, data, k, dim=dim, largest=largest) ) from boundlab.diff.op import diff_pair_handler # noqa: E402 interpret["DiffPair"] = diff_pair_handler def onnx_boardcasted(fn): return lambda X, Y, *args, **kwargs: interp.FnList._call( fn, *interp._onnx_broadcast(X, Y), *args, **kwargs ) interpret["Mul"] = onnx_boardcasted(diff_mul_handler) interpret["MatMul"] = diff_matmul_handler interpret["Div"] = onnx_boardcasted( lambda a, b: diff_mul_handler( a, interpret["Reciprocal"](ConstVal(b) if isinstance(b, torch.Tensor) else b) ) ) interpret["Softmax"] = lambda X, axis=-1: diff_softmax_handler( X, dim=axis, exp_handler=interpret["Exp"], reciprocal_handler=interpret["Reciprocal"], ) interpret["SoftmaxPruning"] = lambda scores, data, dim=-1: diff_softmax_pruning_handler( scores, data, dim=dim, exp_handler=interpret["Exp"], reciprocal_handler=interpret["Reciprocal"], heaviside_handler=interpret["HeavisidePruning"], ) __all__ = [ "interpret", "DiffZonosqBounds", "DiffZonoBounds", "expr", "linearizer_to_hander", "relu_linearizer", "tanh_linearizer", "exp_linearizer", "reciprocal_linearizer", "diff_bilinear_elementwise", "diff_bilinear_matmul", "diff_softmax_handler", "diff_softmax_pruning_handler", ]