Source code for boundlab.diff.zono3.default.relu

"""Differential ReLU lineariser (VeryDiff Table 1, 9-case split).

Cases 1–8 reduce to ``relu(x) − relu(y)`` via the standard triangle
relaxation applied independently to *x* and *y*.  The error epsilons
introduced for *x* and *y* are **shared** with the diff component, so
the diff expression tracks ``x_output − y_output`` exactly for those
cases (no extra approximation error).

Case 9 (both crossing) uses a tighter triangle relaxation directly on
``d = x − y`` with slope ``alpha = clamp(d_ub / (d_ub − d_lb), 0, 1)``
and half-width ``mu_d = 0.5 · max(d_ub, −d_lb)``.  A fresh epsilon is
introduced only for case-9 neurons.
"""

import torch

from boundlab import utils
from boundlab.expr._core import Expr
from boundlab.linearop._einsum import EinsumOp
from boundlab.prop import ublb
from boundlab.zono import ZonoBounds
from .. import DiffZonoBounds


[docs] def relu_linearizer( xs: list[Expr], ys: list[Expr], ds: list[Expr] ) -> DiffZonoBounds: """Return a :class:`DiffZonoBounds` for differential ReLU. *x_bounds* and *y_bounds* are standard triangle-relaxation zonotopes. For cases 1–8 the diff reuses the **same** epsilon variables as *x* and *y*, making ``diff_output = x_output − y_output`` exactly. For case 9 (both crossing) a fresh epsilon is introduced for the diff component. Examples -------- Active/active regime: diff x-weight equals the relu slope (1.0): >>> import torch >>> import boundlab.expr as expr >>> from boundlab.diff.zono3.default.relu import relu_linearizer >>> x = expr.ConstVal(torch.tensor([2.0])) + 0.5 * expr.LpEpsilon([1]) >>> y = expr.ConstVal(torch.tensor([1.0])) + 0.5 * expr.LpEpsilon([1]) >>> d = x - y >>> dzb = relu_linearizer(x, y, d) >>> dzb.x_bounds.input_weights[0].item() 1.0 Crossing/crossing regime is still sound for ``relu(x) - relu(y)``: >>> x = expr.ConstVal(torch.tensor([0.0])) + 0.8 * expr.LpEpsilon([1]) >>> y = expr.ConstVal(torch.tensor([0.1])) + 0.8 * expr.LpEpsilon([1]) >>> d = x - y >>> dzb = relu_linearizer(x, y, d) >>> dzb.diff_bounds.bias.shape torch.Size([1]) """ x, y, diff = xs[0], ys[0], ds[0] # for type checking x_ub, x_lb = ublb(x) y_ub, y_lb = ublb(y) d_ub, d_lb = ublb(diff) zeros = torch.zeros_like(x_ub) # ------------------------------------------------------------------ # Standard triangle relaxation for x and y independently # ------------------------------------------------------------------ lam_x = (torch.relu(x_ub) - torch.relu(x_lb)) / (x_ub - x_lb + 1e-30) mu_x = 0.5 * (torch.relu(x_ub) - lam_x * x_ub) lam_y = (torch.relu(y_ub) - torch.relu(y_lb)) / (y_ub - y_lb + 1e-30) mu_y = 0.5 * (torch.relu(y_ub) - lam_y * y_ub) lam_avg = 0.5 * (lam_x + lam_y) dx = lam_x - lam_avg dy = lam_avg - lam_y sd = lam_avg bias_d = mu_x - mu_y ex = mu_x ey = -mu_y err_d = torch.zeros_like(ex) # ------------------------------------------------------------------ # Case 9: both x and y crossing → triangle relaxation on d directly # ------------------------------------------------------------------ case9 = (x_ub > 0) & (x_lb < 0) & (y_ub > 0) & (y_lb < 0) lam_d = torch.clamp(d_ub / (d_ub - d_lb + 1e-30), 0.0, 1.0) nu_d = lam_d * torch.clamp(-d_lb, min=0.0) mu_d = 0.5 * torch.maximum(d_ub, -d_lb) # ------------------------------------------------------------------ # Diff component — masked per neuron by case9 # # Non-case9: diff = lam_x·x − lam_y·y + (mu_x − mu_y) # + mu_x·eps_x − mu_y·eps_y (shared, exact!) # Case9: diff = lam_d·d + (nu_d − mu_d) + mu_d·eps_d (fresh eps) # ------------------------------------------------------------------ dx = torch.where(case9, zeros, dx) # x input weight for diff dy = torch.where(case9, zeros, dy) # y input weight for diff ex = torch.where(case9, zeros, mu_x) # scale applied to shared eps_x ey = torch.where(case9, zeros, -mu_y) # scale applied to shared eps_y sd = torch.where(case9, lam_d, sd) # d input weight for diff bias_d = torch.where(case9, nu_d - mu_d, bias_d) err_d = torch.where(case9, mu_d, err_d) # fresh error for case-9 neurons if not utils.current_fake_mode(): assert torch.isfinite(lam_x).all() and torch.isfinite(mu_x).all() and torch.isfinite(lam_y).all() and torch.isfinite(mu_y).all(), "Non-finite values in ReLU linearizer x/y bounds" assert torch.isfinite(dx).all() and torch.isfinite(dy).all() and torch.isfinite(ex).all() and torch.isfinite(ey).all() \ and torch.isfinite(sd).all() and torch.isfinite(bias_d).all() and torch.isfinite(err_d).all(), "Non-finite values in ReLU linearizer outputs" return DiffZonoBounds( x_bounds=ZonoBounds(bias=mu_x, error_coeffs=mu_x, input_weights=[lam_x]), y_bounds=ZonoBounds(bias=mu_y, error_coeffs=mu_y, input_weights=[lam_y]), diff_bounds=ZonoBounds( bias=bias_d, error_coeffs=EinsumOp.from_hardmard(err_d, len(x_ub.shape)), input_weights=[sd], ), diff_x_error=EinsumOp.from_hardmard(ex, len(x_ub.shape)), diff_x_weights=[dx], diff_y_error=EinsumOp.from_hardmard(ey, len(x_ub.shape)), diff_y_weights=[dy], )