Source code for boundlab.interp

"""Abstract Interpretation Framework for Neural Network Verification.

Examples
--------
Interpret an ONNX model:

>>> import torch
>>> import onnx_ir as ir
>>> import boundlab.expr as expr
>>> from boundlab.interp import Interpreter
>>> itp = Interpreter({"placeholder": lambda x, name: x, "Relu": lambda x: x})
"""

from __future__ import annotations

import copy
import inspect
import math
from pathlib import Path
from typing import Any, Callable, Generic, TypeVar
import onnx_ir as ir
import torch
import beartype
import beartype.roar

from boundlab.expr._core import Expr, ExprFlags
from .onnx import onnx_export

__all__ = ["Interpreter", "ONNX_BASE_INTERPRETER", "onnx_export"]

E = TypeVar("E", bound=Expr)


# =====================================================================
# ONNX attribute / shape helpers (used by __call__ and ONNX_BASE_INTERPRETER)
# =====================================================================

def _onnx_attr_value(attr) -> Any:
    """Convert an ONNX IR attribute to a Python value."""
    t = attr.type
    if t == ir.AttributeType.FLOAT:
        return attr.as_float()
    elif t == ir.AttributeType.INT:
        return attr.as_int()
    elif t == ir.AttributeType.STRING:
        return attr.as_string()
    elif t == ir.AttributeType.TENSOR:
        return torch.from_numpy(attr.as_tensor().numpy().copy()).to(torch.get_default_device())
    elif t == ir.AttributeType.FLOATS:
        return list(attr.as_floats())
    elif t == ir.AttributeType.INTS:
        return list(attr.as_ints())
    elif t == ir.AttributeType.STRINGS:
        return list(attr.as_strings())
    else:
        return None


def _unwrap_shape(x) -> list[int]:
    """Extract a concrete shape/axes list from a tensor."""
    if isinstance(x, torch.Tensor):
        return x.long().tolist()
    return list(x)


def _onnx_gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
    """ONNX Gemm: ``Y = alpha * (A' @ B') + beta * C``."""
    a = A.transpose(0, 1) if transA else A
    b = B.transpose(0, 1) if transB else B
    y = a @ b
    if alpha != 1.0:
        y = alpha * y
    if C is not None:
        y = y + (beta * C if beta != 1.0 else C)
    return y


def _onnx_flatten(X, axis=1):
    """ONNX Flatten: produce 2-D tensor ``[prod(dims[:axis]), prod(dims[axis:])]``."""
    first = math.prod(X.shape[:axis]) if axis > 0 else 1
    return X.reshape(first, -1)


def _onnx_reshape(data, shape, allowzero=0):
    del allowzero
    return data.reshape(_unwrap_shape(shape))


def _onnx_unsqueeze(data, axes):
    """ONNX Unsqueeze: insert size-1 dims at the given axes positions."""
    axes_list = sorted(_unwrap_shape(axes))
    result = data
    for ax in axes_list:
        result = result.unsqueeze(ax)
    return result


def _onnx_squeeze(data, axes=None):
    """ONNX Squeeze: remove size-1 dims at *axes* (or all if omitted)."""
    if axes is None:
        return data.squeeze()
    axes_list = sorted(_unwrap_shape(axes), reverse=True)
    result = data
    for ax in axes_list:
        result = result.squeeze(ax)
    return result

def _onnx_constant(value=None, value_float=None, value_int=None, value_string=None, value_floats=None, value_ints=None, value_strings=None, **_):
    """ONNX Constant node: wrap the tensor attribute as a torch.Tensor."""

    if value is not None:
        return torch.tensor(value).to(torch.get_default_device()) if value is not None else None
    elif value_float is not None:
        return torch.tensor(value_float).to(torch.get_default_device())
    elif value_int is not None:
        return torch.tensor(value_int).to(torch.get_default_device())
    elif value_string is not None:
        return value_string
    elif value_floats is not None:
        return torch.tensor(value_floats).to(torch.get_default_device())
    elif value_ints is not None:
        return torch.tensor(value_ints).to(torch.get_default_device())
    elif value_strings is not None:
        return value_strings
    else:
        raise ValueError(f"ONNX Constant: no value provided among {locals()}")


