Source code for boundlab.interp.onnx

"""ONNX export utilities for BoundLab.

Provides :func:`onnx_export` and :func:`register_onnx_translation` — tools for
converting PyTorch modules to ONNX IR models that the abstract interpretation
:class:`~boundlab.interp.Interpreter` can consume.

Custom ops (e.g. ``boundlab::diff_pair``) are handled via a two-step process:

1. An *onnxscript sentinel function* is registered as a placeholder for the
   custom torch op during ``torch.onnx.export``.
2. After export, :func:`_apply_sentinel_fixups` replaces every sentinel node
   with a proper primitive custom-domain ONNX node.

Call :func:`register_onnx_translation` from the module that defines the custom
torch op to hook into this mechanism.

Examples
--------
>>> import torch
>>> from boundlab.interp.onnx import onnx_export
>>> def f(x):
...     return x @ x.T
...
>>> model = onnx_export(f, ([3, 4],))
>>> list(model.graph)[0].op_type
'MatMul'
"""

from __future__ import annotations

from typing import Callable, Sequence, Union

import onnx_ir as ir
import torch
from torch import nn
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv

__all__ = ["onnx_export", "register_onnx_translation"]

[docs] def onnx_export( f: Callable[..., torch.Tensor] | nn.Module, args: tuple[Union[torch.Size, list[int], torch.Tensor], ...], kwargs: dict[str, Union[torch.Size, list[int]]] = {}, input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, optimize: bool = None, ) -> ir.Model: """Export a PyTorch function or module to an ONNX IR model. Shape arguments are given as lists/tuples of ints — zero-value tensors are constructed internally and only used for tracing. Args: f: A callable or :class:`torch.nn.Module` to export. args: Input shapes, one per positional argument (e.g. ``([3, 4],)`` for a single rank-2 input). kwargs: Keyword-argument shapes (rarely needed). input_names: Optional names for the ONNX graph inputs. output_names: Optional names for the ONNX graph outputs. Returns: An :class:`onnx_ir.Model` ready for abstract interpretation. Examples -------- >>> import torch >>> from boundlab.interp.onnx import onnx_export >>> def f(x): ... return x @ x.T ... >>> model = onnx_export(f, ([3, 4],)) >>> list(model.graph)[0].op_type 'MatMul' """ if not isinstance(f, nn.Module): class Wrapper(nn.Module): def forward(self, *args, **kwargs): return f(*args, **kwargs) mod = Wrapper().eval() else: mod = f.eval() args_tensor = tuple(x if isinstance(x, torch.Tensor) else torch.zeros(x) for x in args) with FakeTensorMode( allow_non_fake_inputs=True, shape_env=ShapeEnv(allow_dynamic_output_shape_ops=True), ): program = torch.export.export(mod, args_tensor) onnx_program = torch.onnx.export( program, args=(), export_params=True, optimize=optimize, input_names=input_names, output_names=output_names, verbose=False, ) return onnx_program.model