Source code for boundlab.diff.zono3.default.softmax

"""

softmax as ``exp → reduce-sum → reciprocal → element-wise product''

"""

import torch

from boundlab.diff.zono3 import expr
from boundlab.expr import Expr, ConstVal
from boundlab.expr._var import LpEpsilon
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.prop import ublb
from boundlab.zono.simplify import zono_simplify
from .bilinear import diff_bilinear_elementwise


def _clamp_expr(e: Expr, lo: float, hi: float, reason: str = "softmax_clamp") -> Expr:
    """Soundly intersect ``e``'s range with the known mathematical interval
    ``[lo, hi]``.

    Elements whose propagated bounds already lie inside ``[lo, hi]`` are kept
    verbatim (preserving the zonotope correlation that downstream cancellation
    relies on). Elements whose bounds escape the interval — e.g. a softmax
    value the zonotope reports as ``1e9`` because fp32 cancellation lost the
    ``Σ exp ≥ 1`` denominator floor — are replaced by the box hull of the
    intersection ``[max(l, lo), min(u, hi)]``. Sound either way: the true value
    lies in both ``[l, u]`` and ``[lo, hi]``, hence in their intersection.
    """
    if not isinstance(e, Expr):
        return e
    u, l = ublb(e)
    finite = torch.isfinite(l) & torch.isfinite(u)
    in_range = finite & (l >= lo - 1e-6) & (u <= hi + 1e-6)
    if bool(in_range.all()):
        return e
    new_l = torch.clamp(l, min=lo)
    new_u = torch.clamp(u, max=hi)
    # Elements whose propagated bounds are non-finite (inf coefficients
    # upstream) and elements whose intersection came out empty (fp slop or an
    # unsound upstream bound) both fall back to the full known range [lo, hi]:
    # the mathematical guarantee is the only information left to trust.
    degenerate = (~finite) | (new_l > new_u)
    new_l = torch.where(degenerate, torch.full_like(new_l, lo), new_l)
    new_u = torch.where(degenerate, torch.full_like(new_u, hi), new_u)
    center = (new_l + new_u) / 2
    radius = (new_u - new_l) / 2
    box = ConstVal(center) + radius * LpEpsilon(list(e.shape), reason=reason)
    if bool(finite.all()):
        # Every coefficient of ``e`` is finite, so the 0/1 blend is exact and
        # cannot produce ``0 * inf``.
        mask = in_range.to(center.dtype)
        return mask * e + (1.0 - mask) * box
    # ``e`` carries non-finite coefficients somewhere; a 0/1 blend against it
    # would evaluate ``0 * inf = NaN`` and poison *every* element. Drop the
    # correlations and return the pure box (this regime is already
    # catastrophic — those bounds were inf/NaN — and the box is also the
    # cheapest representation memory-wise: a single fresh generator).
    return box


def _floor_expr(e: Expr, lo: torch.Tensor, reason: str = "softmax_floor") -> Expr:
    """Soundly raise ``e``'s per-element lower bound to the known floor ``lo``.

    ``lo`` is a tensor of mathematically proven per-element lower bounds on
    the true value (e.g. ``Σ_j exp(x_j − x_i) ≥ 1``). Elements whose
    propagated lower bound already respects the floor are kept verbatim;
    elements below it are replaced by the box ``[max(l, lo), u]``. Elements
    with non-finite propagated bounds are left untouched (their coefficients
    are non-finite, so a 0/1 blend would NaN; they propagate loudly instead).
    """
    if not isinstance(e, Expr):
        return e
    u, l = ublb(e)
    finite = torch.isfinite(l) & torch.isfinite(u)
    need = finite & (l < lo)
    if not bool(need.any()):
        return e
    new_l = torch.maximum(l, lo)
    new_u = torch.maximum(u, new_l)  # trust the proven floor on contradiction
    center = torch.where(need, (new_l + new_u) / 2, torch.zeros_like(u))
    radius = torch.where(need, (new_u - new_l) / 2, torch.zeros_like(u))
    box = ConstVal(center) + radius * LpEpsilon(list(e.shape), reason=reason)
    # ``need ⊆ finite``: every element multiplied by 0 below has finite
    # coefficients, so the blend cannot produce ``0 * inf``.
    mask = (~need).to(u.dtype)
    return mask * e + (1.0 - mask) * box