_ONNX_DTYPE_MAP = {
    1: torch.float32,
    2: torch.uint8,
    3: torch.int8,
    5: torch.int16,
    6: torch.int32,
    7: torch.int64,
    9: torch.bool,
    10: torch.float16,
    11: torch.float64,
    16: torch.bfloat16,
}


def _onnx_cast(input, to):
    """ONNX Cast: convert tensor dtype."""
    dtype = _ONNX_DTYPE_MAP.get(to)
    if dtype is None:
        raise ValueError(f"Unsupported ONNX dtype id: {to}")
    return input.to(dtype)


def _normalize_reduce_axes(axes):
    if axes is None:
        return None
    # Under torch's default ONNX opset (>= 18) the reduce `axes` arrive as a
    # graph *input* tensor, which diff_net wraps as an Expr; route through
    # _as_const before unwrapping (mirrors the Slice / Gather handlers).
    return tuple(int(a) for a in _unwrap_shape(_as_const(axes)))


def _onnx_reduce_sum(data, axes=None, keepdims=1, noop_with_empty_axes=0):
    reduce_axes = _normalize_reduce_axes(axes)
    if reduce_axes == () and int(noop_with_empty_axes) == 1:
        return data
    return data.sum(dim=reduce_axes, keepdim=bool(keepdims))


def _onnx_reduce_mean(data, axes=None, keepdims=1, noop_with_empty_axes=0):
    reduce_axes = _normalize_reduce_axes(axes)
    if reduce_axes == () and int(noop_with_empty_axes) == 1:
        return data
    return data.mean(dim=reduce_axes, keepdim=bool(keepdims))


def _onnx_gather(data, indices, axis=0):
    axis = int(axis)
    indices = _as_const(indices)

    # Handle DiffExpr3
    try:
        from boundlab.diff.expr import DiffExpr3, DiffExpr2
        if isinstance(data, DiffExpr3):
            return DiffExpr3(
                _onnx_gather(data.x, indices, axis),
                _onnx_gather(data.y, indices, axis),
                _onnx_gather(data.diff, indices, axis),
            )
        if isinstance(data, DiffExpr2):
            return DiffExpr2(
                _onnx_gather(data.x, indices, axis),
                _onnx_gather(data.y, indices, axis),
            )
    except ImportError:
        pass

    if isinstance(data, Expr):
        rank = len(data.shape)
        axis = axis + rank if axis < 0 else axis
        idx = indices.long()
        if idx.numel() == 1:
            slices = [slice(None)] * rank
            slices[axis] = int(idx.item())
            return data[tuple(slices)]
        if idx.dim() == 1:
            from boundlab.expr import Cat
            parts = []
            for i in idx.tolist():
                slices = [slice(None)] * rank
                slices[axis] = int(i)
                parts.append(data[tuple(slices)].unsqueeze(axis))
            return Cat(*parts, dim=axis)
        raise NotImplementedError(
            f"ONNX Gather for Expr currently supports scalar/1D indices, got shape {tuple(idx.shape)}"
        )

    if indices.dim() == 0:
        index = indices.reshape(1).long()
    else:
        index = indices.long().reshape(-1)
    gathered = torch.index_select(data, axis, index)
    out_shape = list(data.shape[:axis]) + list(indices.shape) + list(data.shape[axis + 1 :])
    return gathered.reshape(out_shape)


def _onnx_concat(*inputs, axis):
    """ONNX Concat: concatenate inputs along *axis*.

    Dispatches to :class:`boundlab.expr.Cat` when any input is an
    :class:`Expr` (wrapping plain tensors as :class:`ConstVal`),
    otherwise uses :func:`torch.cat`.
    """
    axis = int(axis)
    if any(isinstance(x, Expr) for x in inputs):
        from boundlab.expr import Cat, ConstVal
        parts = [x if isinstance(x, Expr) else ConstVal(x) for x in inputs]
        return Cat(*parts, dim=axis)
    return torch.cat(list(inputs), dim=axis)


