Source code for boundlab.zono.softmax

"""Softmax handler for zonotope abstract interpretation.

Implements softmax following DeepT (Bonaert et al., 2021):

.. math::

   \\mathrm{softmax}(x)_i
     = \\frac{1}{\\sum_j \\exp(x_j - x_i)}

This avoids a bilinear product: only subtraction, exp, reduce-sum
and reciprocal are needed.
"""

from sys import stderr

import torch

from boundlab import expr, utils
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._indices import GatherOp
from boundlab.zono.reciprocal import reciprocal_linearizer
from boundlab.zono.exp import exp_linearizer
from boundlab.expr._core import Expr
from boundlab.zono.softmax2 import softmax2_ibp, softmax2_linearizer
from boundlab.zono.tanh import tanh_linearizer

"""Softmax handler for zonotope abstract interpretation.

Implements softmax following DeepT (Bonaert et al., 2021):

.. math::

   \\mathrm{softmax}(x)_i
     = \\frac{1}{\\sum_j \\exp(x_j - x_i)}

This avoids a bilinear product: only subtraction, exp, reduce-sum
and reciprocal are needed.
"""

from sys import stderr

import torch

from boundlab import expr, utils
from boundlab.expr._var import LpEpsilon
from boundlab.linearop._indices import GatherOp
from boundlab.zono.reciprocal import reciprocal_linearizer
from boundlab.zono.exp import exp_linearizer
from boundlab.expr._core import Expr
from boundlab.zono.tanh import tanh_linearizer


