boundlab.diff.op.diff_pair#

boundlab.diff.op.diff_pair(x, y)[source]#

Mark two tensors as a differentially-paired input.

This is a registered torch.library custom operator, so it is captured verbatim when the containing model is exported with torch.export.export(). During torch.onnx.export(), it is lowered to a custom-domain ONNX node boundlab::diff_pair. At the concrete-tensor level it returns x unchanged (a no-op).

When the exported graph is run through a differential interpreter (e.g. boundlab.diff.zono3.interpret) the diff_pair node is replaced by a DiffExpr2 that tracks both branches simultaneously through all subsequent operations.

Parameters:
  • x (torch.Tensor) – Tensor for the first network branch.

  • y (torch.Tensor) – Tensor for the second network branch; must have the same shape and dtype as x.

Returns:

A fake tensor with the same shape and dtype as x; carries no concrete information from y at runtime.

Return type:

torch.Tensor

Examples

Exporting a model that uses diff_pair:

>>> import torch
>>> from torch import nn
>>> from boundlab.diff.op import diff_pair
>>> class PairedModel(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.fc = nn.Linear(4, 3)
...     def forward(self, x, y):
...         p = diff_pair(x, y)
...         return self.fc(p)
>>> model = PairedModel()
>>> gm = torch.export.export(model, (torch.zeros(4), torch.zeros(4)))
>>> any("diff_pair" in str(n.target) for n in gm.graph.nodes)
True