Source code for boundlab.zono.relu
import torch
from boundlab.expr._core import Expr
from boundlab.linearop import LinearOp
from boundlab.linearop._base import LinearOpFlags
from boundlab.linearop._einsum import EinsumOp
from boundlab.linearop._indices import SetIndicesOp
from . import ZonoBounds, _register_linearizer
[docs]
@_register_linearizer("relu")
def relu_linearizer(expr: Expr) -> ZonoBounds:
"""Triangle relaxation of ReLU for zonotope abstract interpretation.
For each neuron with input bounds [lb, ub]:
- Dead (ub <= 0): output is 0, no contribution.
- Active (lb >= 0): output equals input exactly.
- Crossing (lb < 0 < ub): triangle relaxation with
slope = ub / (ub - lb),
bias = -ub * lb / (2 * (ub - lb)),
error = -ub * lb / (2 * (ub - lb)).
"""
lb = expr.lb()
ub = expr.ub()
output_shape = ub.shape
dead = ub <= 0
active = lb >= 0
cross = ~dead & ~active
slope = torch.where(active, torch.ones_like(ub), torch.zeros_like(ub))
slope = torch.where(cross, ub / (ub - lb), slope)
error = torch.zeros_like(ub)
cross_val = -ub * lb / (2 * (ub - lb))
bias = torch.where(cross, cross_val, torch.zeros_like(ub))
nonzero_idx = torch.nonzero(cross, as_tuple=True)
length = nonzero_idx[0].shape[0]
cross_coeffs = cross_val[nonzero_idx]
indices_op = SetIndicesOp(nonzero_idx, torch.Size((length,)), output_shape)
hardmard_op = EinsumOp.from_hardmard(cross_coeffs, 1)
hardmard_op.flags |= LinearOpFlags.IS_NON_NEGATIVE
return ZonoBounds(bias=bias, error_coeffs=indices_op @ hardmard_op, input_weights=[slope])