boundlab.diff.zonosq3.diff_softmax_pruning_handler#
- boundlab.diff.zonosq3.diff_softmax_pruning_handler(scores, data, dim=-1, dtype=None, exp_handler=None, reciprocal_handler=None, heaviside_handler=None)[source]#
Differential transformer for
boundlab::SoftmaxPruning.Implements the mock pruning op
boundlab.diff.op.softmax_pruning():branch x =
softmax(data),branch y =
exp(data_i) / Σ_j heaviside(scores_j) exp(data_j)(denominator-masked softmax) ,diff =
x − y.
It reuses the DeepT decomposition of
diff_softmax_handler()(softmax_i = 1 / Σ_j exp(data_j − data_i)) but inserts aconst_heaviside_pruning()mask on each term of the denominator sum, so the score-based pruning is only applied to the y network while the x network keeps the full softmax. Onlyscores_j(the key axisdim) participates in masking; the mask is broadcast over the query axis.- Parameters:
- Returns:
Expression / DiffExpr3 over-approximating the masked softmax pair.
Examples
>>> import torch >>> import boundlab.expr as expr >>> from boundlab.diff.expr import DiffExpr3 >>> from boundlab.diff.zono3.default.softmax import diff_softmax_pruning_handler >>> z = expr.ConstVal(torch.zeros(2, 3)) + 0.05 * expr.LpEpsilon([2, 3]) >>> scores = torch.tensor([[1.0, -1.0, 1.0], [1.0, -1.0, 1.0]]) >>> out = diff_softmax_pruning_handler(scores, DiffExpr3(z, z, z - z), dim=1) >>> out.diff.shape torch.Size([2, 3])