"""Operators for differential verification.
Registers ``boundlab::diff_pair`` as a :mod:`torch.library` custom operator so
that it can be captured by :func:`torch.export.export`. The operator takes two
tensors of the same shape and returns a single *fake* tensor of the same shape
— a no-op at the concrete-tensor level whose sole purpose is to mark two
branches as a *paired* input for differential abstract interpretation.
When a :class:`~boundlab.interp.Interpreter` processes an exported graph, the
``diff_pair`` node is converted to a
:class:`~boundlab.diff.expr.DiffExpr2` by the registered handler.
"""
import torch
import torch.library
from boundlab.diff.expr import DiffExpr2
from boundlab.expr._affine import ConstVal
# =====================================================================
# Custom operator registration
# =====================================================================
_lib = torch.library.Library("boundlab", "DEF")
_lib.define("diff_pair(Tensor x, Tensor y) -> Tensor")
_lib.impl("diff_pair", lambda x, _: x, "CPU")
_lib.impl("diff_pair", lambda x, _: x, "CUDA")
# Shape/dtype inference for torch.export tracing.
# `register_fake` is the current API; fall back to the legacy `impl_abstract`.
_register_fake = getattr(torch.library, "register_fake", torch.library.impl_abstract)
_register_fake("boundlab::diff_pair")(lambda x, _: torch.empty_like(x))
# =====================================================================
# Public API
# =====================================================================
[docs]
def diff_pair(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Mark two tensors as a differentially-paired input.
This is a registered :mod:`torch.library` custom operator, so it is
captured verbatim when the containing model is exported with
:func:`torch.export.export`. During :func:`torch.onnx.export`, it is lowered
to a custom-domain ONNX node ``boundlab::diff_pair``. At the
concrete-tensor level it returns ``x`` unchanged (a no-op).
When the exported graph is run through a differential interpreter
(e.g. :data:`boundlab.diff.zono3.interpret`) the ``diff_pair`` node is
replaced by a :class:`~boundlab.diff.expr.DiffExpr2` that tracks both
branches simultaneously through all subsequent operations.
Args:
x: Tensor for the first network branch.
y: Tensor for the second network branch; must have the same shape
and dtype as ``x``.
Returns:
A fake tensor with the same shape and dtype as ``x``; carries no
concrete information from ``y`` at runtime.
Examples
--------
Exporting a model that uses ``diff_pair``:
>>> import torch
>>> from torch import nn
>>> from boundlab.diff.op import diff_pair
>>> class PairedModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(4, 3)
... def forward(self, x, y):
... p = diff_pair(x, y)
... return self.fc(p)
>>> model = PairedModel()
>>> gm = torch.export.export(model, (torch.zeros(4), torch.zeros(4)))
>>> any("diff_pair" in str(n.target) for n in gm.graph.nodes)
True
"""
if torch.onnx.is_in_onnx_export():
# During ONNX export, emit an explicit custom-domain node so the
# resulting model preserves differential pairing semantics.
return torch.onnx.ops.symbolic(
"boundlab::diff_pair",
(x, y),
dtype=x.dtype,
shape=x.shape,
version=1,
)
return torch.ops.boundlab.diff_pair(x, y)
# =====================================================================
# DiffLinear
# =====================================================================
import torch.nn as nn
[docs]
class DiffLinear(nn.Module):
"""Two parallel linear layers whose outputs are paired via :func:`diff_pair`.
At the concrete-tensor level this is equivalent to running ``fc1(x)``
(``fc2``'s output is discarded at runtime via the ``diff_pair`` no-op).
When the model is exported and interpreted by the differential interpreter
(e.g. :data:`boundlab.diff.zono3.interpret`), the ``diff_pair`` node is
lifted into a :class:`~boundlab.diff.expr.DiffExpr2` that tracks both
branches simultaneously.
Args:
fc1: First linear layer.
fc2: Second linear layer; must have the same ``in_features``,
``out_features``, and dtype as ``fc1``.
Examples
--------
>>> import torch
>>> from torch import nn
>>> from boundlab.diff.op import DiffLinear
>>> fc1 = nn.Linear(4, 3)
>>> fc2 = nn.Linear(4, 3)
>>> model = DiffLinear(fc1, fc2)
>>> out = model(torch.zeros(4))
>>> out.shape
torch.Size([3])
"""
[docs]
def __init__(self, fc1: nn.Linear, fc2: nn.Linear):
super().__init__()
assert fc1.in_features == fc2.in_features, "fc1 and fc2 must share in_features"
assert fc1.out_features == fc2.out_features, "fc1 and fc2 must share out_features"
self.fc1 = fc1
self.fc2 = fc2
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = diff_pair(self.fc1.weight, self.fc2.weight)
assert (self.fc1.bias is not None) == (self.fc2.bias is not None), "fc1 and fc2 must both have bias or both have no bias"
if self.fc1.bias is not None:
return x @ weight.t() + diff_pair(self.fc1.bias, self.fc2.bias)
else:
return x @ weight.t()
# =====================================================================
# Interpreter handler (used by boundlab.diff.zono3.interpret)
# =====================================================================
def diff_pair_handler(x, y) -> DiffExpr2:
"""Interpreter handler: convert a ``diff_pair`` node to a DiffExpr2.
Registered in :data:`boundlab.diff.zono3.interpret` when this module is
imported.
"""
if isinstance(x, torch.Tensor):
x = ConstVal(x)
if isinstance(y, torch.Tensor):
y = ConstVal(y)
return DiffExpr2(x, y)
__all__ = ["diff_pair", "DiffLinear"]