def _onnx_einsum(*inputs, equation):
    """ONNX Einsum: dispatch to torch.einsum for constants, or build an
    :class:`EinsumOp` when exactly one operand is an :class:`Expr`.

    Constant operands are pre-contracted into a single tensor whose axes span
    ``union(x_labels, out_labels)``, then wrapped by :class:`EinsumOp` so it
    fuses with surrounding linear ops.
    """
    equation = equation.replace(" ", "")
    lhs, rhs = equation.split("->")
    in_labels = lhs.split(",")
    assert len(in_labels) == len(inputs), f"Einsum equation mismatch: {equation}"

    # DiffExpr support: evaluate component-wise
    try:
        from boundlab.diff.expr import DiffExpr2, DiffExpr3
        diff_positions = [i for i, v in enumerate(inputs) if isinstance(v, (DiffExpr3, DiffExpr2))]
    except Exception:
        DiffExpr2 = DiffExpr3 = ()
        diff_positions = []

    if diff_positions:
        if len(diff_positions) == 1:
            di = diff_positions[0]
            diff = inputs[di]
            if isinstance(diff, DiffExpr2):
                diff = DiffExpr3(diff.x, diff.y, diff.x - diff.y)
            in_x = list(inputs); in_x[di] = diff.x
            in_y = list(inputs); in_y[di] = diff.y
            in_d = list(inputs); in_d[di] = diff.diff
            out_x = _onnx_einsum(*in_x, equation=equation)
            out_y = _onnx_einsum(*in_y, equation=equation)
            out_d = _onnx_einsum(*in_d, equation=equation)
            return DiffExpr3(out_x, out_y, out_d)
        elif len(diff_positions) == 2:
            # Bilinear case: two DiffExpr inputs.
            # Use identity: A.x ⊗ B.x − A.y ⊗ B.y = A.diff ⊗ B.x + A.y ⊗ (B.x − B.y)
            # Identify DiffExpr3 (has .diff) and DiffExpr2 positions.
            di3 = [i for i in diff_positions if isinstance(inputs[i], DiffExpr3)]
            di2 = [i for i in diff_positions if isinstance(inputs[i], DiffExpr2)]
            in_x = list(inputs)
            in_y = list(inputs)
            for i in diff_positions:
                in_x[i] = _as_const(inputs[i].x) if isinstance(inputs[i], DiffExpr2) else inputs[i].x
                in_y[i] = _as_const(inputs[i].y) if isinstance(inputs[i], DiffExpr2) else inputs[i].y
            out_x = _onnx_einsum(*in_x, equation=equation)
            out_y = _onnx_einsum(*in_y, equation=equation)
            if len(di3) == 1 and len(di2) == 1:
                # Bilinear identity: diff = A.diff ⊗ B.x + A.y ⊗ (B.x − B.y)
                a_idx, b_idx = di3[0], di2[0]
                a, b = inputs[a_idx], inputs[b_idx]
                bx = _as_const(b.x)
                by = _as_const(b.y)
                in_d1 = list(inputs)
                in_d1[a_idx] = a.diff
                in_d1[b_idx] = bx
                in_d2 = list(inputs)
                in_d2[a_idx] = a.y
                in_d2[b_idx] = bx - by
                out_d = _onnx_einsum(*in_d1, equation=equation) + _onnx_einsum(*in_d2, equation=equation)
            else:
                out_d = out_x - out_y
            return DiffExpr3(out_x, out_y, out_d)
        else:
            raise NotImplementedError(f"Einsum with {len(diff_positions)} DiffExpr inputs is not supported")

    expr_positions = [i for i, v in enumerate(inputs) if isinstance(v, Expr)]
    if not expr_positions:
        return torch.einsum(equation, *inputs)
    assert len(expr_positions) == 1, "Einsum with multiple Expr inputs is not supported"
    ei = expr_positions[0]
    x = inputs[ei]
    x_labels = in_labels[ei]
    const_labels = [in_labels[i] for i in range(len(inputs)) if i != ei]
    const_tensors = [inputs[i] for i in range(len(inputs)) if i != ei]
    out_labels = rhs

    t_labels = list(dict.fromkeys(x_labels + out_labels))
    sizes = {l: s for l, s in zip(x_labels, x.shape)}
    for t, lbl in zip(const_tensors, const_labels):
        for l, s in zip(lbl, t.shape):
            sizes.setdefault(l, s)

    const_label_set = set("".join(const_labels))
    contract_target = "".join(l for l in t_labels if l in const_label_set)
    if const_tensors:
        contract_eq = ",".join(const_labels) + "->" + contract_target
        tensor = torch.einsum(contract_eq, *const_tensors)
    else:
        tensor = torch.ones(())

    current = list(contract_target)
    for i, l in enumerate(t_labels):
        if l not in current:
            tensor = tensor.unsqueeze(i)
            current.insert(i, l)
    tensor = tensor.expand([sizes[l] for l in t_labels]).contiguous()

    input_dims = [t_labels.index(l) for l in x_labels]
    output_dims = [t_labels.index(l) for l in out_labels]

    from boundlab.linearop._einsum import EinsumOp
    from boundlab.expr._affine import AffineSum
    op = EinsumOp(tensor, input_dims, output_dims)
    return AffineSum((op, x))


