Source code for boundlab.zono.relu
import torch
from boundlab.linearop._base import LinearOpFlags
from boundlab.linearop._einsum import EinsumOp
from boundlab.linearop._indices import SetIndicesOp
from . import ZonoBounds, _register_linearizer
[docs]
@_register_linearizer("relu")
def relu_linearizer(ub: torch.Tensor, lb: torch.Tensor) -> ZonoBounds:
"""Triangle relaxation of ReLU for zonotope abstract interpretation.
For each neuron with input bounds [lb, ub]:
- Dead (ub <= 0): output is 0, no contribution.
- Active (lb >= 0): output equals input exactly.
- Crossing (lb < 0 < ub): triangle relaxation with
slope = ub / (ub - lb),
bias = -ub * lb / (2 * (ub - lb)),
error = -ub * lb / (2 * (ub - lb)).
"""
output_shape = ub.shape
lam = (torch.relu(ub) - torch.relu(lb)) / (ub - lb + 1e-30)
mu = 0.5 * (torch.relu(ub) - lam * ub)
# nonzero_idx = torch.nonzero(cross, as_tuple=True)
# length = nonzero_idx[0].shape[0]
# cross_coeffs = cross_val[nonzero_idx]
# indices_op = SetIndicesOp(nonzero_idx, torch.Size((length,)), output_shape)
# hardmard_op = EinsumOp.from_hardmard(cross_coeffs, 1)
# hardmard_op.flags |= LinearOpFlags.IS_NON_NEGATIVE
hardmard_op = EinsumOp.from_hardmard(mu, len(ub.shape))
return ZonoBounds(bias=mu, error_coeffs=hardmard_op, input_weights=[lam])