Source code for boundlab.poly.tanh
"""Tanh linearizer for polytope abstract interpretation."""
import torch
from . import PolyBounds, _register_linearizer
[docs]
@_register_linearizer("tanh")
def tanh_linearizer(ub: torch.Tensor, lb: torch.Tensor) -> PolyBounds:
r"""CROWN relaxation of :math:`\tanh`.
``tanh`` is concave for :math:`x \ge 0` and convex for :math:`x \le 0`.
For each neuron we use the common CROWN envelope:
- When both bounds share a sign, the tight secant forms the concave-
side envelope and a tangent (at the bound midpoint) forms the
convex-side envelope.
- For sign-crossing intervals we use the minimum-slope line through
the respective endpoint on each side, giving a sound (though not
tightest-possible) relaxation.
"""
degen = torch.abs(ub - lb) < 1e-12
tl = torch.tanh(lb)
tu = torch.tanh(ub)
denom = (ub - lb).clamp(min=1e-30)
secant = (tu - tl) / denom
secant_bias_u = tu - secant * ub # tu − slope·ub == tl − slope·lb
# Tangent at the midpoint (slope 1 − tanh²(m)).
mid = 0.5 * (lb + ub)
tm = torch.tanh(mid)
tangent_slope = 1.0 - tm * tm
tangent_bias = tm - tangent_slope * mid
safe_slope = torch.minimum(1.0 - tl * tl, 1.0 - tu * tu)
non_negative = lb >= 0
non_positive = ub <= 0
upper_lam = torch.where(
non_positive,
secant,
torch.where(non_negative, tangent_slope, safe_slope),
)
upper_bias = torch.where(
non_positive,
secant_bias_u,
torch.where(non_negative, tangent_bias, tu - safe_slope * ub),
)
lower_lam = torch.where(
non_negative,
secant,
torch.where(non_positive, tangent_slope, safe_slope),
)
lower_bias = torch.where(
non_negative,
secant_bias_u,
torch.where(non_positive, tangent_bias, tl - safe_slope * lb),
)
exact_lam = 1.0 - tl * tl
exact_bias = tl - exact_lam * lb
upper_lam = torch.where(degen, exact_lam, upper_lam)
upper_bias = torch.where(degen, exact_bias, upper_bias)
lower_lam = torch.where(degen, exact_lam, lower_lam)
lower_bias = torch.where(degen, exact_bias, lower_bias)
return PolyBounds(
upper_lam=upper_lam,
upper_bias=upper_bias,
lower_lam=lower_lam,
lower_bias=lower_bias,
)