Source code for boundlab.poly.relu

"""ReLU linearizer for polytope abstract interpretation.

CROWN-style relaxation of :math:`\\mathrm{ReLU}(x) = \\max(0, x)`.
"""

import torch

from . import PolyBounds, _register_linearizer


[docs] @_register_linearizer("relu") def relu_linearizer(ub: torch.Tensor, lb: torch.Tensor) -> PolyBounds: r"""CROWN relaxation of ReLU. For each neuron with input bounds :math:`[\ell, u]`: - Dead (:math:`u \le 0`): ``f(x) = 0``. - Active (:math:`\ell \ge 0`): ``f(x) = x``. - Crossing (:math:`\ell < 0 < u`): tight upper envelope through :math:`(\ell, 0)` and :math:`(u, u)`; lower bound given by the tangent of ReLU at the interval midpoint :math:`c = (\ell + u)/2`. Examples -------- >>> import torch >>> from boundlab.poly.relu import relu_linearizer >>> ub = torch.tensor([2.0, -1.0, 1.0]) >>> lb = torch.tensor([-1.0, -2.0, 0.5]) >>> b = relu_linearizer(ub, lb) >>> b.upper_lam.shape torch.Size([3]) """ zero = torch.zeros_like(ub) one = torch.ones_like(ub) active = lb >= 0 dead = ub <= 0 crossing = ~(active | dead) # Upper bound: secant through (lb, 0) and (ub, ub) on crossing intervals. cross_slope = ub / (ub - lb + 1e-30) cross_upper_bias = -cross_slope * lb # Lower bound: tangent at midpoint c = (lb + ub)/2. center = 0.5 * (ub + lb) tangent_slope = torch.where(center >= 0, one, zero) tangent_bias = torch.relu(center) - tangent_slope * center upper_lam = torch.where(active, one, torch.where(crossing, cross_slope, zero)) upper_bias = torch.where(crossing, cross_upper_bias, zero) lower_lam = torch.where(active, one, torch.where(crossing, tangent_slope, zero)) lower_bias = torch.where(active, zero, torch.where(crossing, tangent_bias, zero)) return PolyBounds( upper_lam=upper_lam, upper_bias=upper_bias, lower_lam=lower_lam, lower_bias=lower_bias, )