boundlab.diff.op.diff_pair#
- boundlab.diff.op.diff_pair(x, y)[source]#
Mark two tensors as a differentially-paired input.
This is a registered
torch.librarycustom operator, so it is captured verbatim when the containing model is exported withtorch.export.export(). Duringtorch.onnx.export(), it is lowered to a custom-domain ONNX nodeboundlab::diff_pair. At the concrete-tensor level it returnsxunchanged (a no-op).When the exported graph is run through a differential interpreter (e.g.
boundlab.diff.zono3.interpret) thediff_pairnode is replaced by aDiffExpr2that tracks both branches simultaneously through all subsequent operations.- Parameters:
x (torch.Tensor) – Tensor for the first network branch.
y (torch.Tensor) – Tensor for the second network branch; must have the same shape and dtype as
x.
- Returns:
A fake tensor with the same shape and dtype as
x; carries no concrete information fromyat runtime.- Return type:
Examples
Exporting a model that uses
diff_pair:>>> import torch >>> from torch import nn >>> from boundlab.diff.op import diff_pair >>> class PairedModel(nn.Module): ... def __init__(self): ... super().__init__() ... self.fc = nn.Linear(4, 3) ... def forward(self, x, y): ... p = diff_pair(x, y) ... return self.fc(p) >>> model = PairedModel() >>> gm = torch.export.export(model, (torch.zeros(4), torch.zeros(4))) >>> any("diff_pair" in str(n.target) for n in gm.graph.nodes) True