Source code for boundlab.utils
r"""Utility Functions for BoundLab
This module provides helper functions used throughout the BoundLab framework.
Examples
--------
>>> from boundlab.utils import merge_name
>>> merge_name("x", "+", "y")
'(x + y)'
"""
import string
from typing import Callable, Literal, Sequence, TypeAlias, TypeVar, Union
from torch import nn
import onnx_ir as ir
A = TypeVar("A")
Triple: TypeAlias = tuple[A, A, A]
import torch
import tempfile
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
[docs]
def merge_name(name1, op: str, name2) -> str | None:
"""Merge two optional names into a single name for a composed operation.
Examples
--------
>>> merge_name("left", "@", "right")
'(left @ right)'
>>> merge_name(None, "@", "right") is None
True
"""
name1 = name1.name if hasattr(name1, "name") else name1
name2 = name2.name if hasattr(name2, "name") else name2
if type(name1) is not str or type(name2) is not str:
return None
if name1 is not None and name2 is not None:
return f"({name1} {op} {name2})"
return None
[docs]
def onnx_export(
f: Callable[..., torch.Tensor] | nn.Module,
args: tuple[Union[torch.Size, list[int]], ...],
kwargs: dict[str, Union[torch.Size, list[int]]] = {},
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
) -> ir.Model:
"""Export a PyTorch function to ONNX format.
Examples
--------
>>> import torch
>>> from boundlab.utils import onnx_export
>>> def f(x):
... return x @ x.T
...
>>> model_proto = onnx_export(f, [3, 4])
>>> list(model_proto.graph)[0].op_type
'MatMul'
"""
if not isinstance(f, nn.Module):
class Wrapper(nn.Module):
def forward(self, *args, **kwargs):
return f(*args, **kwargs)
f = Wrapper()
elif isinstance(f, nn.Module):
f = f.eval()
args_tensor = tuple(torch.zeros(s) for s in args)
program = torch.onnx.export(
f, args_tensor,
export_params=True, input_names=input_names, output_names=output_names)
return program.model