def _onnx_conv(X, W, B=None, *, kernel_shape=None, strides=None, pads=None, dilations=None, group=1, auto_pad="NOTSET", **_):
    """ONNX Conv (2D), restricted to ``kernel_size == stride``.

    Reshapes ``X`` into non-overlapping patches
    ``[N, C_in, H/kH, kH, W/kW, kW]`` and contracts against the weight
    tensor ``W`` of shape ``[C_out, C_in, kH, kW]`` via :func:`_onnx_einsum`,
    which fuses with surrounding linear ops.
    """
    if kernel_shape is None:
        kernel_shape = list(W.shape[2:])
    assert len(kernel_shape) == 2, f"only 2D Conv is supported, got kernel_shape={kernel_shape}"
    kH, kW = int(kernel_shape[0]), int(kernel_shape[1])
    strides = [kH, kW] if strides is None else [int(s) for s in strides]
    assert strides == [kH, kW], f"Conv requires kernel_size == stride, got kernel={kernel_shape}, stride={strides}"
    assert int(group) == 1, "grouped Conv is not supported"
    assert auto_pad in ("NOTSET", "VALID"), f"Conv auto_pad={auto_pad} is not supported"
    if pads is not None:
        assert all(int(p) == 0 for p in pads), f"Conv with padding is not supported, got pads={pads}"
    if dilations is not None:
        assert all(int(d) == 1 for d in dilations), f"Conv with dilation is not supported, got dilations={dilations}"

    N, C_in, H, Win = X.shape
    C_out, C_in_w, kH_w, kW_w = W.shape
    assert C_in == C_in_w and kH == kH_w and kW == kW_w, \
        f"Conv weight shape {tuple(W.shape)} incompatible with input {tuple(X.shape)}"
    assert H % kH == 0 and Win % kW == 0, \
        f"input spatial dims ({H},{Win}) not divisible by kernel ({kH},{kW})"
    Hp, Wp = H // kH, Win // kW

    X_r = X.reshape(N, C_in, Hp, kH, Wp, kW)
    Y = _onnx_einsum(X_r, W, equation="nchHwW,ocHW->nohw")
    if B is not None:
        Y = Y + B.reshape(1, C_out, 1, 1)
    return Y


def _onnx_slice(data, starts, ends, axes=None, steps=None, **_):
    """ONNX Slice: extract a slice from *data*."""
    starts_list = _unwrap_shape(_as_const(starts))
    ends_list = _unwrap_shape(_as_const(ends))
    ndim = len(data.shape)
    if axes is not None:
        axes_list = [a % ndim for a in _unwrap_shape(_as_const(axes))]
    else:
        axes_list = list(range(len(starts_list)))
    if steps is not None:
        steps_list = _unwrap_shape(_as_const(steps))
    else:
        steps_list = [1] * len(starts_list)
    slices = [slice(None)] * ndim
    for a, s, e, st in zip(axes_list, starts_list, ends_list, steps_list):
        if e > data.shape[a]:
            e = data.shape[a]
        step = st if st != 1 else None
        slices[a] = slice(s, e, step)
    return data[tuple(slices)]


