Source code for boundlab.linearop

r"""Linear Operator Library for Expression Backpropagation.

This module defines linear operators used by BoundLab expressions during
symbolic transformation and backward bound propagation.

Key operators:

- :class:`ComposedOp`: Functional composition of linear maps
  (:math:`A \circ B`), used to chain transformations efficiently.
- :class:`SumOp`: Pointwise sum of linear maps, used when multiple affine
  contributions target the same expression.
- :class:`EinsumOp`: General tensor-linear map based on Einstein summation;
  this is the most flexible primitive for dense affine transformations.

The module also exposes shape/indexing operators (reshape, permute, gather,
scatter, slicing, padding) that are all represented as :class:`LinearOp`
instances and can therefore be composed, summed, and propagated uniformly.

Examples
--------
Apply a shape operator to a concrete tensor:

>>> import torch
>>> from boundlab.linearop import ReshapeOp
>>> op = ReshapeOp(torch.Size([2, 3]), (3, 2))
>>> y = op.forward(torch.arange(6.0).reshape(2, 3))
>>> y.shape
torch.Size([3, 2])
"""

from ._base import LinearOp, ComposedOp, SumOp, ScalarOp, ZeroOp
from ._einsum import EinsumOp
from ._shape import (
    ReshapeOp,
    FlattenOp,
    UnflattenOp,
    PermuteOp,
    TransposeOp,
    SqueezeOp,
    UnsqueezeOp,
    ExpandOp,
    RepeatOp,
    TileOp,
    FlipOp,
    RollOp,
    DiagOp,
)
from ._indices import (
    # Core indexing ops
    GatherOp,
    ScatterOp,
    GetIndicesOp,
    SetIndicesOp,
    GetSliceOp,
    SetSliceOp,
    # Convenience functions
    narrow_indices,
    select_indices,
    pad_indices,
    pad_output_shape,
)


# ---------------------------------------------------------------------------
# Backwards-compatible aliases (convenience wrappers)
# ---------------------------------------------------------------------------

[docs] class NarrowOp(GetSliceOp): """Select a contiguous slice along *dim*. (Alias for GetSliceOp) Examples -------- >>> import torch >>> from boundlab.linearop import NarrowOp >>> op = NarrowOp(torch.Size([5]), dim=0, start=1, length=3) >>> op.forward(torch.tensor([0., 1., 2., 3., 4.])) tensor([1., 2., 3.]) """
[docs] def __init__(self, input_shape, dim: int, start: int, length: int): indices = narrow_indices(len(input_shape), dim, start, length) super().__init__(input_shape, indices) self.dim = dim self.start = start self.length = length
def __str__(self): return f"narrow({self.dim}, {self.start}, {self.length})"
[docs] class SelectOp(GetSliceOp): """Select a single index along *dim*, removing that dimension. (Alias for GetSliceOp) Examples -------- >>> import torch >>> from boundlab.linearop import SelectOp >>> op = SelectOp(torch.Size([2, 3]), dim=0, index=1) >>> op.forward(torch.tensor([[1., 2., 3.], [4., 5., 6.]])) tensor([4., 5., 6.]) """
[docs] def __init__(self, input_shape, dim: int, index: int): indices = select_indices(len(input_shape), dim, index) super().__init__(input_shape, indices) self.dim = dim self.index = index
def __str__(self): return f"select({self.dim}, {self.index})"
[docs] class GetItemOp(GetSliceOp): """Indexing / slicing via ``x[indices]``. (Alias for GetSliceOp).""" def __str__(self): from ._indices import _format_indices return f"getitem({_format_indices(self.indices)})"
[docs] class PadOp(SetSliceOp): """Zero-pad an input tensor. (Alias for SetSliceOp) Examples -------- >>> import torch >>> from boundlab.linearop import PadOp >>> op = PadOp(torch.Size([3]), [1, 2]) >>> op.forward(torch.tensor([1., 2., 3.])) tensor([0., 1., 2., 3., 0., 0.]) """
[docs] def __init__(self, input_shape, pad_spec: list[int]): self._pad_spec = list(pad_spec) output_shape = pad_output_shape(input_shape, pad_spec) indices = pad_indices(input_shape, pad_spec) super().__init__(indices, input_shape, output_shape)
def __str__(self): return f"pad({self._pad_spec})"
__all__ = [ # Base classes "LinearOp", "ComposedOp", "SumOp", "ScalarOp", "ZeroOp", # Einsum "EinsumOp", # Shape ops "ReshapeOp", "FlattenOp", "UnflattenOp", "PermuteOp", "TransposeOp", "SqueezeOp", "UnsqueezeOp", "ExpandOp", "RepeatOp", "TileOp", "FlipOp", "RollOp", "DiagOp", # Indexing ops (general) "GatherOp", "ScatterOp", "GetIndicesOp", "SetIndicesOp", "GetSliceOp", "SetSliceOp", # Convenience aliases "NarrowOp", "SelectOp", "GetItemOp", "PadOp", # Utility functions "narrow_indices", "select_indices", "pad_indices", "pad_output_shape", ]