Source code for boundlab.zono.softmax2

"""Softmax2 handler for zonotope abstract interpretation.

Defines

.. math::

   \mathrm{softmax2}(x, y) = \frac{x}{1 + x\,\exp(y)}
"""

from __future__ import annotations

import torch
from torch import nn

from boundlab.expr._core import Expr
from boundlab.gradlin import gradlin
from boundlab.linearop._einsum import EinsumOp
from . import ZonoBounds, _register_linearizer


def softmax2(x, y):
    if isinstance(x, int) and x == 1:
        return torch.sigmoid(-y)
    else:
        return torch.nan_to_num(
            torch.reciprocal(torch.reciprocal(x) + torch.exp(y)),
            nan=0.0
        )

def softmax2dx(x, y):
    return torch.nan_to_num(
        1 / (1 + x * torch.exp(y))**2,
        nan=0.0
    )

def softmax2dy(x, y):
    return -softmax2(x, y) * (1 - softmax2(x, y) / x)

def softmax2dy_inv(x, lamy, sign=1):
    # Domain for real inverse: x > 0, lamy in [-x/4, 0).
    # Near lamy -> 0-, the direct quadratic expression suffers cancellation
    # (especially for sign=1). Use asymptotic branches in that regime.
    disc = (x * (4 * lamy + x)).clamp(min=0.0)
    root = torch.sqrt(disc)

    den = 2 * lamy * x
    ratio = -((2 * lamy + x + sign * root) / den)

    near0 = lamy.abs() <= 1e-8
    if sign == 1:
        # small-root branch: exp(y) ~ -lamy / x^2
        ratio_asym = (-lamy) / (x * x)
    else:
        # large-root branch: exp(y) ~ 1 / (-lamy)
        ratio_asym = 1.0 / (-lamy)

    ratio = torch.where(near0, ratio_asym, ratio)
    finfo = torch.finfo(ratio.dtype)
    ratio = torch.nan_to_num(ratio, nan=finfo.tiny, posinf=finfo.max, neginf=finfo.tiny)
    ratio = ratio.clamp(min=finfo.tiny, max=finfo.max)
    return torch.log(ratio)


