Source code for boundlab.expr._cat

r"""Concatenation and Stacking Operations

This module provides expressions for concatenating and stacking
child expressions along specified dimensions.
"""

from typing import Literal

import torch

from boundlab.expr._core import Expr, ExprFlags


[docs] class Cat(Expr): """Expression for concatenating child expressions along a dimension. The backward pass produces an embed LinearOp per child that zero-pads the child's contribution into the full cat output shape. The VJP of F.pad (narrow) is computed automatically. """
[docs] def __init__(self, *children: Expr, dim: int = 0): super().__init__(ExprFlags.IS_AFFINE) assert all(isinstance(c, Expr) for c in children), "All children of Cat must be Expr instances." assert len(children) >= 1, "Cat requires at least one child." assert all( c.shape[:dim] == children[0].shape[:dim] and c.shape[dim + 1:] == children[0].shape[dim + 1:] for c in children ), "All children must have matching shapes except along the concatenation dimension." self._children = tuple(children) self.dim = dim cat_size = sum(c.shape[dim] for c in children) s = list(children[0].shape) s[dim] = cat_size self._shape = torch.Size(s)
@property def shape(self) -> torch.Size: return self._shape @property def children(self) -> tuple[Expr, ...]: return self._children
[docs] def with_children(self, *new_children: Expr) -> "Cat": return Cat(*new_children, dim=self.dim)
[docs] def backward(self, weights, direction: Literal[">=", "<=", "=="] = "=="): # noqa: ARG002 """Propagate weights to each child via zero-padding embed ops. Args: weights: A :class:`~boundlab.linearop.EinsumOp` accumulated weight. direction: Unused (Cat is always linear). Returns: ``(0, [child_weight_0, child_weight_1, ...])`` """ from boundlab.linearop import PadOp child_ops = [] offset = 0 cat_size = self._shape[self.dim] for child in self._children: size = child.shape[self.dim] pad_before = offset pad_after = cat_size - offset - size ndim = len(child.shape) # F.pad spec: pairs in reverse dim order, (left, right) per dim pad_spec = [0] * (2 * ndim) pad_spec[2 * (ndim - 1 - self.dim)] = pad_before pad_spec[2 * (ndim - 1 - self.dim) + 1] = pad_after embed_op = PadOp(child.shape, pad_spec) child_ops.append(weights @ embed_op) offset += size return (0, child_ops)
[docs] def to_string(self, *children_str: str) -> str: return f"cat([{', '.join(children_str)}], dim={self.dim})"
[docs] class Stack(Expr): """Expression for stacking child expressions along a new dimension. All children must have identical shapes. The backward pass produces an embed LinearOp per child that places the child at its index along the stacking dimension, with zeros elsewhere. """
[docs] def __init__(self, *children: Expr, dim: int = 0): assert all(isinstance(c, Expr) for c in children), "All children of Stack must be Expr instances." super().__init__(ExprFlags.IS_AFFINE) assert len(children) >= 1, "Stack requires at least one child." assert all( c.shape == children[0].shape for c in children ), "All children must have the same shape for Stack." self._children = tuple(children) self.dim = dim s = list(children[0].shape) s.insert(dim, len(children)) self._shape = torch.Size(s)
@property def shape(self) -> torch.Size: return self._shape @property def children(self) -> tuple[Expr, ...]: return self._children
[docs] def with_children(self, *new_children: Expr) -> "Stack": return Stack(*new_children, dim=self.dim)
[docs] def backward(self, weights, direction: Literal[">=", "<=", "=="] = "=="): # noqa: ARG002 """Propagate weights to each child via unsqueeze+cat embed ops. Args: weights: A :class:`~boundlab.linearop.EinsumOp` accumulated weight. direction: Unused (Stack is always linear). Returns: ``(0, [child_weight_0, child_weight_1, ...])`` """ from boundlab.linearop import PadOp, UnsqueezeOp, ComposedOp n = len(self._children) child_ops = [] for i, child in enumerate(self._children): # unsqueeze to add the stack dim, then pad to fill the full stack size unsq = UnsqueezeOp(child.shape, self.dim) ndim = len(self._shape) pad_spec = [0] * (2 * ndim) d_rev = ndim - 1 - self.dim pad_spec[2 * d_rev] = i pad_spec[2 * d_rev + 1] = n - i - 1 pad = PadOp(unsq.output_shape, pad_spec) embed_op = ComposedOp(pad, unsq) child_ops.append(weights @ embed_op) return (0, child_ops)
[docs] def to_string(self, *children_str: str) -> str: return f"stack([{', '.join(children_str)}], dim={self.dim})"