"""
Differential lineariser for ``heaviside_pruning``.
The ONNX op encodes a pruning mask applied only to the **y** network:
out_x = x
out_y = heaviside(scores_y) * y
out_d = out_x - out_y
We rewrite the mask as ``y - h(-s_y) * y`` so the only non-linearity is the
product ``h(-s_y) * y``. The helper below constructs an affine enclosure
h(s) * x ≈ w_s * s + w_x * x + bias ± err
following the rules in ``examples/vit/vit_plan.md`` (item 4):
* If ``ls + us > 0`` linearise ``x - h(-s) * x`` instead (forces ``ls + us``
non-positive).
* If ``lx + ux > 0`` linearise ``-h(s) * (-x)`` instead (forces ``lx + ux``
non-positive).
* With ``ls + us <= 0`` and ``lx + ux <= 0``:
- ``us < 0``: always zero.
- ``ls < 0 < us`` and ``ux <= 0``: ``lam = max(lx / -ls, ux / us)`` with
bounds ``lx <= h(s)*x - lam*s <= 0``.
- ``ls < 0 < us`` and ``ux > 0``: ``lam = min(ux / (ux - lx), -lx / (ux - lx))``
with bounds ``(1 - lam) * lx <= h(s)*x - lam*x <= (1 - lam) * ux``.
Both ``y`` and ``diff`` share the same approximation error epsilon so the mask
correlation is preserved between outputs.
"""
from __future__ import annotations
import torch
from boundlab import expr
from boundlab.expr._core import Expr
from boundlab.linearop._einsum import EinsumOp
from boundlab.prop import ublb
from boundlab.zono import ZonoBounds
from .. import DiffZonoBounds, interpret
_EPS = 1e-30
def _linearize_base(ls, us, lx, ux):
"""Handles ls+us<=0 and lx+ux<=0; no flips."""
zeros = torch.zeros_like(ls)
w_s = zeros.clone()
w_x = zeros.clone()
bias = zeros.clone()
err = zeros.clone()
mask_zero = us < 0
mask_case2 = (~mask_zero) & (ux <= 0)
lam_num1 = torch.where(mask_case2, lx, zeros)
lam_den1 = torch.where(mask_case2, -ls + _EPS, torch.ones_like(ls))
lam_num2 = torch.where(mask_case2, ux, zeros)
lam_den2 = torch.where(mask_case2, us + _EPS, torch.ones_like(us))
lam_case2 = torch.maximum(lam_num1 / lam_den1, lam_num2 / lam_den2)
w_s = torch.where(mask_case2, lam_case2, w_s)
bias = torch.where(mask_case2, 0.5 * lx, bias)
err = torch.where(mask_case2, -0.5 * lx, err)
mask_case3 = (~mask_zero) & (ux > 0)
denom3 = torch.where(mask_case3, ux - lx + _EPS, torch.ones_like(ux))
lam3 = torch.minimum(ux / denom3, (-lx) / denom3)
w_x = torch.where(mask_case3, lam3, w_x)
bias = torch.where(mask_case3, 0.5 * (1 - lam3) * (lx + ux), bias)
err = torch.where(mask_case3, 0.5 * (1 - lam3) * (ux - lx), err)
return w_s, w_x, bias, err
def _linearize_no_s_flip0(ls, us, lx, ux):
"""Assumes ls+us<=0; handles x-flip and base."""
zeros = torch.zeros_like(ls)
w_s = zeros.clone()
w_x = zeros.clone()
bias = zeros.clone()
err = zeros.clone()
mask_x_flip = lx + ux > 0
ws_xflip, wx_xflip, b_xflip, e_xflip = _linearize_base(ls[mask_x_flip], us[mask_x_flip],
-ux[mask_x_flip], -lx[mask_x_flip])
w_s[mask_x_flip] = -ws_xflip
w_x[mask_x_flip] = wx_xflip
bias[mask_x_flip] = -b_xflip
err[mask_x_flip] = e_xflip
mask_base = ~(mask_x_flip)
ws_base, wx_base, b_base, e_base = _linearize_base(ls[mask_base], us[mask_base],
lx[mask_base], ux[mask_base])
w_s[mask_base] = ws_base
w_x[mask_base] = wx_base
bias[mask_base] = b_base
err[mask_base] = e_base
return w_s, w_x, bias, err
def _linearize_no_s_flip(ls, us, lx, ux):
"""Assumes ls+us<=0; handles x-flip and base."""
zeros = torch.zeros_like(ls)
mask_x_flip = lx + ux > 0
lx2 = torch.where(mask_x_flip, -ux, lx)
ux2 = torch.where(mask_x_flip, -lx, ux)
ws_0, wx_0, b_x0, e_x0 = _linearize_base(ls, us, lx2, ux2)
w_s = torch.where(mask_x_flip, -ws_0, ws_0)
w_x = torch.where(mask_x_flip, wx_0, wx_0)
bias = torch.where(mask_x_flip, -b_x0, b_x0)
err = torch.where(mask_x_flip, e_x0, e_x0)
return w_s, w_x, bias, err
def _linearize_hsx0(ls: torch.Tensor, us: torch.Tensor,
lx: torch.Tensor, ux: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return (w_s, w_x, bias, err) such that
``h(s) * x ≈ w_s * s + w_x * x + bias ± err``.
Element-wise implementation mirroring the case split in ``vit_plan.md``.
"""
if ls.numel() == 0:
z = torch.zeros_like(ls)
return z, z, z, z
zeros = torch.zeros_like(ls)
w_s = zeros.clone()
w_x = zeros.clone()
bias = zeros.clone()
err = zeros.clone()
mask_s_flip = ls + us > 0
ws_flip, wx_flip, b_flip, e_flip = _linearize_no_s_flip(
-us[mask_s_flip], -ls[mask_s_flip], lx[mask_s_flip], ux[mask_s_flip]
)
w_s[mask_s_flip] = ws_flip
w_x[mask_s_flip] = 1 - wx_flip
bias[mask_s_flip] = -b_flip
err[mask_s_flip] = e_flip
mask_remaining = ~mask_s_flip
ws_rest, wx_rest, b_rest, e_rest = _linearize_no_s_flip(
ls[mask_remaining], us[mask_remaining], lx[mask_remaining], ux[mask_remaining]
)
w_s[mask_remaining] = ws_rest
w_x[mask_remaining] = wx_rest
bias[mask_remaining] = b_rest
err[mask_remaining] = e_rest
return w_s, w_x, bias, err
def _linearize_hsx(ls: torch.Tensor, us: torch.Tensor,
lx: torch.Tensor, ux: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return (w_s, w_x, bias, err) such that
``h(s) * x ≈ w_s * s + w_x * x + bias ± err``.
Element-wise implementation mirroring the case split in ``vit_plan.md``.
"""
if ls.numel() == 0:
z = torch.zeros_like(ls)
return z, z, z, z
zeros = torch.zeros_like(ls)
ls2 = torch.where(ls + us > 0, -us, ls)
us2 = torch.where(ls + us > 0, -ls, us)
mask_s_flip = ls + us > 0
ws_0, wx_0, b_0, e_0 = _linearize_no_s_flip(ls2, us2, lx, ux)
w_s = torch.where(mask_s_flip, ws_0, ws_0)
w_x = torch.where(mask_s_flip, 1 - wx_0, wx_0)
bias = torch.where(mask_s_flip, -b_0, b_0)
err = torch.where(mask_s_flip, e_0, e_0)
return w_s, w_x, bias, err
def _linearize_mask_term(s_y: Expr, y: Expr) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Linearise ``h(-s_y) * y``.
Returns weights for the original ``s_y`` (not ``-s_y``) and ``y``.
"""
s_ub, s_lb = ublb(s_y)
y_ub, y_lb = ublb(y)
# We linearise h(s) * x with s = -s_y, x = y
ls, us = -s_ub, -s_lb
w_s, w_x, bias, err = _linearize_hsx(ls, us, y_lb, y_ub)
# Convert back to sy (since s = -sy)
w_sy = -w_s
return w_sy, w_x, bias, err
[docs]
def diff_heaviside_pruning_handler(scores, data):
"""Differential handler for ``boundlab::heaviside_pruning``.
Args:
scores: DiffExpr3 / DiffExpr2 providing pruning scores. Only the
**y** component participates in masking.
data: DiffExpr3 / DiffExpr2 providing the tensor to prune.
"""
from boundlab.diff.expr import DiffExpr2, DiffExpr3
# Promote scores/data into DiffExpr3 when possible
if isinstance(scores, DiffExpr2):
sy = scores.y
elif not isinstance(scores, DiffExpr3):
# Treat constant scores as identical for both networks
from boundlab import expr as _expr
sy = _expr.ConstVal(scores) if isinstance(scores, torch.Tensor) else scores
if isinstance(data, DiffExpr2):
data = DiffExpr3(data.x, data.y, data.x - data.y)
elif not isinstance(data, DiffExpr3):
from boundlab import expr as _expr
data_expr = _expr.ConstVal(data) if isinstance(data, torch.Tensor) else data
data = DiffExpr3(data_expr, data_expr, data_expr * 0)
assert isinstance(sy, Expr) and isinstance(data, DiffExpr3), (
f"heaviside_pruning requires expressions convertible to DiffExpr3: {scores} vs {data}")
y = data.y
# Linearise t = h(-sy) * y → weights on sy and y
w_sy, w_y, bias_t, err_t = _linearize_mask_term(sy, y)
zeros = torch.zeros_like(bias_t)
# x component: passthrough of data.x
x_bounds = ZonoBounds(
bias=zeros,
error_coeffs=zeros,
input_weights=[0, torch.ones_like(w_y)],
)
# y component: y - t
y_bounds = ZonoBounds(
bias=-bias_t,
error_coeffs=err_t,
input_weights=[w_sy, torch.ones_like(w_y) - w_y],
)
# diff component: d + t
diff_bounds = ZonoBounds(
bias=bias_t,
error_coeffs=zeros,
input_weights=[0, torch.ones_like(w_y)], # weights on ds: (scores_diff, data_diff)
)
err_op = EinsumOp.from_hardmard(err_t, len(err_t.shape))
dzb = DiffZonoBounds(
x_bounds=x_bounds,
y_bounds=y_bounds,
diff_bounds=diff_bounds,
diff_x_error=EinsumOp.from_hardmard(zeros, len(err_t.shape)),
diff_x_weights=[0, 0],
diff_y_error=err_op,
diff_y_weights=[-w_sy, w_y],
)
from .. import _build_triple_from_dzb
xs = [0, data.x]
ys = [sy, data.y]
ds = [0, data.diff]
return _build_triple_from_dzb(dzb, xs, ys, ds)
# Register handler
interpret["HeavisidePruning"] = diff_heaviside_pruning_handler
__all__ = ["diff_heaviside_pruning_handler"]