Source code for boundlab.expr._cat
r"""Concatenation and Stacking Operations
This module provides expressions for concatenating and stacking
child expressions along specified dimensions.
"""
from typing import Literal
import torch
from boundlab.expr._core import Expr, ExprFlags
from boundlab.expr._affine import AffineSum
[docs]
def Cat(*children: Expr, dim: int = 0) -> AffineSum:
"""Concatenate child expressions along a dimension.
Implemented as an :class:`AffineSum` whose per-child operator is a
:class:`~boundlab.linearop.SetSliceOp` that embeds the child into the
correct slice of the full output shape.
"""
from boundlab.linearop import SetSliceOp
assert all(isinstance(c, Expr) for c in children), "All children of Cat must be Expr instances."
assert len(children) >= 1, "Cat requires at least one child."
if dim < 0:
dim += len(children[0].shape)
assert all(
c.shape[:dim] == children[0].shape[:dim] and c.shape[dim + 1:] == children[0].shape[dim + 1:]
for c in children
), "All children must have matching shapes except along the concatenation dimension."
cat_size = sum(c.shape[dim] for c in children)
out_shape = list(children[0].shape)
out_shape[dim] = cat_size
out_shape = torch.Size(out_shape)
pairs = []
offset = 0
for c in children:
size = c.shape[dim]
slices = [[slice(0, s)] for s in out_shape]
slices[dim] = [slice(offset, offset + size)]
pairs.append((SetSliceOp(out_shape, slices), c))
offset += size
return AffineSum(*pairs)
[docs]
class Stack(Expr):
"""Expression for stacking child expressions along a new dimension.
All children must have identical shapes. The backward pass produces
an embed LinearOp per child that places the child at its index along
the stacking dimension, with zeros elsewhere.
"""
[docs]
def __init__(self, *children: Expr, dim: int = 0):
assert all(isinstance(c, Expr) for c in children), "All children of Stack must be Expr instances."
super().__init__(ExprFlags.IS_AFFINE)
assert len(children) >= 1, "Stack requires at least one child."
assert all(
c.shape == children[0].shape for c in children
), "All children must have the same shape for Stack."
self._children = tuple(children)
self.dim = dim
s = list(children[0].shape)
s.insert(dim, len(children))
self._shape = torch.Size(s)
@property
def shape(self) -> torch.Size:
return self._shape
@property
def children(self) -> tuple[Expr, ...]:
return self._children
[docs]
def with_children(self, *new_children: Expr) -> "Stack":
return Stack(*new_children, dim=self.dim)
[docs]
def backward(self, weights, direction: Literal[">=", "<=", "=="] = "=="): # noqa: ARG002
"""Propagate weights to each child via unsqueeze+cat embed ops.
Args:
weights: A :class:`~boundlab.linearop.EinsumOp` accumulated weight.
direction: Unused (Stack is always linear).
Returns:
``(0, [child_weight_0, child_weight_1, ...])``
"""
from boundlab.linearop import PadOp, UnsqueezeOp, ComposedOp
n = len(self._children)
child_ops = []
for i, child in enumerate(self._children):
# unsqueeze to add the stack dim, then pad to fill the full stack size
unsq = UnsqueezeOp(child.shape, self.dim)
ndim = len(self._shape)
pad_spec = [0] * (2 * ndim)
d_rev = ndim - 1 - self.dim
pad_spec[2 * d_rev] = i
pad_spec[2 * d_rev + 1] = n - i - 1
pad = PadOp(unsq.output_shape, pad_spec)
embed_op = ComposedOp(pad, unsq)
child_ops.append(weights @ embed_op)
return (0, child_ops)
[docs]
def to_string(self, *children_str: str) -> str:
return f"stack([{', '.join(children_str)}], dim={self.dim})"