[docs] def diff_softmax_handler(x, dim: int = -1, dtype=None, exp_handler=None, reciprocal_handler=None): r"""Differential softmax transformer. When *x* is a :class:`~boundlab.diff.expr.DiffExpr3`, the handler decomposes softmax into differential exp, reduce-sum, differential reciprocal, and differential element-wise product. When *x* is a plain :class:`~boundlab.expr.Expr` or :class:`~boundlab.diff.expr.DiffExpr2`, falls back to the standard softmax path or promotes to DiffExpr3 first. Args: x: Input expression or DiffExpr3 with shape ``(m, n)``. dim: Softmax dimension (default: -1). Only ``dim=1`` on 2D input is currently supported. dtype: Ignored (API compatibility). Returns: Expression or DiffExpr3 over-approximating softmax. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.diff.expr import DiffExpr3 >>> from boundlab.diff.zono3.default.softmax import diff_softmax_handler >>> x = expr.ConstVal(torch.zeros(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> y = expr.ConstVal(torch.ones(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> t = DiffExpr3(x, y, x - y) >>> out = diff_softmax_handler(t, dim=1) >>> out.diff.shape torch.Size([2, 3]) """ from .. import interpret if exp_handler is None: exp_handler = interpret["Exp"] if reciprocal_handler is None: reciprocal_handler = interpret["Reciprocal"] if isinstance(x, torch.Tensor): x = ConstVal(x) if isinstance(x, ConstVal): return ConstVal(torch.softmax(x.value, dim=dim)) if isinstance(x, Expr): from boundlab.zono.softmax import softmax_handler as std_softmax return std_softmax(x, dim=dim, dtype=dtype) if isinstance(x, DiffExpr2): x = DiffExpr3(x.x, x.y, x.x - x.y) assert isinstance(x, DiffExpr3), x ndim = len(x.shape) if dim < 0: dim = ndim + dim n = x.shape[dim] # DeepT-style rewrite: σ_i(ν) = 1 / Σ_j exp(ν_j - ν_i) # Vectorized: reshape ν to (..., N, 1) and (..., 1, N), broadcast-subtract # to get an (..., N, N) tensor of ν_j - ν_i, exp, sum over inner dim, # reciprocal. Output IS σ. No bilinear needed. # Insert a size-1 axis right AFTER dim and right BEFORE dim to create the # broadcast shapes. Work with dim as a positive axis. # x_i has shape (..., N, 1) — treat as "my row" # x_j has shape (..., 1, N) — treat as "all columns" x_i = x.unsqueeze(dim + 1) # shape: (..., N, 1, ...) x_j = x.unsqueeze(dim) # shape: (..., 1, N, ...) # BoundLab Expr subtraction requires matching shapes (no auto-broadcast). # Expand both to (..., N, N, ...) explicitly before subtract. broadcast_shape = list(x.shape) broadcast_shape.insert(dim + 1, n) # add the j axis at dim+1 x_i_exp = x_i.expand(*broadcast_shape) x_j_exp = x_j.expand(*broadcast_shape) x_shifted = x_j_exp - x_i_exp # exp of the pairwise-difference tensor. exp_shifted = exp_handler(x_shifted) # Sum along the j-axis (which is now at position dim + 1 since we inserted one # new axis at dim + 1 and one at dim, but the j-axis corresponds to dim + 1 # in the unsqueezed layout). Actually, after x.unsqueeze(dim+1) -> N at dim, # 1 at dim+1, then x.unsqueeze(dim) -> 1 at dim, N at dim+1. The SUM we want # is over j, which is the axis of size N that comes from x_j — position dim+1. sum_exp = exp_shifted.sum(dim=dim + 1, keepdim=False) # sum_exp now has shape (..., N, ...) — same as original x. # --- Semantic denominator floor (soundness) ----------------------------- # The reciprocal stage's domain is strictly positive; where the *zonotope* # lower bound of the denominator dips to <= 0 it emits ±inf envelopes (see # reciprocal_linearizer), which then cascade into the next layer as a # non-finite ReLU/activation input. But the denominator has a *provable* # positive floor independent of the (loose) propagated envelope: # D_i = Σ_j exp(ν_j − ν_i) ≥ max(1, exp(max_j ν_lb_j − ν_ub_i)) # because the j=i term is exp(0)=1 and every term is positive, so for any # fixed k, D_i ≥ exp(ν_k − ν_i) ≥ exp(ν_lb_k − ν_ub_i). Both networks use # the full (unmasked) softmax here, so the same floor applies to x and y. # The exponent is capped at 80 to keep the floor finite; lowering an # exponent only weakens the floor, so it stays a valid lower bound. Raising # a propagated lower bound to a proven floor is sound (the true value lies # in both the propagated interval and [floor, ∞)). x_ub_in, x_lb_in = ublb(x.x) y_ub_in, y_lb_in = ublb(x.y) x_floor = torch.exp( (x_lb_in.amax(dim=dim, keepdim=True) - x_ub_in).clamp(max=80.0) ).clamp(min=1.0) y_floor = torch.exp( (y_lb_in.amax(dim=dim, keepdim=True) - y_ub_in).clamp(max=80.0) ).clamp(min=1.0) sum_exp = DiffExpr3( _floor_expr(sum_exp.x, x_floor, reason="softmax_denominator_floor"), _floor_expr(sum_exp.y, y_floor, reason="softmax_denominator_floor"), sum_exp.diff, ) # Reciprocal — this IS σ, no bilinear. result = reciprocal_handler(sum_exp) # Both networks output softmax probabilities ∈ [0, 1]; their difference is # in [−1, 1]. Intersecting the propagated envelope with these known ranges # is sound and stops any residual reciprocal blow-up from cascading. (No-op # for elements already within range, e.g. at small eps.) if isinstance(result, DiffExpr3): result = DiffExpr3( zono_simplify(_clamp_expr(result.x, 0.0, 1.0), reason="diff_softmax_handler"), zono_simplify(_clamp_expr(result.y, 0.0, 1.0), reason="diff_softmax_handler"), zono_simplify(_clamp_expr(result.diff, -1.0, 1.0), reason="diff_softmax_handler"), ) return result
[docs] def diff_softmax_pruning_handler(scores, data, dim: int = -1, dtype=None, exp_handler=None, reciprocal_handler=None, heaviside_handler=None): r"""Differential transformer for ``boundlab::SoftmaxPruning``. Implements the mock pruning op :func:`boundlab.diff.op.softmax_pruning`: * branch x = ``softmax(data)`` , * branch y = ``exp(data_i) / Σ_j heaviside(scores_j) exp(data_j)`` (denominator-masked softmax) , * diff = ``x − y`` . It reuses the DeepT decomposition of :func:`diff_softmax_handler` (``softmax_i = 1 / Σ_j exp(data_j − data_i)``) but inserts a :func:`~boundlab.diff.zono3.default.heaviside.const_heaviside_pruning` mask on each term of the denominator sum, so the score-based pruning is only applied to the **y** network while the **x** network keeps the full softmax. Only ``scores_j`` (the *key* axis ``dim``) participates in masking; the mask is broadcast over the *query* axis. Args: scores: Pruning scores; same shape as *data*. May be a plain :class:`~boundlab.expr.Expr` / tensor (identical for both networks) or a :class:`~boundlab.diff.expr.DiffExpr2` / :class:`~boundlab.diff.expr.DiffExpr3`. data: Input expression / DiffExpr3 the softmax is taken over. dim: Softmax axis (default: -1). dtype: Ignored (API compatibility with ``torch.softmax``). Returns: Expression / DiffExpr3 over-approximating the masked softmax pair. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.diff.expr import DiffExpr3 >>> from boundlab.diff.zono3.default.softmax import diff_softmax_pruning_handler >>> z = expr.ConstVal(torch.zeros(2, 3)) + 0.05 * expr.LpEpsilon([2, 3]) >>> scores = torch.tensor([[1.0, -1.0, 1.0], [1.0, -1.0, 1.0]]) >>> out = diff_softmax_pruning_handler(scores, DiffExpr3(z, z, z - z), dim=1) >>> out.diff.shape torch.Size([2, 3]) """ from .. import interpret if exp_handler is None: exp_handler = interpret["Exp"] if reciprocal_handler is None: reciprocal_handler = interpret["Reciprocal"] if heaviside_handler is None: heaviside_handler = interpret["HeavisidePruning"] if isinstance(scores, torch.Tensor): scores = ConstVal(scores) if isinstance(data, torch.Tensor): data = ConstVal(data) if isinstance(data, Expr): data = DiffExpr2(data, data) # promote to DiffExpr2 for the standard softmax path if isinstance(data, DiffExpr2): data = DiffExpr3(data.x, data.y, data.x - data.y) ndim = len(data.shape) if dim < 0: dim = ndim + dim n = data.shape[dim] # print(data) # DeepT-style rewrite (see ``diff_softmax_handler``): build the (..., N, N, ...) # tensor data_j − data_i, exp it, mask each j-term, sum over j, reciprocal. data_i = data.unsqueeze(dim + 1) # query axis: (..., N, 1, ...) data_j = data.unsqueeze(dim) # key axis: (..., 1, N, ...) broadcast_shape = list(data.shape) broadcast_shape.insert(dim + 1, n) # add the j (key) axis at dim + 1 data_i_exp = data_i.expand(*broadcast_shape) data_j_exp = data_j.expand(*broadcast_shape) exp_shifted = exp_handler(data_j_exp - data_i_exp) # scores depend only on the key axis j (position dim + 1). Insert a size-1 # query axis at dim and broadcast to the pairwise (..., N, N, ...) shape. scores_b = scores.unsqueeze(dim).expand(*broadcast_shape) # Mask the denominator terms — only the y network is pruned. #masked = heaviside_handler(scores_b, exp_shifted) #sum_exp = masked.sum(dim=dim + 1, keepdim=False) #new fixed mask masked = heaviside_handler(scores_b, exp_shifted) sum_exp = masked.sum(dim=dim + 1, keepdim=False) # --- Semantic denominator floors (soundness, audit round 2 / A1) -------- # The reciprocal stage must never see a loose envelope whose lower bound # dips to ≤ 0: its 1e-9 safety clamp then manufactures ``1/1e-9``-scale # (~1e9) envelopes, and downstream fp32 interval arithmetic mixing those # with O(1) terms cancels catastrophically (ulp(5e8) = 32) — observed as # ``out.diff`` excluding reachable values by up to ~1.0 at moderate # perturbation scales. Both denominators have *provable* positive floors: # x: Σ_j exp(x_j − x_i) ≥ max(1, exp(max_j x_lb_j − x_ub_i)) (j=i term) # y: Σ_j h_j exp(y_j − y_i) ≥ exp(max_{kept j} y_lb_j − y_ub_i) # Exponents are capped at 80 so the floor itself stays finite; lowering an # exponent only weakens the floor, so it remains a valid lower bound. s_const = scores.get_const() if isinstance(scores, ConstVal) else None if s_const is not None: x_ub_in, x_lb_in = ublb(data.x) y_ub_in, y_lb_in = ublb(data.y) kept = (s_const >= 0) big = torch.where(kept, y_lb_in, torch.full_like(y_lb_in, float("-inf"))) y_floor = torch.exp((big.amax(dim=dim, keepdim=True) - y_ub_in).clamp(max=80.0)) x_floor = torch.exp( (x_lb_in.amax(dim=dim, keepdim=True) - x_ub_in).clamp(max=80.0) ).clamp(min=1.0) sum_exp = DiffExpr3( _floor_expr(sum_exp.x, x_floor, reason="softmax_denominator_floor"), _floor_expr(sum_exp.y, y_floor, reason="softmax_denominator_floor"), sum_exp.diff, ) result = reciprocal_handler(sum_exp) result = heaviside_handler(scores, result) assert isinstance(result, DiffExpr3) # Both branches are softmax probabilities ∈ [0, 1]; their difference is in # [-1, 1]. The zonotope reciprocal can report values far outside this range # when wide attention logits drive Σ exp to fp32 magnitudes that swamp the # denominator's ``≥ 1`` floor. Intersecting with the known ranges is sound # and stops that spurious blow-up from cascading through later layers. result = DiffExpr3( zono_simplify(_clamp_expr(result.x, 0.0, 1.0), reason="diff_softmax_pruning_handler"), zono_simplify(_clamp_expr(result.y, 0.0, 1.0), reason="diff_softmax_pruning_handler"), zono_simplify(_clamp_expr(result.diff, -1.0, 1.0), reason="diff_softmax_pruning_handler"), ) # print(result) return result
#return reciprocal_handler(sum_exp) __all__ = ["diff_softmax_handler", "diff_softmax_pruning_handler"]