Source code for boundlab.zono.softmax

"""Softmax handler for zonotope abstract interpretation.

Implements softmax as a composed operation:
  softmax(x)_j = exp(x_j) / sum_k exp(x_k)

Following DeepT (Bonaert et al., 2021), uses numerical stabilization
by subtracting the center's max value before applying exp.
"""

import torch

from boundlab.expr._core import Expr
from boundlab.expr._affine import ConstVal
from boundlab.expr._var import LpEpsilon
from .bilinear import bilinear_elementwise


[docs] def softmax_handler(x: Expr, dim: int = -1, dtype=None) -> Expr: r"""Zonotope softmax transformer built from primitive handlers. Softmax is decomposed as: .. math:: \mathrm{softmax}(x)_j = \frac{\exp(x_j)}{\sum_k \exp(x_k)} The implementation applies: ``exp -> reduce-sum -> reciprocal -> element-wise product``. For stability, it first shifts by the center maximum along the softmax dimension. Currently, only 2D inputs with ``dim == 1`` are supported. Args: x: Input expression with shape (m, n). dim: Dimension along which to apply softmax (default: -1). dtype: Ignored (for API compatibility with torch.softmax). Returns: An expression over-approximating ``torch.softmax(x, dim=dim)``. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.zono.softmax import softmax_handler >>> x = expr.ConstVal(torch.zeros(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> y = softmax_handler(x, dim=1) >>> y.shape torch.Size([2, 3]) """ ndim = len(x.shape) if dim < 0: dim = ndim + dim assert ndim == 2 and dim == 1, \ f"Softmax currently only supports 2D tensors along last dim, got shape {x.shape} dim {dim}" n = x.shape[dim] # Numerical stability: shift by center's max along softmax dim x_center = x.center() x_max = x_center.max(dim=dim, keepdim=True).values # (m, 1) x_shifted = x - x_max.expand(*x.shape) # affine, same shape as x # Import the registered handlers from the zonotope interpreter from . import interpret exp_handler = interpret["exp"] reciprocal_handler = interpret["reciprocal"] # Apply exp element-wise exp_x = exp_handler(x_shifted) # Sum along dim 1: exp_x @ ones(n, 1) → (m, 1) sum_exp = exp_x @ torch.ones(n, 1) # Reciprocal: 1 / sum_exp → (m, 1) inv_sum = reciprocal_handler(sum_exp) # Broadcast inv_sum to match exp_x shape: (m, 1) → (m, n) inv_sum_expanded = inv_sum.expand(*exp_x.shape) # Element-wise product: exp_x * inv_sum (bilinear) result = bilinear_elementwise(exp_x, inv_sum_expanded) return result