boundlab.diff.zono3.diff_softmax_handler#
- boundlab.diff.zono3.diff_softmax_handler(x, dim=-1, dtype=None, exp_handler=None, reciprocal_handler=None)[source]#
Differential softmax transformer.
When x is a
DiffExpr3, the handler decomposes softmax into differential exp, reduce-sum, differential reciprocal, and differential element-wise product.When x is a plain
ExprorDiffExpr2, falls back to the standard softmax path or promotes to DiffExpr3 first.- Parameters:
x – Input expression or DiffExpr3 with shape
(m, n).dim (int) – Softmax dimension (default: -1). Only
dim=1on 2D input is currently supported.dtype – Ignored (API compatibility).
- Returns:
Expression or DiffExpr3 over-approximating softmax.
Examples
>>> import torch >>> import boundlab.expr as expr >>> from boundlab.diff.expr import DiffExpr3 >>> from boundlab.diff.zono3.default.softmax import diff_softmax_handler >>> x = expr.ConstVal(torch.zeros(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> y = expr.ConstVal(torch.ones(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> t = DiffExpr3(x, y, x - y) >>> out = diff_softmax_handler(t, dim=1) >>> out.diff.shape torch.Size([2, 3])