Source code for boundlab.diff.zono3.relu

"""Differential ReLU lineariser (VeryDiff Table 1, 9-case split).

Cases 1–8 reduce to ``relu(x) − relu(y)`` via the standard triangle
relaxation applied independently to *x* and *y*.  For each such case
at most one of (mu_x, mu_y) is non-zero (they occupy disjoint regimes),
so summing the two error terms is lossless.

Case 9 (both crossing) uses a tighter triangle relaxation directly on
``d = x − y`` with slope ``alpha = clamp(d_ub / (d_ub − d_lb), 0, 1)``
and half-width ``mu_d = 0.5 · max(d_ub, −d_lb)``.
"""

import torch

from boundlab.expr._core import Expr
from boundlab.linearop._base import LinearOpFlags
from boundlab.linearop._einsum import EinsumOp
from boundlab.linearop._indices import SetIndicesOp
from boundlab.prop import ublb
from boundlab.zono import ZonoBounds
from boundlab.zono.relu import relu_linearizer as std_relu_linearizer

from . import _register_linearizer


[docs] @_register_linearizer("relu") def relu_linearizer( x: Expr, y: Expr, diff: Expr ) -> tuple[ZonoBounds, ZonoBounds, ZonoBounds]: """Return ``(x_bounds, y_bounds, diff_bounds)`` for differential ReLU. *x_bounds* and *y_bounds* are standard triangle-relaxation zonotopes. *diff_bounds* over-approximates ``relu(x) − relu(y)`` with ``input_weights = [sx, sy, sd]`` corresponding to inputs ``[x, y, diff]``. Examples -------- Active/active regime: diff weights are ``sx=1, sy=-1``, equivalent to ``x - y``: >>> import torch >>> import boundlab.expr as expr >>> from boundlab.diff.zono3.relu import relu_linearizer >>> x = expr.ConstVal(torch.tensor([2.0])) + 0.5 * expr.LpEpsilon([1]) >>> y = expr.ConstVal(torch.tensor([1.0])) + 0.5 * expr.LpEpsilon([1]) >>> d = x - y >>> _, _, d_bounds = relu_linearizer(x, y, d) >>> sx, sy = d_bounds.input_weights[0], d_bounds.input_weights[1] >>> d_expr = sx * x + sy * y + d_bounds.bias >>> torch.allclose(d_expr.ub(), d.ub(), atol=1e-5) True Crossing/crossing regime is still sound for ``relu(x) - relu(y)``: >>> x = expr.ConstVal(torch.tensor([0.0])) + 0.8 * expr.LpEpsilon([1]) >>> y = expr.ConstVal(torch.tensor([0.1])) + 0.8 * expr.LpEpsilon([1]) >>> d = x - y >>> _, _, d_bounds = relu_linearizer(x, y, d) >>> d_bounds.bias.shape torch.Size([1]) """ x_bounds = std_relu_linearizer(x) y_bounds = std_relu_linearizer(y) # ------------------------------------------------------------------ # Cases 1–8: diff ≈ relu(x) − relu(y) # ------------------------------------------------------------------ sx = x_bounds.input_weights[0] # slope_x (0 / 1 / λ_x) sy = -y_bounds.input_weights[0] # −slope_y bias = x_bounds.bias - y_bounds.bias # μ_x − μ_y err = x_bounds.bias.abs() + y_bounds.bias.abs() # |μ_x| + |μ_y| # ------------------------------------------------------------------ # Case 9: both crossing → triangle relaxation on d directly # ------------------------------------------------------------------ x_ub, x_lb = ublb(x) y_ub, y_lb = ublb(y) d_ub, d_lb = ublb(diff) x_cross = (x_ub > 0) & (x_lb < 0) y_cross = (y_ub > 0) & (y_lb < 0) ones = torch.ones_like(x_ub) zeros = torch.zeros_like(x_ub) d_denom = d_ub - d_lb degen = d_denom.abs() < 1e-15 case9 = x_cross & y_cross & ~degen lam = torch.clamp(d_ub / d_denom, 0.0, 1.0) nu = lam * torch.clamp(-d_lb, min=0.0) mu = 0.5 * torch.maximum(d_ub, -d_lb) # Case 9 overrides: zero out x/y slopes, use d directly with lambda slope. sx = torch.where(case9, zeros, sx) sy = torch.where(case9, zeros, sy) sd = torch.where(case9, lam, zeros) bias = torch.where(case9, nu - mu, bias) err = torch.where(case9, mu, err) # ------------------------------------------------------------------ # Build the sparse error LinearOp # ------------------------------------------------------------------ output_shape = x_ub.shape error_indices = torch.nonzero(case9, as_tuple=True) error_len = error_indices[0].shape[0] error_vals = err[error_indices] indices_op = SetIndicesOp(error_indices, torch.Size((error_len,)), output_shape) hadamard_op = EinsumOp.from_hardmard(error_vals, 1) hadamard_op.flags |= LinearOpFlags.IS_NON_NEGATIVE error_op = indices_op @ hadamard_op diff_bounds = ZonoBounds(bias=bias, error_coeffs=error_op, input_weights=[sx, sy, sd]) return x_bounds, y_bounds, diff_bounds