def _onnx_concat(*args, axis=0):
    """ONNX Concat: concatenate inputs along axis."""
    inputs = list(args)
    from boundlab.expr._core import Expr
    from boundlab.expr._affine import ConstVal

    # Check for DiffExpr types
    try:
        from boundlab.diff.expr import DiffExpr2, DiffExpr3
        if any(isinstance(x, DiffExpr3) for x in inputs):
            from boundlab.expr import Cat
            x_parts, y_parts, d_parts = [], [], []
            for inp in inputs:
                if isinstance(inp, DiffExpr3):
                    x_parts.append(inp.x)
                    y_parts.append(inp.y)
                    d_parts.append(inp.diff)
                elif isinstance(inp, DiffExpr2):
                    x_parts.append(inp.x if isinstance(inp.x, Expr) else ConstVal(inp.x))
                    y_parts.append(inp.y if isinstance(inp.y, Expr) else ConstVal(inp.y))
                    shape = inp.x.shape if hasattr(inp.x, 'shape') else inp.y.shape
                    d_parts.append(ConstVal(torch.zeros(shape)))
                else:
                    v = inp if isinstance(inp, Expr) else ConstVal(inp) if isinstance(inp, torch.Tensor) else inp
                    x_parts.append(v)
                    y_parts.append(v)
                    d_parts.append(ConstVal(torch.zeros(v.shape)))
            dim = int(axis)
            return DiffExpr3(Cat(*x_parts, dim=dim), Cat(*y_parts, dim=dim), Cat(*d_parts, dim=dim))
    except ImportError:
        pass

    if any(isinstance(x, Expr) for x in inputs):
        from boundlab.expr import Cat
        wrapped = [x if isinstance(x, Expr) else ConstVal(x) for x in inputs]
        return Cat(*wrapped, dim=int(axis))
    return torch.cat(inputs, dim=int(axis))


def _onnx_broadcast(X, Y):
    """Broadcast X and Y to compatible shapes (ONNX numpy-style rules)."""
    def _get_shape(v):
        if hasattr(v, 'shape'):
            return v.shape
        return ()
    x_shape = _get_shape(X)
    y_shape = _get_shape(Y)
    if x_shape == y_shape:
        return X, Y
    target = torch.broadcast_shapes(x_shape, y_shape)
    if hasattr(X, 'expand') and x_shape != target:
        X = X.expand(*target)
    elif isinstance(X, torch.Tensor) and X.shape != target:
        X = X.expand(target)
    if hasattr(Y, 'expand') and y_shape != target:
        Y = Y.expand(*target)
    elif isinstance(Y, torch.Tensor) and Y.shape != target:
        Y = Y.expand(target)
    return X, Y

def _onnx_expand(data, shape):
    """ONNX Expand: broadcast *data* to *shape*.

    ONNX ``Expand`` uses *multidirectional* (numpy) broadcasting, so the output
    shape is ``broadcast_shapes(data.shape, shape)`` — not ``shape`` itself.
    torch's ``Tensor.expand`` is unidirectional (a target dim must be ``-1`` or
    equal to the input when the input is non-1), so feeding the raw ONNX target
    fails whenever it carries a ``1`` in a dim where the input is larger (e.g.
    the export of ``x.expand(-1, k)`` as ``Expand(shape=[1, k])``).
    """
    target = tuple(_unwrap_shape(shape))
    in_shape = tuple(getattr(data, 'shape', ()))
    out = tuple(torch.broadcast_shapes(in_shape, target))
    if in_shape == out:
        return data
    if hasattr(data, 'expand'):
        return data.expand(*out)
    elif isinstance(data, torch.Tensor):
        return data.expand(out)
    else:
        raise TypeError(f"Cannot expand object of type {type(data)}")

def _as_const(x):
    """Extract a concrete tensor from a DiffExpr2/3 or ConstVal for shape/index constants."""
    from boundlab.expr._affine import ConstVal as CV
    if isinstance(x, CV):
        return x.value
    try:
        from boundlab.diff.expr import DiffExpr2, DiffExpr3
        if isinstance(x, DiffExpr2):
            c = x.get_const()
            return c[0] if c is not None else _as_const(x.x)
        if isinstance(x, DiffExpr3):
            return _as_const(x.x)
    except ImportError:
        pass
    return x


# =====================================================================
# FnList — multi-handler dispatch helper
# =====================================================================

