boundlab.diff.zono3.diff_softmax_pruning_handler#

boundlab.diff.zono3.diff_softmax_pruning_handler(scores, data, dim=-1, dtype=None, exp_handler=None, reciprocal_handler=None, heaviside_handler=None)[source]#

Differential transformer for boundlab::SoftmaxPruning.

Implements the mock pruning op boundlab.diff.op.softmax_pruning():

  • branch x = softmax(data) ,

  • branch y = exp(data_i) / Σ_j heaviside(scores_j) exp(data_j) (denominator-masked softmax) ,

  • diff = x y .

It reuses the DeepT decomposition of diff_softmax_handler() (softmax_i = 1 / Σ_j exp(data_j data_i)) but inserts a const_heaviside_pruning() mask on each term of the denominator sum, so the score-based pruning is only applied to the y network while the x network keeps the full softmax. Only scores_j (the key axis dim) participates in masking; the mask is broadcast over the query axis.

Parameters:
  • scores – Pruning scores; same shape as data. May be a plain Expr / tensor (identical for both networks) or a DiffExpr2 / DiffExpr3.

  • data – Input expression / DiffExpr3 the softmax is taken over.

  • dim (int) – Softmax axis (default: -1).

  • dtype – Ignored (API compatibility with torch.softmax).

Returns:

Expression / DiffExpr3 over-approximating the masked softmax pair.

Examples

>>> import torch
>>> import boundlab.expr as expr
>>> from boundlab.diff.expr import DiffExpr3
>>> from boundlab.diff.zono3.default.softmax import diff_softmax_pruning_handler
>>> z = expr.ConstVal(torch.zeros(2, 3)) + 0.05 * expr.LpEpsilon([2, 3])
>>> scores = torch.tensor([[1.0, -1.0, 1.0], [1.0, -1.0, 1.0]])
>>> out = diff_softmax_pruning_handler(scores, DiffExpr3(z, z, z - z), dim=1)
>>> out.diff.shape
torch.Size([2, 3])