"""
softmax as ``exp → reduce-sum → reciprocal → element-wise product''
"""
import torch
from boundlab.diff.zono3 import expr
from boundlab.expr import Expr, ConstVal
from boundlab.expr._var import LpEpsilon
from boundlab.diff.expr import DiffExpr2, DiffExpr3
from boundlab.prop import ublb
from boundlab.zono.simplify import zono_simplify
from .bilinear import diff_bilinear_elementwise
def _clamp_expr(e: Expr, lo: float, hi: float, reason: str = "softmax_clamp") -> Expr:
"""Soundly intersect ``e``'s range with the known mathematical interval
``[lo, hi]``.
Elements whose propagated bounds already lie inside ``[lo, hi]`` are kept
verbatim (preserving the zonotope correlation that downstream cancellation
relies on). Elements whose bounds escape the interval — e.g. a softmax
value the zonotope reports as ``1e9`` because fp32 cancellation lost the
``Σ exp ≥ 1`` denominator floor — are replaced by the box hull of the
intersection ``[max(l, lo), min(u, hi)]``. Sound either way: the true value
lies in both ``[l, u]`` and ``[lo, hi]``, hence in their intersection.
"""
if not isinstance(e, Expr):
return e
u, l = ublb(e)
finite = torch.isfinite(l) & torch.isfinite(u)
in_range = finite & (l >= lo - 1e-6) & (u <= hi + 1e-6)
if bool(in_range.all()):
return e
new_l = torch.clamp(l, min=lo)
new_u = torch.clamp(u, max=hi)
# Elements whose propagated bounds are non-finite (inf coefficients
# upstream) and elements whose intersection came out empty (fp slop or an
# unsound upstream bound) both fall back to the full known range [lo, hi]:
# the mathematical guarantee is the only information left to trust.
degenerate = (~finite) | (new_l > new_u)
new_l = torch.where(degenerate, torch.full_like(new_l, lo), new_l)
new_u = torch.where(degenerate, torch.full_like(new_u, hi), new_u)
center = (new_l + new_u) / 2
radius = (new_u - new_l) / 2
box = ConstVal(center) + radius * LpEpsilon(list(e.shape), reason=reason)
if bool(finite.all()):
# Every coefficient of ``e`` is finite, so the 0/1 blend is exact and
# cannot produce ``0 * inf``.
mask = in_range.to(center.dtype)
return mask * e + (1.0 - mask) * box
# ``e`` carries non-finite coefficients somewhere; a 0/1 blend against it
# would evaluate ``0 * inf = NaN`` and poison *every* element. Drop the
# correlations and return the pure box (this regime is already
# catastrophic — those bounds were inf/NaN — and the box is also the
# cheapest representation memory-wise: a single fresh generator).
return box
def _floor_expr(e: Expr, lo: torch.Tensor, reason: str = "softmax_floor") -> Expr:
"""Soundly raise ``e``'s per-element lower bound to the known floor ``lo``.
``lo`` is a tensor of mathematically proven per-element lower bounds on
the true value (e.g. ``Σ_j exp(x_j − x_i) ≥ 1``). Elements whose
propagated lower bound already respects the floor are kept verbatim;
elements below it are replaced by the box ``[max(l, lo), u]``. Elements
with non-finite propagated bounds are left untouched (their coefficients
are non-finite, so a 0/1 blend would NaN; they propagate loudly instead).
"""
if not isinstance(e, Expr):
return e
u, l = ublb(e)
finite = torch.isfinite(l) & torch.isfinite(u)
need = finite & (l < lo)
if not bool(need.any()):
return e
new_l = torch.maximum(l, lo)
new_u = torch.maximum(u, new_l) # trust the proven floor on contradiction
center = torch.where(need, (new_l + new_u) / 2, torch.zeros_like(u))
radius = torch.where(need, (new_u - new_l) / 2, torch.zeros_like(u))
box = ConstVal(center) + radius * LpEpsilon(list(e.shape), reason=reason)
# ``need ⊆ finite``: every element multiplied by 0 below has finite
# coefficients, so the blend cannot produce ``0 * inf``.
mask = (~need).to(u.dtype)
return mask * e + (1.0 - mask) * box
[docs]
def diff_softmax_handler(x, dim: int = -1, dtype=None, exp_handler=None, reciprocal_handler=None):
r"""Differential softmax transformer.
When *x* is a :class:`~boundlab.diff.expr.DiffExpr3`, the handler
decomposes softmax into differential exp, reduce-sum, differential
reciprocal, and differential element-wise product.
When *x* is a plain :class:`~boundlab.expr.Expr` or
:class:`~boundlab.diff.expr.DiffExpr2`, falls back to the
standard softmax path or promotes to DiffExpr3 first.
Args:
x: Input expression or DiffExpr3 with shape ``(m, n)``.
dim: Softmax dimension (default: -1). Only ``dim=1`` on 2D input
is currently supported.
dtype: Ignored (API compatibility).
Returns:
Expression or DiffExpr3 over-approximating softmax.
Examples
--------
>>> import torch
>>> import boundlab.expr as expr
>>> from boundlab.diff.expr import DiffExpr3
>>> from boundlab.diff.zono3.default.softmax import diff_softmax_handler
>>> x = expr.ConstVal(torch.zeros(2, 3)) + 0.1 * expr.LpEpsilon([2, 3])
>>> y = expr.ConstVal(torch.ones(2, 3)) + 0.1 * expr.LpEpsilon([2, 3])
>>> t = DiffExpr3(x, y, x - y)
>>> out = diff_softmax_handler(t, dim=1)
>>> out.diff.shape
torch.Size([2, 3])
"""
from .. import interpret
if exp_handler is None:
exp_handler = interpret["Exp"]
if reciprocal_handler is None:
reciprocal_handler = interpret["Reciprocal"]
if isinstance(x, torch.Tensor):
x = ConstVal(x)
if isinstance(x, ConstVal):
return ConstVal(torch.softmax(x.value, dim=dim))
if isinstance(x, Expr):
from boundlab.zono.softmax import softmax_handler as std_softmax
return std_softmax(x, dim=dim, dtype=dtype)
if isinstance(x, DiffExpr2):
x = DiffExpr3(x.x, x.y, x.x - x.y)
assert isinstance(x, DiffExpr3), x
ndim = len(x.shape)
if dim < 0:
dim = ndim + dim
n = x.shape[dim]
# DeepT-style rewrite: σ_i(ν) = 1 / Σ_j exp(ν_j - ν_i)
# Vectorized: reshape ν to (..., N, 1) and (..., 1, N), broadcast-subtract
# to get an (..., N, N) tensor of ν_j - ν_i, exp, sum over inner dim,
# reciprocal. Output IS σ. No bilinear needed.
# Insert a size-1 axis right AFTER dim and right BEFORE dim to create the
# broadcast shapes. Work with dim as a positive axis.
# x_i has shape (..., N, 1) — treat as "my row"
# x_j has shape (..., 1, N) — treat as "all columns"
x_i = x.unsqueeze(dim + 1) # shape: (..., N, 1, ...)
x_j = x.unsqueeze(dim) # shape: (..., 1, N, ...)
# BoundLab Expr subtraction requires matching shapes (no auto-broadcast).
# Expand both to (..., N, N, ...) explicitly before subtract.
broadcast_shape = list(x.shape)
broadcast_shape.insert(dim + 1, n) # add the j axis at dim+1
x_i_exp = x_i.expand(*broadcast_shape)
x_j_exp = x_j.expand(*broadcast_shape)
x_shifted = x_j_exp - x_i_exp
# exp of the pairwise-difference tensor.
exp_shifted = exp_handler(x_shifted)
# Sum along the j-axis (which is now at position dim + 1 since we inserted one
# new axis at dim + 1 and one at dim, but the j-axis corresponds to dim + 1
# in the unsqueezed layout). Actually, after x.unsqueeze(dim+1) -> N at dim,
# 1 at dim+1, then x.unsqueeze(dim) -> 1 at dim, N at dim+1. The SUM we want
# is over j, which is the axis of size N that comes from x_j — position dim+1.
sum_exp = exp_shifted.sum(dim=dim + 1, keepdim=False)
# sum_exp now has shape (..., N, ...) — same as original x.
# --- Semantic denominator floor (soundness) -----------------------------
# The reciprocal stage's domain is strictly positive; where the *zonotope*
# lower bound of the denominator dips to <= 0 it emits ±inf envelopes (see
# reciprocal_linearizer), which then cascade into the next layer as a
# non-finite ReLU/activation input. But the denominator has a *provable*
# positive floor independent of the (loose) propagated envelope:
# D_i = Σ_j exp(ν_j − ν_i) ≥ max(1, exp(max_j ν_lb_j − ν_ub_i))
# because the j=i term is exp(0)=1 and every term is positive, so for any
# fixed k, D_i ≥ exp(ν_k − ν_i) ≥ exp(ν_lb_k − ν_ub_i). Both networks use
# the full (unmasked) softmax here, so the same floor applies to x and y.
# The exponent is capped at 80 to keep the floor finite; lowering an
# exponent only weakens the floor, so it stays a valid lower bound. Raising
# a propagated lower bound to a proven floor is sound (the true value lies
# in both the propagated interval and [floor, ∞)).
x_ub_in, x_lb_in = ublb(x.x)
y_ub_in, y_lb_in = ublb(x.y)
x_floor = torch.exp(
(x_lb_in.amax(dim=dim, keepdim=True) - x_ub_in).clamp(max=80.0)
).clamp(min=1.0)
y_floor = torch.exp(
(y_lb_in.amax(dim=dim, keepdim=True) - y_ub_in).clamp(max=80.0)
).clamp(min=1.0)
sum_exp = DiffExpr3(
_floor_expr(sum_exp.x, x_floor, reason="softmax_denominator_floor"),
_floor_expr(sum_exp.y, y_floor, reason="softmax_denominator_floor"),
sum_exp.diff,
)
# Reciprocal — this IS σ, no bilinear.
result = reciprocal_handler(sum_exp)
# Both networks output softmax probabilities ∈ [0, 1]; their difference is
# in [−1, 1]. Intersecting the propagated envelope with these known ranges
# is sound and stops any residual reciprocal blow-up from cascading. (No-op
# for elements already within range, e.g. at small eps.)
if isinstance(result, DiffExpr3):
result = DiffExpr3(
zono_simplify(_clamp_expr(result.x, 0.0, 1.0), reason="diff_softmax_handler"),
zono_simplify(_clamp_expr(result.y, 0.0, 1.0), reason="diff_softmax_handler"),
zono_simplify(_clamp_expr(result.diff, -1.0, 1.0), reason="diff_softmax_handler"),
)
return result
[docs]
def diff_softmax_pruning_handler(scores, data, dim: int = -1, dtype=None,
exp_handler=None, reciprocal_handler=None,
heaviside_handler=None):
r"""Differential transformer for ``boundlab::SoftmaxPruning``.
Implements the mock pruning op :func:`boundlab.diff.op.softmax_pruning`:
* branch x = ``softmax(data)`` ,
* branch y = ``exp(data_i) / Σ_j heaviside(scores_j) exp(data_j)``
(denominator-masked softmax) ,
* diff = ``x − y`` .
It reuses the DeepT decomposition of :func:`diff_softmax_handler`
(``softmax_i = 1 / Σ_j exp(data_j − data_i)``) but inserts a
:func:`~boundlab.diff.zono3.default.heaviside.const_heaviside_pruning`
mask on each term of the denominator sum, so the score-based pruning is
only applied to the **y** network while the **x** network keeps the full
softmax. Only ``scores_j`` (the *key* axis ``dim``) participates in masking;
the mask is broadcast over the *query* axis.
Args:
scores: Pruning scores; same shape as *data*. May be a plain
:class:`~boundlab.expr.Expr` / tensor (identical for both
networks) or a :class:`~boundlab.diff.expr.DiffExpr2` /
:class:`~boundlab.diff.expr.DiffExpr3`.
data: Input expression / DiffExpr3 the softmax is taken over.
dim: Softmax axis (default: -1).
dtype: Ignored (API compatibility with ``torch.softmax``).
Returns:
Expression / DiffExpr3 over-approximating the masked softmax pair.
Examples
--------
>>> import torch
>>> import boundlab.expr as expr
>>> from boundlab.diff.expr import DiffExpr3
>>> from boundlab.diff.zono3.default.softmax import diff_softmax_pruning_handler
>>> z = expr.ConstVal(torch.zeros(2, 3)) + 0.05 * expr.LpEpsilon([2, 3])
>>> scores = torch.tensor([[1.0, -1.0, 1.0], [1.0, -1.0, 1.0]])
>>> out = diff_softmax_pruning_handler(scores, DiffExpr3(z, z, z - z), dim=1)
>>> out.diff.shape
torch.Size([2, 3])
"""
from .. import interpret
if exp_handler is None:
exp_handler = interpret["Exp"]
if reciprocal_handler is None:
reciprocal_handler = interpret["Reciprocal"]
if heaviside_handler is None:
heaviside_handler = interpret["HeavisidePruning"]
if isinstance(scores, torch.Tensor):
scores = ConstVal(scores)
if isinstance(data, torch.Tensor):
data = ConstVal(data)
if isinstance(data, Expr):
data = DiffExpr2(data, data) # promote to DiffExpr2 for the standard softmax path
if isinstance(data, DiffExpr2):
data = DiffExpr3(data.x, data.y, data.x - data.y)
ndim = len(data.shape)
if dim < 0:
dim = ndim + dim
n = data.shape[dim]
# print(data)
# DeepT-style rewrite (see ``diff_softmax_handler``): build the (..., N, N, ...)
# tensor data_j − data_i, exp it, mask each j-term, sum over j, reciprocal.
data_i = data.unsqueeze(dim + 1) # query axis: (..., N, 1, ...)
data_j = data.unsqueeze(dim) # key axis: (..., 1, N, ...)
broadcast_shape = list(data.shape)
broadcast_shape.insert(dim + 1, n) # add the j (key) axis at dim + 1
data_i_exp = data_i.expand(*broadcast_shape)
data_j_exp = data_j.expand(*broadcast_shape)
exp_shifted = exp_handler(data_j_exp - data_i_exp)
# scores depend only on the key axis j (position dim + 1). Insert a size-1
# query axis at dim and broadcast to the pairwise (..., N, N, ...) shape.
scores_b = scores.unsqueeze(dim).expand(*broadcast_shape)
# Mask the denominator terms — only the y network is pruned.
#masked = heaviside_handler(scores_b, exp_shifted)
#sum_exp = masked.sum(dim=dim + 1, keepdim=False)
#new fixed mask
masked = heaviside_handler(scores_b, exp_shifted)
sum_exp = masked.sum(dim=dim + 1, keepdim=False)
# --- Semantic denominator floors (soundness, audit round 2 / A1) --------
# The reciprocal stage must never see a loose envelope whose lower bound
# dips to ≤ 0: its 1e-9 safety clamp then manufactures ``1/1e-9``-scale
# (~1e9) envelopes, and downstream fp32 interval arithmetic mixing those
# with O(1) terms cancels catastrophically (ulp(5e8) = 32) — observed as
# ``out.diff`` excluding reachable values by up to ~1.0 at moderate
# perturbation scales. Both denominators have *provable* positive floors:
# x: Σ_j exp(x_j − x_i) ≥ max(1, exp(max_j x_lb_j − x_ub_i)) (j=i term)
# y: Σ_j h_j exp(y_j − y_i) ≥ exp(max_{kept j} y_lb_j − y_ub_i)
# Exponents are capped at 80 so the floor itself stays finite; lowering an
# exponent only weakens the floor, so it remains a valid lower bound.
s_const = scores.get_const() if isinstance(scores, ConstVal) else None
if s_const is not None:
x_ub_in, x_lb_in = ublb(data.x)
y_ub_in, y_lb_in = ublb(data.y)
kept = (s_const >= 0)
big = torch.where(kept, y_lb_in, torch.full_like(y_lb_in, float("-inf")))
y_floor = torch.exp((big.amax(dim=dim, keepdim=True) - y_ub_in).clamp(max=80.0))
x_floor = torch.exp(
(x_lb_in.amax(dim=dim, keepdim=True) - x_ub_in).clamp(max=80.0)
).clamp(min=1.0)
sum_exp = DiffExpr3(
_floor_expr(sum_exp.x, x_floor, reason="softmax_denominator_floor"),
_floor_expr(sum_exp.y, y_floor, reason="softmax_denominator_floor"),
sum_exp.diff,
)
result = reciprocal_handler(sum_exp)
result = heaviside_handler(scores, result)
assert isinstance(result, DiffExpr3)
# Both branches are softmax probabilities ∈ [0, 1]; their difference is in
# [-1, 1]. The zonotope reciprocal can report values far outside this range
# when wide attention logits drive Σ exp to fp32 magnitudes that swamp the
# denominator's ``≥ 1`` floor. Intersecting with the known ranges is sound
# and stops that spurious blow-up from cascading through later layers.
result = DiffExpr3(
zono_simplify(_clamp_expr(result.x, 0.0, 1.0), reason="diff_softmax_pruning_handler"),
zono_simplify(_clamp_expr(result.y, 0.0, 1.0), reason="diff_softmax_pruning_handler"),
zono_simplify(_clamp_expr(result.diff, -1.0, 1.0), reason="diff_softmax_pruning_handler"),
)
# print(result)
return result
#return reciprocal_handler(sum_exp)
__all__ = ["diff_softmax_handler", "diff_softmax_pruning_handler"]