class FnList(Generic[E]):
    """Helper class for merging multiple handlers for the same operator."""
    def __init__(self, fns):
        if isinstance(fns, FnList):
            self.fns = copy.copy(fns.fns)
        elif isinstance(fns, list):
            self.fns = copy.copy(fns)
        else:
            self.fns = [fns]

    @staticmethod
    def _call(fn: Callable[..., E], *args: E, **kwargs) -> E:
        if not kwargs:
            return fn(*args)

        try:
            params = inspect.signature(fn).parameters.values()
        except (TypeError, ValueError):
            return fn(*args, **kwargs)

        accepted_kwargs = set()
        for param in params:
            if param.kind is inspect.Parameter.VAR_KEYWORD:
                return fn(*args, **kwargs)
            if param.kind in (
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
                inspect.Parameter.KEYWORD_ONLY,
            ):
                accepted_kwargs.add(param.name)

        return fn(
            *args,
            **{key: value for key, value in kwargs.items() if key in accepted_kwargs},
        )

    def __call__(self, *args: E, **kwargs) -> E:
        if len(self.fns) == 1:
            return self._call(self.fns[0], *args, **kwargs)
        errors = []
        for i, fn in enumerate(self.fns[::-1]):
            # try:
            result = self._call(fn, *args, **kwargs)
            if result is not NotImplemented:
                return result
            # except NotImplementedError as e:
            #     errors.append(e)
            #     continue
            # except Exception as e:
            #     print(f"Error happened at {i}")
            #     raise e
        raise TypeError(f"No matching handler found for arguments {args} {kwargs}. Errors: {errors}")

    def __add__(self, other: Callable[..., E] | FnList) -> FnList:
        if isinstance(other, FnList):
            return FnList(self.fns + other.fns)
        return FnList(self.fns + [other])

    def product(self, *other: FnList) -> FnList:
        zip_list = [self] + list(other)
        def zipped_fn(*args, **kwargs):
            results = (None,) * len(zip_list)
            for i in range(len(zip_list)):
                argsi = [args[i] for i in range(len(args))]
                kwargsi = {k: kwargs[k][i] for k in kwargs}
                results[i] = zip_list[i](*argsi, **kwargsi)
            return tuple(results)
        return FnList(zipped_fn)

class FnListChain:

    def __init__(self, *fli: Callable[..., E]):
        self.fn_list = fli
    
    def __call__(self, *args: E, **kwargs) -> E:
        result = FnList._call(self.fn_list[0], *args, **kwargs)
        for fn in self.fn_list[1:]:
            result = FnList._call(fn, result, **kwargs)
        return result

# =====================================================================
# Interpreter
# =====================================================================

