boundlab.diff.zono3.diff_softmax_handler#

boundlab.diff.zono3.diff_softmax_handler(x, dim=-1, dtype=None, exp_handler=None, reciprocal_handler=None)[source]#

Differential softmax transformer.

When x is a DiffExpr3, the handler decomposes softmax into differential exp, reduce-sum, differential reciprocal, and differential element-wise product.

When x is a plain Expr or DiffExpr2, falls back to the standard softmax path or promotes to DiffExpr3 first.

Parameters:
  • x – Input expression or DiffExpr3 with shape (m, n).

  • dim (int) – Softmax dimension (default: -1). Only dim=1 on 2D input is currently supported.

  • dtype – Ignored (API compatibility).

Returns:

Expression or DiffExpr3 over-approximating softmax.

Examples

>>> import torch
>>> import boundlab.expr as expr
>>> from boundlab.diff.expr import DiffExpr3
>>> from boundlab.diff.zono3.default.softmax import diff_softmax_handler
>>> x = expr.ConstVal(torch.zeros(2, 3)) + 0.1 * expr.LpEpsilon([2, 3])
>>> y = expr.ConstVal(torch.ones(2, 3)) + 0.1 * expr.LpEpsilon([2, 3])
>>> t = DiffExpr3(x, y, x - y)
>>> out = diff_softmax_handler(t, dim=1)
>>> out.diff.shape
torch.Size([2, 3])