Source code for boundlab.poly.softmax

"""Softmax handler for polytope abstract interpretation.

Implements softmax through the same DeepT decomposition used by the
zonotope backend:

.. math::

   \mathrm{softmax}(x)_i
     = \frac{1}{\sum_j \exp(x_j - x_i)}

This keeps the transformation in terms of pairwise subtraction,
exponential, reduce-sum and reciprocal.
"""

import torch

from boundlab import utils
from boundlab.expr._core import Expr
from . import _bounds_to_expr
from .reciprocal import reciprocal_linearizer


[docs] def softmax_handler(x: Expr, dim: int = -1, dtype=None) -> Expr: """Polytope softmax transformer via the DeepT decomposition.""" if not isinstance(x, Expr): return NotImplemented if dim < 0: dim += len(x.shape) assert dim == len(x.shape) - 1, "softmax_handler only supports the last dimension" diff = -utils.pairwise_diff(x, dim) from . import interpret exp_diff = interpret["Exp"](diff) sum_exp = exp_diff.sum(dim=-1) diff_ub, diff_lb = diff.ublb() # Tighten the denominator bounds using the exact interval image of exp. sum_exp_ub, sum_exp_lb = sum_exp.ublb() exact_sum_ub = torch.exp(diff_ub).sum(dim=-1) exact_sum_lb = torch.exp(diff_lb).sum(dim=-1) sum_exp_ub = torch.minimum(sum_exp_ub, exact_sum_ub) sum_exp_lb = torch.maximum(sum_exp_lb, exact_sum_lb) finite_mask = torch.isfinite(sum_exp_ub) & torch.isfinite(sum_exp_lb) sum_exp_ub = torch.where(finite_mask, sum_exp_ub, torch.ones_like(sum_exp_ub)) sum_exp_lb = torch.where(finite_mask, sum_exp_lb, torch.ones_like(sum_exp_lb)) sum_exp_lb = torch.clamp(sum_exp_lb, min=1e-30) bounds = reciprocal_linearizer(sum_exp_ub, sum_exp_lb) return _bounds_to_expr(sum_exp, bounds, reason=reciprocal_linearizer.__name__)