Source code for boundlab.diff.net

"""Build a paired ONNX model for differential interpretation.

``diff_net`` merges two structurally-identical ONNX models by inserting
``boundlab::diff_pair`` nodes that pair the weights from both networks before
each shared ``Gemm`` layer.  At concrete-tensor runtime the model is
equivalent to *net1*; when interpreted by
:data:`~boundlab.diff.zono3.interpret`, the ``diff_pair`` nodes create
:class:`~boundlab.diff.expr.DiffExpr2` expressions that propagate both
branches through all subsequent operations.
"""

from __future__ import annotations

import copy
from itertools import zip_longest
from pathlib import Path

import onnx_ir as ir


def _load_onnx(model) -> ir.Model:
    """Load an ONNX model from a path or pass through an existing onnx_ir.Model."""
    if isinstance(model, (str, Path)):
        return ir.load(str(model))
    if isinstance(model, ir.Model):
        return model
    raise TypeError(f"Expected path or onnx_ir.Model, got {type(model)}")


def _value_ref(name: str) -> ir.Value:
    return ir.Value(name=name)


def _new_initializer(name: str, source: ir.Value) -> ir.Value:
    return ir.Value(
        name=name,
        shape=copy.deepcopy(source.shape),
        type=copy.deepcopy(source.type),
        doc_string=source.doc_string,
        const_value=copy.deepcopy(source.const_value),
        metadata_props=copy.deepcopy(source.metadata_props),
    )


def _clone_node_detached(node: ir.Node) -> ir.Node:
    return ir.Node(
        node.domain,
        node.op_type,
        [(_value_ref(v.name) if v is not None else None) for v in node.inputs],
        attributes={k: copy.deepcopy(v) for k, v in node.attributes.items()},
        outputs=[ir.Value(name=v.name) for v in node.outputs],
        overload=node.overload,
        version=node.version,
        name=node.name,
        doc_string=node.doc_string,
        metadata_props=copy.deepcopy(node.metadata_props),
    )


def _node_input_name(node: ir.Node, idx: int) -> str:
    v = node.inputs[idx]
    assert v is not None and v.name is not None
    return v.name


[docs] def diff_net( net1: ir.Model | str | Path, net2: ir.Model | str | Path, ) -> ir.Model: """Merge two structurally-identical ONNX models by pairing initializer inputs. For each aligned node pair in ``net1``/``net2``, any input position where both sides consume initializers is replaced by a shared ``boundlab::diff_pair`` value in the merged graph. This generalizes beyond affine operators (e.g. ``Gemm`` / ``MatMul``) to any operator that reads parameters as initializers. At concrete runtime, ``diff_pair`` is a no-op that returns branch-1 values (net1 semantics). Under differential interpretation, the paired values are lifted into :class:`~boundlab.diff.expr.DiffExpr2`. Parameters ---------- net1, net2: ONNX models with the same graph structure. Each may be an :class:`onnx_ir.Model`, a ``str``, or a :class:`pathlib.Path`. Returns ------- A merged :class:`onnx_ir.Model`. Examples -------- >>> from boundlab.diff.net import diff_net >>> from pathlib import Path >>> nets = Path("compare/veridiff/examples/nets") >>> merged = diff_net(nets / "single_layer_ref.onnx", nets / "single_layer_alt.onnx") >>> any(n.domain == "boundlab" and n.op_type == "diff_pair" for n in merged.graph) True """ net1 = _load_onnx(net1) net2 = _load_onnx(net2) merged = net1.clone() # Build initializer dicts (name -> Value). init1 = merged.graph.initializers init2 = net2.graph.initializers # Align nodes by topological index. nodes1 = list(merged.graph) nodes2 = list(net2.graph) if len(nodes1) != len(nodes2): raise ValueError( f"Networks have different numbers of nodes: {len(nodes1)} vs {len(nodes2)}" ) new_nodes: list[ir.Node] = [] new_initializer_entries: dict[str, ir.Value] = {} # Cache: net2 initializer name -> cloned initializer name in merged model. cloned_init2_name: dict[str, str] = {} # Cache: (net1 init name, net2 init name) -> paired value name. paired_name: dict[tuple[str, str], str] = {} def _paired_input_value(name1: str, name2: str) -> ir.Value: key = (name1, name2) if key in paired_name: return _value_ref(paired_name[key]) if name2 not in cloned_init2_name: new_name = f"_diff_init2_{len(cloned_init2_name)}" new_initializer_entries[new_name] = _new_initializer(new_name, init2[name2]) cloned_init2_name[name2] = new_name else: new_name = cloned_init2_name[name2] out_name = f"_diff_paired_{len(paired_name)}" paired_name[key] = out_name new_nodes.append( ir.Node( "boundlab", "diff_pair", [_value_ref(name1), _value_ref(new_name)], outputs=[ir.Value(name=out_name)], ) ) return _value_ref(out_name) for idx, (node, node2) in enumerate(zip(nodes1, nodes2)): if node.op_type != node2.op_type or node.domain != node2.domain: raise ValueError( f"Node mismatch at index {idx}: " f"({node.domain}::{node.op_type}) vs ({node2.domain}::{node2.op_type})" ) if len(node.inputs) != len(node2.inputs): raise ValueError( f"Node input-arity mismatch at index {idx}: " f"{len(node.inputs)} vs {len(node2.inputs)}" ) remapped_inputs: list[ir.Value | None] = [] for inp1, inp2 in zip_longest(node.inputs, node2.inputs): if inp1 is None: remapped_inputs.append(None) continue name1 = inp1.name # Keep original input when net2 optional input is missing. if inp2 is None or inp2.name is None: remapped_inputs.append(_value_ref(name1)) continue name2 = inp2.name if name1 in init1 and name2 in init2: remapped_inputs.append(_paired_input_value(name1, name2)) else: remapped_inputs.append(_value_ref(name1)) new_nodes.append( ir.Node( node.domain, node.op_type, remapped_inputs, attributes={k: copy.deepcopy(v) for k, v in node.attributes.items()}, outputs=[ir.Value(name=o.name) for o in node.outputs], overload=node.overload, version=node.version, name=node.name, doc_string=node.doc_string, metadata_props=copy.deepcopy(node.metadata_props), ) ) # Rewrite graph in place on the cloned model. merged.graph.remove(list(merged.graph), safe=False) merged.graph.extend(new_nodes) merged.graph.initializers.update(new_initializer_entries) if paired_name: merged.graph.opset_imports["boundlab"] = max(merged.graph.opset_imports.get("boundlab", 0), 1) merged.opset_imports["boundlab"] = max(merged.opset_imports.get("boundlab", 0), 1) return merged
__all__ = ["diff_net"]