def softmax2_ub(lamx: torch.Tensor, lamy: torch.Tensor, x_ub: torch.Tensor, x_lb: torch.Tensor, y_ub: torch.Tensor, y_lb: torch.Tensor) -> torch.Tensor:
    assert torch.isfinite(lamx).all(), "softmax2_ub: lamx must be finite"
    assert torch.isfinite(lamy).all(), "softmax2_ub: lamy must be finite"
    assert (x_ub >= 1e-8).all(), "softmax2_ub: x_ub must be positive"

    ypos_ub = softmax2dy_inv(x_ub, lamy, sign=-1)
    ypos_lb = softmax2dy_inv(x_lb, lamy, sign=-1)
    sqrt_lamx = torch.sqrt(lamx)
    lambda0 = 1 / sqrt_lamx - 1
    finfo = torch.finfo(lamx.dtype)
    ratio_ub = (lambda0 / x_ub).clamp(min=finfo.tiny, max=finfo.max)
    ratio_lb = (lambda0 / x_lb).clamp(min=finfo.tiny, max=finfo.max)
    yi_ub = torch.log(ratio_ub)
    yi_lb = torch.log(ratio_lb)
    assert torch.isfinite(lambda0).all(), "softmax2_ub: lambda0 non-finite"
    assert torch.isfinite(ypos_ub).all(), "softmax2_ub: ypos_ub non-finite"
    assert torch.isfinite(ypos_lb).all(), "softmax2_ub: ypos_lb non-finite"
    
    ypos_ub = torch.minimum(ypos_ub, yi_ub)
    ypos_ub = torch.clamp(ypos_ub, y_lb, y_ub)
    ypos_lb = torch.maximum(ypos_lb, yi_lb)
    ypos_lb = torch.clamp(ypos_lb, y_lb, y_ub)
    yi_ub = torch.clamp(yi_ub, y_lb, y_ub)
    yi_lb = torch.clamp(yi_lb, y_lb, y_ub)
 
    def f(y):
        # lambda0 / (1 + lambda0) = 1 - sqrt(lam), numerically safer near lam -> 0.
        x = lambda0 * torch.exp(-y)
        return torch.where((y <= yi_ub) & (yi_ub >= y_lb + 1e-12), softmax2(x_ub, y) - lamx * x_ub,
               torch.where((y >= yi_lb) & (yi_lb <= y_ub - 1e-12), softmax2(x_lb, y) - lamx * x_lb,
                        softmax2(x, y) - lamx * x))

    ub = torch.stack([
        f(ypos_ub) - lamy * ypos_ub,
        f(ypos_lb) - lamy * ypos_lb,
        f(yi_lb) - lamy * yi_lb,
        f(yi_ub) - lamy * yi_ub,
        f(y_lb) - lamy * y_lb,
        f(y_ub) - lamy * y_ub,
    ], dim=0).max(dim=0).values
    # print(f"ypos_ub: {ypos_ub.item():.12g}, ypos_ub value: {f(ypos_ub).item() - lamy.item() * ypos_ub.item():.12g}")
    # print(f"ypos_lb: {ypos_lb.item():.12g}, ypos_lb value: {f(ypos_lb).item() - lamy.item() * ypos_lb.item():.12g}")
    # print(f"yi_ub: {yi_ub.item():.12g}, yi_ub value: {f(yi_ub).item() - lamy.item() * yi_ub.item():.12g}")
    # print(f"yi_lb: {yi_lb.item():.12g}, yi_lb value: {f(yi_lb).item() - lamy.item() * yi_lb.item():.12g}")

    # torch.where evaluates both branches; clamp lamy for the ub2 helper branch
    # to avoid assertion failures when lamx is not in the fallback region.
    lamy_ub2 = torch.clamp(lamy, min=-1.0 + 1e-8, max=-1e-8)
    ub = torch.where(
        lamx.abs() <= 1e-8,
        torch.maximum(
            softmax2_ub2(lamy_ub2, x_ub, y_ub, y_lb) - lamx * x_ub,
            softmax2_ub2(lamy_ub2, x_lb, y_ub, y_lb) - lamx * x_lb,
        ),
        ub
    )

    assert torch.isfinite(ub).all(), "softmax2_ub: output became non-finite"
    return ub

def softmax2_ub2(lam: torch.Tensor, x: torch.Tensor, y_ub: torch.Tensor, y_lb: torch.Tensor) -> torch.Tensor:
    assert torch.isfinite(lam).all(), "softmax2_ub2: lam must be finite"
    assert ((lam <= 0.0) & (lam >= -1.0)).all(), "softmax2_ub2: lam must satisfy -1 < lam < 0"

    ypos = softmax2dy_inv(x, lam, sign=-1)
    ypos = torch.where(torch.isfinite(ypos), ypos, y_lb)
    ypos = torch.clamp(ypos, y_lb, y_ub)

    ub = torch.stack([
        softmax2(x, ypos) - lam * ypos,
        softmax2(x, y_ub) - lam * y_ub,
        softmax2(x, y_lb) - lam * y_lb,
    ]).max(dim=0).values

    assert torch.isfinite(ub).all(), "softmax2_ub2: output became non-finite"
    return ub
    
def softmax2_lb(lam: torch.Tensor, x: torch.Tensor, y_ub: torch.Tensor, y_lb: torch.Tensor) -> torch.Tensor:
    assert torch.isfinite(lam).all(), "softmax2_lb: lam must be finite"
    assert ((lam <= 0.0) & (lam >= -1.0)).all(), "softmax2_lb: lam must satisfy 0 < lam < 1"

    yneg = softmax2dy_inv(x, lam, sign=1)
    yneg = torch.where(torch.isfinite(yneg), yneg, y_ub)
    yneg = torch.clamp(yneg, y_lb, y_ub)

    lb = torch.stack([
        softmax2(x, yneg) - lam * yneg,
        softmax2(x, y_lb) - lam * y_lb,
        softmax2(x, y_ub) - lam * y_ub,
    ]).min(dim=0).values
    assert torch.isfinite(lb).all(), "softmax2_lb: output became non-finite"
    return lb