[docs] def softmax_handler(x: Expr, dim: int = -1, dtype=None) -> Expr: r"""Zonotope softmax transformer using the DeepT decomposition. Softmax is rewritten as: .. math:: \mathrm{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)} = \frac{1}{\sum_j \exp(x_j - x_i)} so only subtraction, exp, reduce-sum, and reciprocal are required (no bilinear element-wise product). Args: x: Input expression. dim: Dimension along which to apply softmax (default: -1). dtype: Ignored (for API compatibility with torch.softmax). Returns: An expression over-approximating ``torch.softmax(x, dim=dim)``. Examples -------- >>> import torch >>> import boundlab.expr as expr >>> from boundlab.zono.softmax import softmax_handler >>> x = expr.ConstVal(torch.zeros(2, 3)) + 0.1 * expr.LpEpsilon([2, 3]) >>> y = softmax_handler(x, dim=1) >>> y.shape torch.Size([2, 3]) """ if not isinstance(x, Expr): return NotImplemented if dim < 0: dim += len(x.shape) assert dim == len(x.shape) - 1, "softmax_handler only supports the last dimension" from . import interpret ub, lb = x.ublb() # pairwise_diff gives diff[..., i, j] = x[..., i] - x[..., j] # softmax(x)[i] = 1 / sum_j exp(x[j] - x[i]) = 1 / sum_j exp(diff[..., i, j]) diff = -utils.pairwise_diff(x, dim) ub, lb = diff.ublb() expbounds = exp_linearizer(ub, lb) bias = expbounds.bias error = expbounds.error_coeffs.tensor weights = expbounds.input_weights[0] finite_mask = torch.isfinite(weights) & torch.isfinite(error) & torch.isfinite(bias) & (lb < 30) & (ub < 30) bias = torch.where(finite_mask, bias, 0) error = torch.where(finite_mask, error, 0) weights = torch.where(finite_mask, weights, 0) # One fresh noise symbol per *output element* (..., i): diff.shape[:-1]. # (Was diff.shape[:-2], a dimension too small, which crashed per-epsilon # affine extraction; see run_quant_equiv._patch_zono_softmax_axes.) sum_exp = (weights * diff).sum(dim=-1) + error.sum(dim=-1) * LpEpsilon(diff.shape[:-1]) + bias.sum(dim=-1) finite_mask = finite_mask.all(dim=-1) sum_exp_ub, sum_exp_lb = sum_exp.ublb() sum_exp_ub = torch.minimum(sum_exp_ub, torch.exp(ub).sum(dim=-1)) # The true denominator always contains the j == i term exp(0) = 1, so any # positive floor <= 1 is a sound lower bound; the floor keeps a fully # underflowed exp(lb).sum() from handing the reciprocal a zero (NaN). sum_exp_lb = torch.maximum(sum_exp_lb, torch.exp(lb).sum(dim=-1)).clamp(min=1e-30) bounds = reciprocal_linearizer(sum_exp_ub, sum_exp_lb) w = bounds.input_weights[0] mu = bounds.bias beta = bounds.error_coeffs.tensor finite_mask = finite_mask & torch.isfinite(w) & torch.isfinite(mu) & torch.isfinite(beta) w = torch.where(finite_mask, w, 0) mu = torch.where(finite_mask, mu, 0) beta = torch.where(finite_mask, beta, 0) # Elements whose relaxation failed get the trivially sound softmax range # [0, 1] instead of being pinned to exactly 0, which excluded reachable # values (true softmax can be ~1 with a pairwise ub >= 30) and was unsound. fallback = 0.5 * LpEpsilon(sum_exp.shape) + torch.full(tuple(sum_exp.shape), 0.5) result = (finite_mask * (w * sum_exp + mu + beta * LpEpsilon(sum_exp.shape)) + (~finite_mask) * fallback) # Tag for the simplex clamp in matmul_handler (Task 7): rows of this softmax # output sum to 1, so a following S@V is a per-token convex combination of V. # The attribute survives only on this node, which is fine — the next op # (the value matmul) receives it directly. try: result._simplex_rows = True except AttributeError: pass return result
def softmax_handler_basedon_softmax2(x: Expr, dim: int = -1, dtype=None) -> Expr: if dim < 0: dim += len(x.shape) assert dim == len(x.shape) - 1, "softmax_handler_basedon_softmax2 only supports the last dimension" # Pairwise differences for recurrence y = x_j - x_i. orig_shape = x.shape x = x.reshape(-1, x.shape[-1]) diff = -utils.pairwise_diff(x, dim=-1) # [T, i, j] diff = diff.permute(1, 2, 0) # [i, j, T] diff = utils.remove_diagonal(diff, dim1=0, dim2=1) # [i, j, T] diff = diff.permute(1, 2, 0) # [j, T, i] diff_ub, diff_lb = diff.ublb() tanh_err = (diff_ub + diff_lb).abs() / 2 indices = torch.argsort(tanh_err, dim=0) diff_ub = torch.gather(diff_ub, dim=0, index=indices) diff_lb = torch.gather(diff_lb, dim=0, index=indices) diff = GatherOp(diff.shape, dim=0, index=indices)(diff) result = expr.ConstVal(torch.ones(x.shape)) result_ub, result_lb = result.ublb() for i in range(diff.shape[0]): # Domain invariant of this recurrence: result stays in (0, 1]. result_ub = torch.clamp(result_ub, min=1e-8, max=1.0) result_lb = torch.clamp(result_lb, min=1e-8, max=1.0) result_lb = torch.minimum(result_lb, result_ub) zonobounds = softmax2_linearizer(result_ub, result_lb, diff_ub[i], diff_lb[i]) ub_ibp, lb_ibp = softmax2_ibp(result_ub, result_lb, diff_ub[i], diff_lb[i]) lam_x, lam_y, mu, beta = zonobounds.input_weights[0], zonobounds.input_weights[1], zonobounds.bias, zonobounds.error_coeffs.tensor result = lam_x * result + lam_y * diff[i] + mu + beta * LpEpsilon(result.shape) result_ub, result_lb = result.ublb() result_ub = torch.minimum(result_ub, ub_ibp) result_lb = torch.maximum(result_lb, lb_ibp) return result.reshape(orig_shape)