Source code for boundlab.diff.zono3.default.softmax

"""

softmax as ``exp → reduce-sum → reciprocal → element-wise product''

"""

import torch

from boundlab.diff.zono3 import expr
from boundlab.expr import Expr, ConstVal
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from .bilinear import diff_bilinear_elementwise


[docs] def diff_softmax_handler(x, dim: int = -1, dtype=None, exp_handler=None, reciprocal_handler=None): r"""Differential softmax transformer. When *x* is a :class:`~boundlab.diff.expr.DiffExpr3`, the handler decomposes softmax into differential exp, reduce-sum, differential reciprocal, and differential element-wise product. When *x* is a plain :class:`~boundlab.expr.Expr` or :class:`~boundlab.diff.expr.DiffExpr2`, falls back to the standard softmax path or promotes to DiffExpr3 first. Args: x: Input expression or DiffExpr3 with shape ``(m, n)``. dim: Softmax dimension (default: -1). Only ``dim=1`` on 2D input is currently supported. dtype: Ignored (API compatibility). Returns: Expression or DiffExpr3 over-approximating softmax. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.diff.expr import DiffExpr3 >>> from boundlab.diff.zono3.default.softmax import diff_softmax_handler >>> x = expr.ConstVal(torch.zeros(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> y = expr.ConstVal(torch.ones(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> t = DiffExpr3(x, y, x - y) >>> out = diff_softmax_handler(t, dim=1) >>> out.diff.shape torch.Size([2, 3]) """ from .. import interpret if exp_handler is None: exp_handler = interpret["Exp"] if reciprocal_handler is None: reciprocal_handler = interpret["Reciprocal"] if isinstance(x, torch.Tensor): x = ConstVal(x) if isinstance(x, ConstVal): return ConstVal(torch.softmax(x.value, dim=dim)) if isinstance(x, Expr): from boundlab.zono.softmax import softmax_handler as std_softmax return std_softmax(x, dim=dim, dtype=dtype) if isinstance(x, DiffExpr2): x = DiffExpr3(x.x, x.y, x.x - x.y) assert isinstance(x, DiffExpr3), x ndim = len(x.shape) if dim < 0: dim = ndim + dim n = x.shape[dim] # DeepT-style rewrite: σ_i(ν) = 1 / Σ_j exp(ν_j - ν_i) # Vectorized: reshape ν to (..., N, 1) and (..., 1, N), broadcast-subtract # to get an (..., N, N) tensor of ν_j - ν_i, exp, sum over inner dim, # reciprocal. Output IS σ. No bilinear needed. # Insert a size-1 axis right AFTER dim and right BEFORE dim to create the # broadcast shapes. Work with dim as a positive axis. # x_i has shape (..., N, 1) — treat as "my row" # x_j has shape (..., 1, N) — treat as "all columns" x_i = x.unsqueeze(dim + 1) # shape: (..., N, 1, ...) x_j = x.unsqueeze(dim) # shape: (..., 1, N, ...) # BoundLab Expr subtraction requires matching shapes (no auto-broadcast). # Expand both to (..., N, N, ...) explicitly before subtract. broadcast_shape = list(x.shape) broadcast_shape.insert(dim + 1, n) # add the j axis at dim+1 x_i_exp = x_i.expand(*broadcast_shape) x_j_exp = x_j.expand(*broadcast_shape) x_shifted = x_j_exp - x_i_exp # exp of the pairwise-difference tensor. exp_shifted = exp_handler(x_shifted) # Sum along the j-axis (which is now at position dim + 1 since we inserted one # new axis at dim + 1 and one at dim, but the j-axis corresponds to dim + 1 # in the unsqueezed layout). Actually, after x.unsqueeze(dim+1) -> N at dim, # 1 at dim+1, then x.unsqueeze(dim) -> 1 at dim, N at dim+1. The SUM we want # is over j, which is the axis of size N that comes from x_j — position dim+1. sum_exp = exp_shifted.sum(dim=dim + 1, keepdim=False) # sum_exp now has shape (..., N, ...) — same as original x. # Reciprocal — this IS σ, no bilinear. result = reciprocal_handler(sum_exp) return result
__all__ = ["diff_softmax_handler"]