r"""Linear Operator Library for Expression Backpropagation.
This module defines linear operators used by BoundLab expressions during
symbolic transformation and backward bound propagation.
Key operators:
- :class:`ComposedOp`: Functional composition of linear maps
(:math:`A \circ B`), used to chain transformations efficiently.
- :class:`SumOp`: Pointwise sum of linear maps, used when multiple affine
contributions target the same expression.
- :class:`EinsumOp`: General tensor-linear map based on Einstein summation;
this is the most flexible primitive for dense affine transformations.
The module also exposes shape/indexing operators (reshape, permute, gather,
scatter, slicing, padding) that are all represented as :class:`LinearOp`
instances and can therefore be composed, summed, and propagated uniformly.
Examples
--------
Apply a shape operator to a concrete tensor:
>>> import torch
>>> from boundlab.linearop import ReshapeOp
>>> op = ReshapeOp(torch.Size([2, 3]), (3, 2))
>>> y = op.forward(torch.arange(6.0).reshape(2, 3))
>>> y.shape
torch.Size([3, 2])
"""
import torch
from ._base import LinearOp, ComposedOp, SumOp, ScalarOp, ZeroOp
from ._einsum import EinsumOp
# Reshape ops
from ._reshape import (
ReshapeOp,
FlattenOp,
UnflattenOp,
SqueezeOp,
UnsqueezeOp,
)
# Permute ops
from ._permute import PermuteOp, TransposeOp
# Expand
from ._expand import ExpandOp
# Remaining shape ops
from ._shape import (
RepeatOp,
TileOp,
FlipOp,
RollOp,
DiagOp,
)
# Slicing ops (new structured API)
from ._slicing import GetSliceOp, SetSliceOp
# Indexing ops (new dim-based API)
from ._indexing import GetIndicesOp, SetIndicesOp
# Gather/Scatter and convenience functions
from ._indices import (
GatherOp,
ScatterOp,
narrow_indices,
select_indices,
pad_indices,
pad_output_shape,
make_get_slices,
make_set_slices,
get_int_dims,
)
# ---------------------------------------------------------------------------
# Backwards-compatible aliases (convenience wrappers)
# ---------------------------------------------------------------------------
[docs]
class NarrowOp(GetSliceOp):
"""Select a contiguous slice along *dim*. (Alias for GetSliceOp)
Examples
--------
>>> import torch
>>> from boundlab.linearop import NarrowOp
>>> op = NarrowOp(torch.Size([5]), dim=0, start=1, length=3)
>>> op.forward(torch.tensor([0., 1., 2., 3., 4.]))
tensor([1., 2., 3.])
"""
[docs]
def __init__(self, input_shape, dim: int, start: int, length: int):
ndim = len(input_shape)
slices = [[slice(0, input_shape[d])] for d in range(ndim)]
slices[dim] = [slice(start, start + length)]
super().__init__(input_shape, slices)
self.dim = dim
self.start = start
self.length = length
def __str__(self):
return f"<narrow {self.dim} {self.start} {self.length}>"
[docs]
class SelectOp(LinearOp):
"""Select a single index along *dim*, removing that dimension.
Implemented as GetSliceOp (length-1 slice) composed with SqueezeOp.
Examples
--------
>>> import torch
>>> from boundlab.linearop import SelectOp
>>> op = SelectOp(torch.Size([2, 3]), dim=0, index=1)
>>> op.forward(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
tensor([4., 5., 6.])
"""
[docs]
def __init__(self, input_shape, dim: int, index: int):
from ._base import LinearOpFlags
self.dim = dim
self.index = index
if index < 0:
index += input_shape[dim]
ndim = len(input_shape)
slices = [[slice(0, input_shape[d])] for d in range(ndim)]
slices[dim] = [slice(index, index + 1)]
self._slice_op = GetSliceOp(input_shape, slices)
self._squeeze_op = SqueezeOp(self._slice_op.output_shape, dim=dim)
output_shape = self._squeeze_op.output_shape
super().__init__(input_shape, output_shape,
flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING | LinearOpFlags.IS_PURE_CONTRACTING)
[docs]
def forward(self, x):
return self._squeeze_op.forward(self._slice_op.forward(x))
[docs]
def backward(self, grad):
return self._slice_op.backward(self._squeeze_op.backward(grad))
[docs]
def vforward(self, x):
return self._squeeze_op.vforward(self._slice_op.vforward(x))
[docs]
def vbackward(self, grad):
return self._slice_op.vbackward(self._squeeze_op.vbackward(grad))
def __matmul__(self, other):
composed = ComposedOp(self._squeeze_op, self._slice_op)
return composed @ other
def __rmatmul__(self, other):
composed = ComposedOp(self._squeeze_op, self._slice_op)
return other @ composed
def __str__(self):
return f"<select {self.dim}, {self.index}>"
[docs]
class GetItemOp:
"""Indexing / slicing via ``x[indices]``.
Factory: returns ``ReshapeOp @ GetSliceOp`` (or just ``GetSliceOp`` when
no integer indices remove a dim). Slice extracts the region; reshape
drops the size-1 dims for any int indices.
"""
[docs]
def __new__(cls, input_shape, indices):
int_dims = get_int_dims(indices)
slices = make_get_slices(input_shape, indices)
slice_op = GetSliceOp(input_shape, slices)
if not int_dims:
return slice_op
sliced_shape = slice_op.output_shape
int_dims_set = set(int_dims)
target_shape = torch.Size(s for i, s in enumerate(sliced_shape) if i not in int_dims_set)
return ReshapeOp(sliced_shape, target_shape) @ slice_op
[docs]
class PadOp(SetSliceOp):
"""Zero-pad an input tensor. (Alias for SetSliceOp)
Examples
--------
>>> import torch
>>> from boundlab.linearop import PadOp
>>> op = PadOp(torch.Size([3]), [1, 2])
>>> op.forward(torch.tensor([1., 2., 3.]))
tensor([0., 1., 2., 3., 0., 0.])
"""
[docs]
def __init__(self, input_shape, pad_spec: list[int]):
self._pad_spec = list(pad_spec)
output_shape = pad_output_shape(input_shape, pad_spec)
ndim = len(output_shape)
slices = []
for d in range(ndim):
d_rev = ndim - 1 - d
if 2 * d_rev + 1 < len(pad_spec):
pad_before = pad_spec[2 * d_rev]
slices.append([slice(pad_before, pad_before + input_shape[d])])
else:
slices.append([slice(0, output_shape[d])])
super().__init__(output_shape, slices)
def __str__(self):
return f"<pad {self._pad_spec}>"
__all__ = [
# Base classes
"LinearOp",
"ComposedOp",
"SumOp",
"ScalarOp",
"ZeroOp",
# Einsum
"EinsumOp",
# Reshape ops
"ReshapeOp",
"FlattenOp",
"UnflattenOp",
"SqueezeOp",
"UnsqueezeOp",
# Permute ops
"PermuteOp",
"TransposeOp",
# Expand
"ExpandOp",
# Other shape ops
"RepeatOp",
"TileOp",
"FlipOp",
"RollOp",
"DiagOp",
# Slicing ops
"GetSliceOp",
"SetSliceOp",
# Indexing ops
"GetIndicesOp",
"SetIndicesOp",
# Gather/Scatter
"GatherOp",
"ScatterOp",
# Convenience aliases
"NarrowOp",
"SelectOp",
"GetItemOp",
"PadOp",
# Utility functions
"narrow_indices",
"select_indices",
"pad_indices",
"pad_output_shape",
"make_get_slices",
"make_set_slices",
]