Source code for boundlab.expr._affine

from __future__ import annotations

r"""Linear Operations for Expressions

This module provides ``AffineSum``, a fused expression class that represents
a sum of EinsumOp-weighted children: Σ_i op_i(child_i).
It replaces separate linear-sequence and add-node structures from the
previous design.
"""

import sys
from typing import Literal

import torch

from boundlab.expr._core import Expr, ExprFlags
from boundlab.linearop import EinsumOp, LinearOp

[docs] class AffineSum(Expr): r"""An expression representing a sum of linear operations applied to children. Represents :math:`\sum_i \mathrm{op}_i(x_i)` where each :math:`\mathrm{op}_i` is a :class:`~boundlab.linearop.EinsumOp`. During construction, if a child is itself an :class:`AffineSum`, its pairs are absorbed by composing the outer op with each inner op via ``@`` (eager contraction). This ensures the expression tree is always flat — no :class:`AffineSum` node ever has an :class:`AffineSum` child. Attributes: pairs: List of ``(op, child)`` tuples. ops: List of EinsumOp operators (convenience view). """
[docs] def __new__(cls, *pairs: tuple, const=None, **_kw): if cls is not AffineSum: # ConstVal (and other subclasses) construct themselves directly. return object.__new__(cls) if len(pairs) == 0 or all(isinstance(child, ConstVal) for _, child in pairs): # All-constant result → return a ConstVal shell; # AffineSum.__init__ will populate .constant, and we sync .value below. return object.__new__(ConstVal) return object.__new__(AffineSum)
[docs] def __init__(self, *pairs: tuple, const=None): """Construct an AffineSum. Args: *pairs: Sequence of ``(op, child)`` pairs where ``op`` is a :class:`~boundlab.linearop.EinsumOp` and ``child`` is an :class:`Expr` or :class:`torch.Tensor`. """ super().__init__(ExprFlags.IS_AFFINE) self.constant = const # Pre-process before allocating ID so ConstVal wrappers get lower IDs. self.children_dict: dict[Expr, LinearOp] = {} for op, child in pairs: assert isinstance(child, Expr), "Tuple expressions are not supported as children of AffineSum; use multiple arguments instead." if isinstance(child, torch.Tensor): self._add_constant(op.forward(child)) elif isinstance(child, AffineSum): # Distribute op through child's pairs: (op ∘ child_op_i, grandchild_i) if child.constant is not None: self._add_constant(op.forward(child.constant)) for grandchild, child_op in child.children_dict.items(): self._add_expr(op @ child_op, grandchild) else: self._add_expr(op, child) output_shapes = {op.output_shape for op in self.children_dict.values()} if self.constant is not None: output_shapes.add(self.constant.shape) assert len(output_shapes) == 1, \ f"All ops must share the same output shape; got {output_shapes}." self._shape = output_shapes.pop() # Propagate flags if self.children_dict: if all(ExprFlags.SYMMETRIC_TO_0 in child.flags for child in self.children_dict.keys()): self.flags |= ExprFlags.SYMMETRIC_TO_0 if all(child.flags & ExprFlags.IS_CONST for child in self.children_dict.keys()): self.flags |= ExprFlags.IS_CONST else: self.flags |= ExprFlags.IS_CONST
def _add_constant(self, const: torch.Tensor): """Accumulate a constant term into this AffineSum.""" if const is not None: self.constant = self.constant + const if self.constant is not None else const def _add_expr(self, op: LinearOp, child: Expr): """Accumulate an ``(op, child)`` contribution into this AffineSum.""" if child in self.children_dict: # If child already exists, compose the ops: old_op + op old_op = self.children_dict[child] new_op = old_op + op self.children_dict[child] = new_op else: self.children_dict[child] = op @property def shape(self) -> torch.Size: return self._shape @property def children(self) -> tuple[Expr, ...]: return tuple(self.children_dict.keys())
[docs] def with_children(self, *new_children: Expr) -> "AffineSum": """Return a new AffineSum with the same ops but new children.""" return AffineSum(*zip(self.children_dict.values(), new_children))
[docs] def backward(self, weights, direction: Literal[">=", "<=", "=="]) \ -> tuple: """Propagate weights backward: each child gets weights ∘ op_i. Args: weights: A :class:`~boundlab.linearop.EinsumOp` accumulated weight. direction: Bound direction (unused — Linear is always linear). Returns: ``(bias, [weights @ op_i for op_i in self.children_dict.values()])``. """ bias = 0 if self.constant is not None: bias = weights.forward(self.constant) return (bias, [weights @ op for op in self.children_dict.values()])
[docs] def to_string(self, *children_str: str) -> str: parts = [f"{op}({cs})" for op, cs in zip(self.children_dict.values(), children_str)] if self.constant is not None: parts.append(f"<Const>") return " + ".join(parts)
[docs] def jacobian_ops_(self): self.children_dict = {child: op.jacobian_op() for child, op in self.children_dict.items()}
[docs] class ConstVal(AffineSum): """Expression representing a constant tensor value. Implemented as an AffineSum with no children and only a constant term. When used as a child of another AffineSum, the constant is automatically absorbed via eager contraction. """
[docs] def __init__(self, value=None, name=None, *_pairs, const=None): # Three call patterns: # 1. ConstVal(tensor[, name]) — direct construction # 2. ConstVal(const=tensor) — from AffineSum(const=x), no pairs # 3. ConstVal((op,ch), ..., const=x) — from AffineSum(*pairs, const=x), # all-ConstVal children; value is # the first pair tuple, _pairs are the rest if isinstance(value, tuple) or _pairs: # Pattern 3: routed from AffineSum with pairs all_pairs = ((value,) + _pairs) if value is not None else _pairs AffineSum.__init__(self, *all_pairs, const=const) else: # Pattern 1 or 2 actual = value if const is None else const AffineSum.__init__(self, const=actual) self.value = self.constant self.name = name
[docs] def to_string(self) -> str: if self.name is not None: return f"#const {self.name}" return f"#const <{self.id:X}>"
[docs] def get_const(self): if self.value is None: return 0 else: return self.value
[docs] def __add__(self, other): if isinstance(other, torch.Tensor): return ConstVal(self.get_const() + other) elif isinstance(other, ConstVal): return ConstVal(self.get_const() + other.get_const()) return super().__add__(other)
def __radd__(self, other): if isinstance(other, torch.Tensor): return ConstVal(other + self.get_const()) if isinstance(other, ConstVal): return ConstVal(other.get_const() + self.get_const()) return super().__radd__(other) def __neg__(self): return ConstVal(-self.get_const()) def __sub__(self, other): if isinstance(other, torch.Tensor): return ConstVal(self.get_const() - other) elif isinstance(other, ConstVal): return ConstVal(self.get_const() - other.get_const()) return super().__sub__(other) def __rsub__(self, other): if isinstance(other, torch.Tensor): return ConstVal(other - self.get_const()) elif isinstance(other, ConstVal): return ConstVal(other.get_const() - self.get_const()) return super().__rsub__(other) def __truediv__(self, other): if isinstance(other, torch.Tensor): return ConstVal(self.get_const() / other) elif isinstance(other, ConstVal): return ConstVal(self.get_const() / other.get_const()) return super().__truediv__(other) def __rtruediv__(self, other): if isinstance(other, torch.Tensor): return ConstVal(other / self.get_const()) elif isinstance(other, ConstVal): return ConstVal(other.get_const() / self.get_const()) return super().__rtruediv__(other) def __abs__(self): return ConstVal(abs(self.get_const()))
[docs] def __mul__(self, other): if isinstance(other, torch.Tensor): return ConstVal(self.get_const() * other) elif isinstance(other, ConstVal): return ConstVal(self.get_const() * other.get_const()) return super().__mul__(other)
def __rmul__(self, other): if isinstance(other, torch.Tensor): return ConstVal(other * self.get_const()) if isinstance(other, ConstVal): return ConstVal(other.get_const() * self.get_const()) return super().__rmul__(other) def __matmul__(self, other): if isinstance(other, torch.Tensor): return ConstVal(self.get_const() @ other) elif isinstance(other, ConstVal): return ConstVal(self.get_const() @ other.get_const()) return super().__matmul__(other) def __rmatmul__(self, other): if isinstance(other, torch.Tensor): return ConstVal(other @ self.get_const()) elif isinstance(other, ConstVal): return ConstVal(other.get_const() @ self.get_const()) return super().__rmatmul__(other) def _apply_op(self, op): return ConstVal(op.forward(self.get_const()))