Source code for boundlab.expr._var
r"""Symbolic Variables for Bound Propagation
This module defines variable expressions representing bounded input
perturbations used in neural network verification.
"""
import math
from typing import Literal
import torch
from boundlab.expr._core import Expr, ExprFlags
from boundlab.linearop._base import LinearOp, LinearOpFlags
from boundlab.linearop._einsum import EinsumOp
[docs]
class LpEpsilon(Expr):
r"""A noise symbol bounded by the :math:`\ell_p`-norm constraint.
Represents a perturbation variable :math:`\boldsymbol{\epsilon}` satisfying:
.. math:: \|\boldsymbol{\epsilon}\|_p \leq 1
During backward propagation with direction ``"<="`` or ``">="``, the
contribution is :math:`\pm\|\mathbf{w}\|_q` where :math:`\mathbf{w}`
is the materialized weight tensor and :math:`q` is the dual norm of
:math:`p` defined by :math:`\frac{1}{p} + \frac{1}{q} = 1`.
Only :class:`~boundlab.linearop.EinsumOp` weights are supported.
"""
[docs]
def __init__(self, *shape, name=None, p="inf"):
super().__init__(ExprFlags.SYMMETRIC_TO_0)
self._shape = torch.Size(*shape)
self.name = name
if p == "inf":
p = math.inf
self.p = p
if p == math.inf:
self.q = 1
else:
self.q = 1 / (1 - 1/p) if p > 1 else math.inf
@property
def shape(self) -> torch.Size:
return self._shape
[docs]
def with_children(self) -> "LpEpsilon":
return self
@property
def children(self) -> tuple[()]:
return ()
[docs]
def backward(self, weights: LinearOp, direction: Literal[">=", "<=", "=="]) \
-> tuple[torch.Tensor, list] | None:
r"""Compute the dual-norm bound contribution.
Args:
weights: A :class:`~boundlab.linearop.EinsumOp` accumulated
weight. Must be a ``EinsumOp`` instance.
direction: ``"<="`` returns :math:`+\|\mathbf{w}\|_q`;
``">="`` returns :math:`-\|\mathbf{w}\|_q`;
``"=="`` returns ``None``.
Returns:
``(±norm, [])`` or ``None`` for ``"=="``.
"""
from boundlab.linearop import LinearOp
if direction == "==":
return None
jac = weights.jacobian()
input_dims = list(range(len(weights.output_shape), len(weights.output_shape) + len(weights.input_shape)))
result = jac.norm(p=self.q, dim=input_dims)
return (result if direction == "<=" else -result, [])
[docs]
def to_string(self) -> str:
if self.name is not None:
return f"<𝜀 {list(self.shape)}>#{self.name}"
return f"<𝜀 {list(self.shape)}>#{self.id:X}"