[docs] @_register_linearizer("Softmax2") def softmax2_linearizer( x_ub: torch.Tensor, x_lb: torch.Tensor, y_ub: torch.Tensor, y_lb: torch.Tensor, niters = 1, ) -> ZonoBounds: """Gradlin-based linearizer for ``x / (1 + x * exp(y))``.""" # Exact singleton box: return an exact affine form with zero error. if ((x_ub == x_lb) & (y_ub == y_lb)).all(): x = x_ub y = y_ub lamx = softmax2dx(x, y) lamy = softmax2dy(x, y) mu = softmax2(x, y) - lamx * x - lamy * y beta = torch.zeros_like(mu) return ZonoBounds( bias=mu, error_coeffs=EinsumOp.from_hardmard(beta, len(x_ub.shape)), input_weights=[lamx, lamy], ) x_center = (x_ub + x_lb) / 2 lamx = nn.Parameter( torch.minimum(softmax2dx(x_ub, y_ub), softmax2dx(x_lb, y_lb)), ) lamy = nn.Parameter( torch.maximum(softmax2dy(x_ub, y_lb), softmax2dy(x_lb, y_ub)), ) gradlin_optimizer = torch.optim.Adam([lamx, lamy], lr=1e-2) for _ in range(niters): with torch.no_grad(): lamx.nan_to_num_(nan=1e-6, posinf=1e-6, neginf=1e-6) lamy.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) lamx.clamp_(min=1e-12, max=1.0 - 1e-8) lamy.clamp_(min=-1.0 + 1e-8, max=-1e-8) assert torch.isfinite(lamx).all(), "softmax2_linearizer: lamx non-finite before optimizer step" assert torch.isfinite(lamy).all(), "softmax2_linearizer: lamy non-finite before optimizer step" gradlin_optimizer.zero_grad() ub = softmax2_ub(lamx, lamy, x_ub, x_lb, y_ub, y_lb) lb = torch.minimum( softmax2_lb(lamy, x_ub, y_ub, y_lb) - lamx * x_ub, softmax2_lb(lamy, x_lb, y_ub, y_lb) - lamx * x_lb, ) assert torch.isfinite(ub).all() and torch.isfinite(lb).all(), \ "softmax2_linearizer: ub/lb non-finite during optimization" assert (ub >= lb - 1e-6).all(), "softmax2_linearizer: ub < lb during optimization, which should be impossible" loss = (ub - lb).mean() assert torch.isfinite(loss), "softmax2_linearizer: loss non-finite" loss.backward() gradlin_optimizer.step() lamx = lamx.detach() lamy = lamy.detach() lamx = torch.nan_to_num(lamx, nan=1e-6, posinf=1e-6, neginf=1e-6) lamy = torch.nan_to_num(lamy, nan=-1e-3, posinf=-1e-8, neginf=-1.0 + 1e-8) lamx = torch.clamp(lamx, min=1e-12, max=1.0 - 1e-8) lamy = torch.clamp(lamy, min=-1.0 + 1e-8, max=-1e-8) ub = softmax2_ub(lamx, lamy, x_ub, x_lb, y_ub, y_lb) lb = torch.minimum( softmax2_lb(lamy, x_ub, y_ub, y_lb) - lamx * x_ub, softmax2_lb(lamy, x_lb, y_ub, y_lb) - lamx * x_lb, ) bad = ub < lb if bad.any(): mid = (ub + lb) / 2 ub = torch.where(bad, mid, ub) lb = torch.where(bad, mid, lb) beta = (ub - lb) / 2 mu = (ub + lb) / 2 return ZonoBounds( bias=mu, error_coeffs=EinsumOp.from_hardmard(beta, len(x_ub.shape)), input_weights=[lamx, lamy], )
def softmax2_ibp( x_ub: torch.Tensor, x_lb: torch.Tensor, y_ub: torch.Tensor, y_lb: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Interval bound for ``softmax2`` on a box. For ``x > 0``, softmax2 is monotone increasing in ``x`` and decreasing in ``y``: ub = f(x_ub, y_lb), lb = f(x_lb, y_ub). """ ub = softmax2(x_ub, y_lb) lb = softmax2(x_lb, y_ub) return ub, lb
[docs] def softmax2_handler(x: Expr, y: Expr) -> Expr: assert x.shape == y.shape, f"softmax2 expects matching shapes, got {x.shape} vs {y.shape}" from . import interpret return interpret["Softmax2"](x, y)