from __future__ import annotations
import dataclasses
import torch
from boundlab import expr
from boundlab.expr._core import Expr
[docs]
@dataclasses.dataclass
class DiffExpr2:
"""A pair of expressions ``(x, y)`` for two-network differential tracking.
All linear operators apply element-wise to both components.
"""
x: Expr
y: Expr
@property
def shape(self) -> torch.Size:
return self.x.shape
def _map(self, fn):
return DiffExpr2(fn(self.x), fn(self.y))
[docs]
def get_const(self) -> tuple[torch.Tensor, torch.Tensor] | None:
x = self.x.get_const()
if x is not None:
y = self.y.get_const()
if y is not None:
return x, y
return None
# ------------------------------------------------------------------
# Arithmetic
# ------------------------------------------------------------------
[docs]
def __add__(self, other):
if isinstance(other, int) and other == 0:
return self
if isinstance(other, (torch.Tensor, Expr)):
return DiffExpr2(self.x + other, self.y + other)
if isinstance(other, DiffExpr2):
return DiffExpr2(self.x + other.x, self.y + other.y)
return NotImplemented
def __radd__(self, other):
if isinstance(other, int) and other == 0:
return self
if isinstance(other, (torch.Tensor, Expr)):
return DiffExpr2(other + self.x, other + self.y)
return NotImplemented
def __neg__(self):
return DiffExpr2(-self.x, -self.y)
def __sub__(self, other):
if isinstance(other, (torch.Tensor, Expr)):
return DiffExpr2(self.x - other, self.y - other)
if isinstance(other, DiffExpr2):
return DiffExpr2(self.x - other.x, self.y - other.y)
return NotImplemented
def __rsub__(self, other):
if isinstance(other, (torch.Tensor, Expr)):
return DiffExpr2(other - self.x, other - self.y)
return NotImplemented
[docs]
def __mul__(self, other):
if isinstance(other, (int, float, torch.Tensor)):
return DiffExpr2(self.x * other, self.y * other)
if isinstance(other, Expr):
if (t := other.get_const()) is not None:
return DiffExpr2(self.x * t, self.y * t)
elif (tensors := self.get_const()) is not None:
return DiffExpr2(other * tensors[0], other * tensors[1])
if isinstance(other, DiffExpr2):
if (tensors := self.get_const()) is not None:
return DiffExpr2(other.x * tensors[0], other.y * tensors[1])
elif (tensors := other.get_const()) is not None:
return DiffExpr2(self.x * tensors[0], self.y * tensors[1])
return NotImplemented
def __rmul__(self, other):
if isinstance(other, (int, float, torch.Tensor)):
return DiffExpr2(other * self.x, other * self.y)
if isinstance(other, Expr):
return self.__mul__(other)
return NotImplemented
def __truediv__(self, other):
if isinstance(other, (int, float, torch.Tensor)):
return DiffExpr2(self.x / other, self.y / other)
return NotImplemented
def __matmul__(self, other):
if isinstance(other, torch.Tensor):
return DiffExpr2(self.x @ other, self.y @ other)
if isinstance(other, Expr):
if (t := other.get_const()) is not None:
return DiffExpr2(self.x @ t, self.y @ t)
elif (tensors := self.get_const()) is not None:
return DiffExpr2(tensors[0] @ other, tensors[1] @ other)
if isinstance(other, DiffExpr2):
if (tensors := self.get_const()) is not None:
return DiffExpr2(tensors[0] @ other.x, self.y @ tensors[1] @ other.y)
elif (tensors := other.get_const()) is not None:
return DiffExpr2(self.x @ tensors[0], self.y @ tensors[1])
return NotImplemented
def __rmatmul__(self, other):
if isinstance(other, torch.Tensor):
return DiffExpr2(other @ self.x, other @ self.y)
if isinstance(other, Expr):
if (t := other.get_const()) is not None:
return DiffExpr2(t @ self.x, t @ self.y)
elif (tensors := self.get_const()) is not None:
return DiffExpr2(other @ tensors[0], other @ tensors[1])
# if isinstance(other, DiffExpr3):
# # other is input triple (x, y, d); self is constant weight pair (W1, W2)
# # x@W1 − y@W2 = d@W1 + y@(W1−W2)
# wx, wy = _const_value(self.x), _const_value(self.y)
# if wx is not None and wy is not None:
# return DiffExpr3(
# other.x @ wx,
# other.y @ wy,
# other.diff @ wx + other.y @ (wx - wy),
# )
return NotImplemented
# ------------------------------------------------------------------
# Indexing
# ------------------------------------------------------------------
def __getitem__(self, indices):
return DiffExpr2(self.x[indices], self.y[indices])
[docs]
def scatter(self, indices, output_shape):
return self._map(lambda e: e.scatter(indices, output_shape))
[docs]
def gather(self, indices):
return self._map(lambda e: e.gather(indices))
# ------------------------------------------------------------------
# Shape ops
# ------------------------------------------------------------------
[docs]
def reshape(self, *shape):
return self._map(lambda e: e.reshape(*shape))
[docs]
def permute(self, *dims):
return self._map(lambda e: e.permute(*dims))
[docs]
def transpose(self, dim0, dim1):
return self._map(lambda e: e.transpose(dim0, dim1))
[docs]
def flatten(self, start_dim=0, end_dim=-1):
return self._map(lambda e: e.flatten(start_dim, end_dim))
[docs]
def unflatten(self, dim, sizes):
return self._map(lambda e: e.unflatten(dim, sizes))
[docs]
def squeeze(self, dim=None):
return self._map(lambda e: e.squeeze(dim))
[docs]
def unsqueeze(self, dim):
return self._map(lambda e: e.unsqueeze(dim))
[docs]
def narrow(self, dim, start, length):
return self._map(lambda e: e.narrow(dim, start, length))
[docs]
def expand(self, *sizes):
return self._map(lambda e: e.expand(*sizes))
[docs]
def repeat(self, *sizes):
return self._map(lambda e: e.repeat(*sizes))
[docs]
def tile(self, *sizes):
return self._map(lambda e: e.tile(*sizes))
[docs]
def flip(self, dims):
return self._map(lambda e: e.flip(dims))
[docs]
def roll(self, shifts, dims):
return self._map(lambda e: e.roll(shifts, dims))
[docs]
def diag(self, diagonal=0):
return self._map(lambda e: e.diag(diagonal))
[docs]
@dataclasses.dataclass
class DiffExpr3:
"""A triple ``(x, y, diff)`` for differential zonotope verification.
``x`` and ``y`` track each network's activations independently.
``diff`` over-approximates ``f₁(x) − f₂(y)``.
For affine operations ``f(z) = W z + b``:
- ``x`` and ``y`` receive both weight and bias.
- ``diff`` receives only the weight (bias cancels: ``(Wx+b)−(Wy+b) = W(x−y)``).
For pure-linear operations (no bias), all three components are updated
identically.
"""
x: Expr
y: Expr
diff: Expr
@property
def shape(self) -> torch.Size:
return self.x.shape
def _map_all(self, fn):
"""Apply *fn* to all three components (pure-linear ops)."""
return DiffExpr3(fn(self.x), fn(self.y), fn(self.diff))
# ------------------------------------------------------------------
# Arithmetic
# ------------------------------------------------------------------
[docs]
def __add__(self, other):
if isinstance(other, int) and other == 0:
return self
if isinstance(other, (torch.Tensor, Expr)):
# Constant bias cancels in the diff component.
return DiffExpr3(self.x + other, self.y + other, self.diff)
if isinstance(other, DiffExpr2):
return DiffExpr3(
self.x + other.x,
self.y + other.y,
self.diff + (other.x - other.y),
)
if isinstance(other, DiffExpr3):
return DiffExpr3(self.x + other.x, self.y + other.y, self.diff + other.diff)
return NotImplemented
def __radd__(self, other):
if isinstance(other, int) and other == 0:
return self
if isinstance(other, (torch.Tensor, Expr)):
return DiffExpr3(other + self.x, other + self.y, self.diff)
return NotImplemented
def __neg__(self):
return DiffExpr3(-self.x, -self.y, -self.diff)
def __sub__(self, other):
if isinstance(other, int) and other == 0:
return self
if isinstance(other, (torch.Tensor, Expr)):
return DiffExpr3(self.x - other, self.y - other, self.diff)
if isinstance(other, DiffExpr2):
return DiffExpr3(
self.x - other.x,
self.y - other.y,
self.diff - (other.x - other.y),
)
if isinstance(other, DiffExpr3):
return DiffExpr3(self.x - other.x, self.y - other.y, self.diff - other.diff)
return NotImplemented
def __rsub__(self, other):
if isinstance(other, (torch.Tensor, Expr)):
# ``other − (x, y, d)`` negates all three then adds constant to x/y.
return DiffExpr3(other - self.x, other - self.y, -self.diff)
return NotImplemented
[docs]
def __mul__(self, other):
if isinstance(other, (int, float, torch.Tensor)):
return self._map_all(lambda e: e * other)
if isinstance(other, Expr):
if (v := other.get_const()) is not None:
return self._map_all(lambda e: e * v)
if isinstance(other, DiffExpr2):
if (tensors := self.get_const()) is not None:
# Bilinear diff identity: x*vx − y*vy = diff*vx + y*(vx − vy)
return DiffExpr3(
self.x * tensors[0],
self.y * tensors[1],
self.diff * tensors[0] + self.y * (tensors[0] - tensors[1]),
)
return NotImplemented
def __rmul__(self, other):
if isinstance(other, (int, float, torch.Tensor)):
return self._map_all(lambda e: other * e)
if isinstance(other, Expr):
if (v := other.get_const()) is not None:
return self._map_all(lambda e: v * e)
if isinstance(other, DiffExpr2):
if (tensors := other.get_const()) is not None:
return DiffExpr3(
tensors[0] * self.x,
tensors[1] * self.y,
tensors[0] * self.diff + (tensors[0] - tensors[1]) * self.y,
)
return NotImplemented
def __truediv__(self, other):
if isinstance(other, (int, float, torch.Tensor)):
return self._map_all(lambda e: e / other)
return NotImplemented
def __matmul__(self, other):
if isinstance(other, torch.Tensor):
return self._map_all(lambda e: e @ other)
if isinstance(other, Expr):
if (v := other.get_const()) is not None:
return self._map_all(lambda e: e @ v)
if isinstance(other, DiffExpr2):
# self is input triple (x, y, d); other is constant weight pair (W1, W2)
# x@W1 − y@W2 = d@W1 + y@(W1−W2)
if (tensors := other.get_const()) is not None:
return DiffExpr3(
self.x @ tensors[0],
self.y @ tensors[1],
self.diff @ tensors[0] + self.y @ (tensors[0] - tensors[1]),
)
return NotImplemented
def __rmatmul__(self, other):
if isinstance(other, torch.Tensor):
return self._map_all(lambda e: other @ e)
if isinstance(other, Expr):
if (v := other.get_const()) is not None:
return self._map_all(lambda e: v @ e)
if isinstance(other, DiffExpr2):
# other is constant weight pair (W1, W2); self is input triple (x, y, d)
# W1@x − W2@y = W1@d + (W1−W2)@y
if (tensors := other.get_const()) is not None:
return DiffExpr3(
tensors[0] @ self.x,
tensors[1] @ self.y,
tensors[0] @ self.diff + (tensors[0] - tensors[1]) @ self.y,
)
return NotImplemented
# ------------------------------------------------------------------
# Indexing
# ------------------------------------------------------------------
def __getitem__(self, indices):
# Plain int: tuple-unpacking (e.g. from getitem nodes in torch.export graph).
if isinstance(indices, int):
return (self.x, self.y, self.diff)[indices]
return self._map_all(lambda e: e[indices])
[docs]
def scatter(self, indices, output_shape):
return self._map_all(lambda e: e.scatter(indices, output_shape))
[docs]
def gather(self, indices):
return self._map_all(lambda e: e.gather(indices))
# ------------------------------------------------------------------
# Shape ops
# ------------------------------------------------------------------
[docs]
def reshape(self, *shape):
return self._map_all(lambda e: e.reshape(*shape))
[docs]
def permute(self, *dims):
return self._map_all(lambda e: e.permute(*dims))
[docs]
def transpose(self, dim0, dim1):
return self._map_all(lambda e: e.transpose(dim0, dim1))
[docs]
def flatten(self, start_dim=0, end_dim=-1):
return self._map_all(lambda e: e.flatten(start_dim, end_dim))
[docs]
def unflatten(self, dim, sizes):
return self._map_all(lambda e: e.unflatten(dim, sizes))
[docs]
def squeeze(self, dim=None):
return self._map_all(lambda e: e.squeeze(dim))
[docs]
def unsqueeze(self, dim):
return self._map_all(lambda e: e.unsqueeze(dim))
[docs]
def narrow(self, dim, start, length):
return self._map_all(lambda e: e.narrow(dim, start, length))
[docs]
def expand(self, *sizes):
return self._map_all(lambda e: e.expand(*sizes))
[docs]
def repeat(self, *sizes):
return self._map_all(lambda e: e.repeat(*sizes))
[docs]
def tile(self, *sizes):
return self._map_all(lambda e: e.tile(*sizes))
[docs]
def flip(self, dims):
return self._map_all(lambda e: e.flip(dims))
[docs]
def roll(self, shifts, dims):
return self._map_all(lambda e: e.roll(shifts, dims))
[docs]
def diag(self, diagonal=0):
return self._map_all(lambda e: e.diag(diagonal))
__all__ = ["DiffExpr2", "DiffExpr3"]