[docs] class Interpreter(Generic[E]):
[docs] def __init__(self, dispatcher: dict[str, Callable[..., E]]): """Initialize an interpreter with a dispatcher. The dispatcher maps ONNX operator names to handler functions. Keys are the ONNX ``op_type`` strings (e.g. ``"Gemm"``, ``"Relu"``, ``"Reshape"``). Custom-domain ops (e.g. ``"DiffPair"`` from the ``boundlab`` domain) are also keyed by bare ``op_type``. """ if isinstance(dispatcher, Interpreter): self.dispatcher = {k: FnList(v) for k, v in dispatcher.dispatcher.items()} else: self.dispatcher = {k: FnList(v) for k, v in dispatcher.items()}
def __getitem__(self, key) -> FnList[E]: return self.dispatcher[key] def __setitem__(self, key, value): if isinstance(value, FnList): for fn in value.fns: self.register(key, fn) else: self.register(key, value)
[docs] def register(self, key: str, value: Callable[..., E]): """Register a handler for an operator.""" assert callable(value), "Handler must be callable" if key in self.dispatcher: self.dispatcher[key].fns.append(value) else: self.dispatcher[key] = FnList(value)
def __contains__(self, key) -> bool: return key in self.dispatcher
[docs] def items(self): return self.dispatcher.items()
def __or__(self, other: Interpreter | dict[str, Callable[..., E]]) -> Interpreter: result = Interpreter(self.dispatcher).deepcopy() result |= other return result def __ior__(self, other: Interpreter | dict[str, Callable[..., E]]): other = other if isinstance(other, Interpreter) else Interpreter(other) for k, v in other.items(): for fn in v.fns: self.register(k, fn) return self
[docs] def product(self, *other: Interpreter) -> Interpreter: """Return a new interpreter that produces tuples of results from this and other interpreters.""" return Interpreter({k: v.product(*[o[k] for o in other]) for k, v in self.dispatcher.items()})
[docs] def and_then(self, other: Callable[[E], E], with_op_name: bool = False) -> Interpreter: """Return a new interpreter that applies another function to the output of this one.""" result = {} class WithOpNameWrapper: def __init__(self, fn, op_name): self.fn = fn self.op_name = op_name def __call__(self, *args, **kwargs): return self.fn(*args, op_name=self.op_name, **kwargs) for k, v in self.dispatcher.items(): if with_op_name: result[k] = FnListChain(v, WithOpNameWrapper(other, k)) else: result[k] = FnListChain(v, other) return Interpreter(result)
[docs] def __call__( self, model: ir.Model | str | Path, verbose: bool = False, output_env: bool = False, ) -> Callable[..., E]: """Build an expression-level interpreter for an ONNX model. Parameters ---------- model: An ``onnx_ir.Model`` or a ``str`` / :class:`pathlib.Path` pointing to an ``.onnx`` file. The ONNX graph is walked in topological order (ONNX guarantees this). For each node: * Initializer inputs are wrapped as :class:`~torch.Tensor` and passed as positional arguments. * Optional/missing inputs (empty-string name) are passed as ``None``. * Node attributes are converted to Python scalars / lists and passed as keyword arguments. * The dispatcher is keyed on the bare ``op_type`` (domain is ignored); e.g. a custom ``boundlab::diff_pair`` node is dispatched as ``"DiffPair"``. Returns ------- A callable ``interpret(*exprs)`` that maps input :class:`~boundlab.expr.Expr` objects to output expression(s). Examples -------- >>> import torch, tempfile, os >>> from boundlab.interp import Interpreter, ONNX_BASE_INTERPRETER >>> from boundlab.zono import interpret >>> import boundlab.expr as expr """ if isinstance(model, torch.onnx.ONNXProgram): model = model.model elif isinstance(model, torch.export.ExportedProgram): model = torch.onnx.export(model, dynamo=True).model assert isinstance(model, (ir.Model, str, Path)), "Model must be an onnx_ir.Model, ExportedProgram, ONNXProgram, or file path" if isinstance(model, (str, Path)): model = ir.load(str(model)) initializers = { init.name: self.dispatcher["Initializer"]( torch.from_numpy(init.const_value.numpy().copy()) .to(torch.get_default_device()) ) for init in model.graph.initializers.values() } initializer_names = set(initializers.keys()) input_names = [ inp.name for inp in model.graph.inputs if inp.name not in initializer_names ] output_names = [out.name for out in model.graph.outputs] assert all(isinstance(v, FnList) for v in self.dispatcher.values()), \ "All handlers must be non-None." # Liveness analysis for env-tensor freeing. The abstract domains are # eager: each node's output is a fully-materialized expression holding # its own generator tensors, and ``env`` otherwise pins every one of # them alive for the whole pass (peak VRAM = sum over the entire graph, # even though most intermediates are dead). ``last_use[name]`` is the # index of the last node consuming ``name``; after that node runs we # drop ``env[name]`` so Python reclaims it (refcount -> 0 once no live # expression still references it). This changes no numbers — only when # memory is released — and turns peak VRAM into the live-set maximum. nodes = list(model.graph) last_use: dict[str, int] = {} for idx, node in enumerate(nodes): for inp in node.inputs: if inp is not None and inp.name: last_use[inp.name] = idx keep_alive = set(output_names) def interpret(*exprs: E) -> E | tuple[E, ...]: env: dict[str, Any] = {} # Freeing must be disabled when the caller wants the full env back. free_dead = not output_env for name, e in zip(input_names, exprs): assert e is not None, name env[name] = self.dispatcher["Input"](e) for idx, node in enumerate(nodes): args = [] for inp in node.inputs: if inp is None: args.append(None) continue inp_name = inp.name if inp_name in env: assert env[inp_name] is not None, inp_name args.append(env[inp_name]) elif inp_name in initializers: assert initializers[inp_name] is not None, inp_name args.append(initializers[inp_name]) else: raise KeyError( f"Input '{inp_name}' not found for node " f"'{node.op_type}' ({node.name!r})" ) kwargs = { name: _onnx_attr_value(attr) for name, attr in node.attributes.items() } kwargs["node_name"] = node.name # Dispatch on op_type (ignore domain) def to_repr(x: Any) -> str: from boundlab.diff.expr import DiffExpr2 if isinstance(x, torch.Tensor): return f"Tensor{list(x.shape)}({x.abs().max().item():.4g})" return repr(x) if verbose: outputs = ", ".join("%" + node.name for node in node.outputs if node is not None) inputs = ", ".join("%" + node.name for node in node.inputs if node is not None) kwargs_str = ", ".join(f"{k}={to_repr(v)}" for k, v in kwargs.items()) if kwargs_str and inputs: kwargs_str = ", " + kwargs_str print(f"{outputs} = {node.op_type}({inputs}{kwargs_str})") handler = self.dispatcher[node.op_type] result = handler(*args, **kwargs) if verbose: print(f"-> {to_repr(result)}") assert not isinstance(result, tuple), f"Handler for {node.op_type} returned a tuple, but only single outputs are supported. Got: {result}" # Bind outputs if len(node.outputs) == 1: out = node.outputs[0] if out is not None and out.name: env[out.name] = result else: for i, out in enumerate(node.outputs): if out is not None and out.name: env[out.name] = result[i] # Release env tensors whose last consumer was this node. We only # touch ``env`` (intermediate expressions + graph inputs); # initializers live in their own dict and stay resident, and # graph outputs are pinned by ``keep_alive``. ``del`` merely # drops env's reference — anything still needed downstream # survives via the consuming expression's own reference. if free_dead: for inp in node.inputs: if (inp is not None and inp.name and inp.name not in keep_alive and last_use.get(inp.name) == idx and inp.name in env): del env[inp.name] outputs = [env[name] for name in output_names] if output_env: return (outputs[0] if len(outputs) == 1 else tuple(outputs)), env return outputs[0] if len(outputs) == 1 else tuple(outputs) return interpret
# ===================================================================== # ONNX base interpreter # ===================================================================== ONNX_BASE_INTERPRETER = Interpreter({ "Input": lambda x, **_: x, "Initializer": lambda x, **_: x, # ---- arithmetic (with broadcast) -------------------------------------- "Add": lambda X, Y: (lambda a, b: a + b)(*_onnx_broadcast(X, Y)), "Sub": lambda X, Y: (lambda a, b: a - b)(*_onnx_broadcast(X, Y)), "Neg": lambda X: -X, "Mul": lambda X, Y: (lambda a, b: a * b)(*_onnx_broadcast(X, Y)), "Div": lambda X, Y: (lambda a, b: a / b)(*_onnx_broadcast(X, Y)), # ---- linear layers ------------------------------------------------ "Gemm": _onnx_gemm, "MatMul": lambda A, B: A @ B, "Einsum": _onnx_einsum, "Conv": _onnx_conv, # ---- shape ops ---------------------------------------------------- "Reshape": lambda data, shape, **_: data.reshape(_unwrap_shape(_as_const(shape))), "Flatten": _onnx_flatten, "Transpose": lambda data, perm=None: (data.permute(*perm) if perm is not None else data.T), "Unsqueeze": lambda data, axes: _onnx_unsqueeze(data, _as_const(axes)), "Squeeze": lambda data, axes=None: _onnx_squeeze(data, _as_const(axes) if axes is not None else None), "Gather": _onnx_gather, "Slice": _onnx_slice, "Concat": _onnx_concat, "Identity": lambda X: X, "Cast": _onnx_cast, "Expand": _onnx_expand, # ---- reductions --------------------------------------------------- "ReduceSum": _onnx_reduce_sum, "ReduceMean": _onnx_reduce_mean, # ---- constants --------------------------------------------- "Constant": _onnx_constant, "Reciprocal": lambda X: 1 / X, })