Source code for boundlab.linearop._slicing

"""Slice-based indexing LinearOp implementations.

``GetSliceOp`` and ``SetSliceOp`` use a structured ``list[list[slice]]``
format where each dimension has a list of non-overlapping slices.
``len(input_shape) == len(slices)`` is enforced.
"""

import torch

from boundlab.linearop._base import LinearOp, LinearOpFlags, ComposedOp
from boundlab.utils import merge_name


def _normalize_slices(slices: list[list[slice]], shape: torch.Size) -> list[list[slice]]:
    """Normalize slices so each has concrete start/stop (no None)."""
    result = []
    for d, dim_slices in enumerate(slices):
        normalized = []
        for s in dim_slices:
            start, stop, step = s.indices(shape[d])
            assert step == 1, f"GetSliceOp only supports step=1, got step={step} in dim {d}"
            normalized.append(slice(start, stop))
        result.append(normalized)
    return result


def _output_size(dim_slices: list[slice]) -> int:
    """Total output size for a list of slices along one dimension."""
    return sum(s.stop - s.start for s in dim_slices)


def _is_full(dim_slices: list[slice], dim_size: int) -> bool:
    """Check if slices cover the full dimension."""
    return len(dim_slices) == 1 and dim_slices[0].start == 0 and dim_slices[0].stop == dim_size


[docs] class GetSliceOp(LinearOp): """Extract sliced regions from a tensor. Args: input_shape: Shape of the input tensor. slices: Per-dimension list of slices. ``len(slices) == len(input_shape)``. """
[docs] def __init__(self, input_shape: torch.Size, slices: list[list[slice]]): assert len(input_shape) == len(slices), \ f"len(input_shape)={len(input_shape)} != len(slices)={len(slices)}" self.slices = _normalize_slices(slices, input_shape) output_shape = torch.Size(_output_size(self.slices[d]) for d in range(len(input_shape))) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING | LinearOpFlags.IS_PURE_CONTRACTING)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: for d, dim_slices in enumerate(self.slices): if _is_full(dim_slices, self.input_shape[d]): continue if len(dim_slices) == 1: s = dim_slices[0] x = x.narrow(d, s.start, s.stop - s.start) else: parts = [x.narrow(d, s.start, s.stop - s.start) for s in dim_slices] x = torch.cat(parts, dim=d) return x
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: for d in range(len(self.slices)): dim_slices = self.slices[d] if _is_full(dim_slices, self.input_shape[d]): continue sizes = [s.stop - s.start for s in dim_slices] parts = grad.split(sizes, dim=d) if len(dim_slices) > 1 else [grad] shape = list(grad.shape) shape[d] = self.input_shape[d] result = torch.zeros(shape, dtype=grad.dtype, device=grad.device) for s, part in zip(dim_slices, parts): result.narrow(d, s.start, s.stop - s.start).add_(part) grad = result return grad
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x)
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: batch_ndim = grad.dim() - len(self.output_shape) for d in range(len(self.slices)): dim_slices = self.slices[d] bd = batch_ndim + d if _is_full(dim_slices, self.input_shape[d]): continue sizes = [s.stop - s.start for s in dim_slices] parts = grad.split(sizes, dim=bd) if len(dim_slices) > 1 else [grad] shape = list(grad.shape) shape[bd] = self.input_shape[d] result = torch.zeros(shape, dtype=grad.dtype, device=grad.device) for s, part in zip(dim_slices, parts): result.narrow(bd, s.start, s.stop - s.start).add_(part) grad = result return grad
def __matmul__(self, other): """Fuse GetSliceOp @ GetSliceOp or GetSliceOp @ EinsumOp.""" if isinstance(other, GetSliceOp): # Compose slices: apply self's slices to other's slices assert self.input_shape == other.output_shape new_slices = _compose_get_slices(self.slices, other.slices) return GetSliceOp(other.input_shape, new_slices) from boundlab.linearop._einsum import EinsumOp if isinstance(other, EinsumOp): assert self.input_shape == other.output_shape # Check if all non-trivial slicing dims are dot/batch dims mul_slice_dims = [] for d, dim_slices in enumerate(self.slices): if _is_full(dim_slices, self.input_shape[d]): continue tensor_dim = other.output_dims[d] if tensor_dim in other.mul_dims: mul_slice_dims.append(d) return _apply_getslice_einsum(self, other, mul_slice_dims) return NotImplemented def __rmatmul__(self, other): return super().__rmatmul__(other) def __str__(self): parts = [] for dim_slices in self.slices: if len(dim_slices) == 1: s = dim_slices[0] parts.append(f"{s.start}:{s.stop}") else: parts.append("[" + ",".join(f"{s.start}:{s.stop}" for s in dim_slices) + "]") return f"<getslice {','.join(parts)}>"
[docs] class SetSliceOp(LinearOp): """Embed input into zeros at specified slice positions. Args: output_shape: Shape of the output tensor (zeros template). slices: Per-dimension list of slices. ``len(output_shape) == len(slices)``. """
[docs] def __init__(self, output_shape: torch.Size, slices: list[list[slice]]): assert len(output_shape) == len(slices), \ f"len(output_shape)={len(output_shape)} != len(slices)={len(slices)}" self.slices = _normalize_slices(slices, output_shape) input_shape = torch.Size(_output_size(self.slices[d]) for d in range(len(output_shape))) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING | LinearOpFlags.IS_PURE_CONTRACTING)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: result = torch.zeros(self.output_shape, dtype=x.dtype, device=x.device) _scatter_slices(result, x, self.slices) return result
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: return _gather_slices(grad, self.slices)
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: batch = x.shape[len(self.input_shape):] result = torch.zeros(*self.output_shape, *batch, dtype=x.dtype, device=x.device) _scatter_slices(result, x, self.slices) return result
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: batch_ndim = grad.dim() - len(self.output_shape) return _gather_slices_batched(grad, self.slices, batch_ndim)
def __rmatmul__(self, other): """Fuse EinsumOp @ SetSliceOp.""" from boundlab.linearop._einsum import EinsumOp if isinstance(other, EinsumOp) and self.output_shape == other.input_shape: # Check if all non-trivial slicing dims are dot/batch dims mul_slice_dims = [] for d, dim_slices in enumerate(self.slices): if _is_full(dim_slices, self.output_shape[d]): continue tensor_dim = other.input_dims[d] if tensor_dim in other.mul_dims: mul_slice_dims.append(d) return _apply_einsum_setslice(other, self, mul_slice_dims) return super().__rmatmul__(other) def __matmul__(self, other): """Fuse SetSliceOp @ SetSliceOp.""" if isinstance(other, SetSliceOp): assert self.input_shape == other.output_shape new_slices = _compose_set_slices(self.slices, other.slices) return SetSliceOp(self.output_shape, new_slices) return NotImplemented def __str__(self): parts = [] for dim_slices in self.slices: if len(dim_slices) == 1: s = dim_slices[0] parts.append(f"{s.start}:{s.stop}") else: parts.append("[" + ",".join(f"{s.start}:{s.stop}" for s in dim_slices) + "]") return f"<setslice {','.join(parts)}>"
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _gather_slices(x: torch.Tensor, slices: list[list[slice]]) -> torch.Tensor: """Gather (forward of GetSliceOp).""" for d, dim_slices in enumerate(slices): if _is_full(dim_slices, x.shape[d]): continue if len(dim_slices) == 1: s = dim_slices[0] x = x.narrow(d, s.start, s.stop - s.start) else: parts = [x.narrow(d, s.start, s.stop - s.start) for s in dim_slices] x = torch.cat(parts, dim=d) return x def _gather_slices_batched(grad: torch.Tensor, slices: list[list[slice]], batch_ndim: int) -> torch.Tensor: """Gather with leading batch dims.""" for d, dim_slices in enumerate(slices): bd = batch_ndim + d if len(dim_slices) == 1 and dim_slices[0] == slice(0, grad.shape[bd]): continue if len(dim_slices) == 1: s = dim_slices[0] grad = grad.narrow(bd, s.start, s.stop - s.start) else: parts = [grad.narrow(bd, s.start, s.stop - s.start) for s in dim_slices] grad = torch.cat(parts, dim=bd) return grad def _scatter_slices(result: torch.Tensor, x: torch.Tensor, slices: list[list[slice]]) -> None: """Scatter x into result at slice positions (in-place).""" _recursive_scatter(result, x, slices, 0, [], []) def _recursive_scatter(result, x, slices, dim, result_indices, x_offsets): """Recursively scatter x into result across all slice combinations.""" if dim == len(slices): r_idx = tuple(result_indices) x_idx = tuple(x_offsets) result[r_idx].copy_(x[x_idx]) return dim_slices = slices[dim] x_pos = 0 for s in dim_slices: length = s.stop - s.start _recursive_scatter( result, x, slices, dim + 1, result_indices + [slice(s.start, s.stop)], x_offsets + [slice(x_pos, x_pos + length)] ) x_pos += length def _compose_get_slices(outer_slices, inner_slices): """Compose GetSlice @ GetSlice: apply outer slices to inner's output.""" result = [] for d in range(len(outer_slices)): # inner_slices[d] maps positions in inner.input to inner.output # outer_slices[d] selects from inner.output inner = inner_slices[d] outer = outer_slices[d] new_dim_slices = _compose_dim_slices(outer, inner) result.append(new_dim_slices) return result def _compose_dim_slices(outer: list[slice], inner: list[slice]) -> list[slice]: """Compose slices along one dimension. ``inner`` maps from the original tensor to an intermediate. ``outer`` selects from the intermediate. Result maps from the original tensor directly. """ # Build a mapping from intermediate positions to original positions # inner creates segments: inner[0] -> positions 0..len0, inner[1] -> len0..len0+len1, etc. result = [] for o_slice in outer: # Find which inner slices cover positions o_slice.start..o_slice.stop pos = 0 remaining_start = o_slice.start remaining_stop = o_slice.stop for i_slice in inner: i_len = i_slice.stop - i_slice.start seg_start = pos seg_stop = pos + i_len # Intersection of [remaining_start, remaining_stop) with [seg_start, seg_stop) inter_start = max(remaining_start, seg_start) inter_stop = min(remaining_stop, seg_stop) if inter_start < inter_stop: # Map back to original coordinates orig_start = i_slice.start + (inter_start - seg_start) orig_stop = i_slice.start + (inter_stop - seg_start) result.append(slice(orig_start, orig_stop)) pos += i_len return _merge_adjacent_slices(result) def _compose_set_slices(outer_slices, inner_slices): """Compose SetSlice @ SetSlice: embed inner's output into outer's output.""" result = [] for d in range(len(outer_slices)): outer = outer_slices[d] inner = inner_slices[d] new_dim_slices = _compose_dim_slices(inner, outer) result.append(new_dim_slices) return result def _merge_adjacent_slices(slices: list[slice]) -> list[slice]: """Merge adjacent/overlapping slices.""" if not slices: return [slice(0, 0)] result = [slices[0]] for s in slices[1:]: if s.start <= result[-1].stop: result[-1] = slice(result[-1].start, max(result[-1].stop, s.stop)) else: result.append(s) return result # --------------------------------------------------------------------------- # Fusion with EinsumOp # --------------------------------------------------------------------------- def _slice_tensor_along_dims(tensor, dims_map, slices_map): """Slice tensor along multiple dims. dims_map[output_d] = tensor_dim; slices_map[output_d] = dim_slices.""" for output_d, tensor_dim in dims_map.items(): dim_slices = slices_map[output_d] if len(dim_slices) == 1: s = dim_slices[0] tensor = tensor.narrow(tensor_dim, s.start, s.stop - s.start) else: parts = [tensor.narrow(tensor_dim, s.start, s.stop - s.start) for s in dim_slices] tensor = torch.cat(parts, dim=tensor_dim) return tensor def _apply_getslice_einsum(gs: GetSliceOp, einsum, mul_dims: list[int]): """Fuse/swap GetSliceOp @ EinsumOp. Slices the tensor on all non-trivial dims. For mul dims, also adds a GetSliceOp on the input side (swap). For dot/batch dims, no input op needed (fuse). """ from boundlab.linearop._einsum import EinsumOp dims_map = {} slices_map = {} for d, dim_slices in enumerate(gs.slices): if not _is_full(dim_slices, gs.input_shape[d]): dims_map[d] = einsum.output_dims[d] slices_map[d] = dim_slices tensor = _slice_tensor_along_dims(einsum.tensor, dims_map, slices_map) new_einsum = EinsumOp(tensor, einsum.input_dims, einsum.output_dims, name=merge_name(gs, "@", einsum)) assert new_einsum.output_shape == gs.output_shape, \ f"_apply_getslice_einsum: output_shape {new_einsum.output_shape} != {gs.output_shape}" if not mul_dims: assert new_einsum.input_shape == einsum.input_shape, \ f"_apply_getslice_einsum: input_shape {new_einsum.input_shape} != {einsum.input_shape}" return new_einsum # Build input-side slices for mul dims input_side_slices = [[slice(0, einsum.input_shape[d])] for d in range(len(einsum.input_shape))] for d in mul_dims: tensor_dim = einsum.output_dims[d] input_d = einsum.input_dims.index(tensor_dim) input_side_slices[input_d] = gs.slices[d] needs_input_slice = any( not _is_full(input_side_slices[d], einsum.input_shape[d]) for d in range(len(einsum.input_shape)) ) if needs_input_slice: input_gs = GetSliceOp(einsum.input_shape, input_side_slices) return ComposedOp(new_einsum, input_gs) return new_einsum def _apply_einsum_setslice(einsum, ss: SetSliceOp, mul_dims: list[int]): """Fuse/swap EinsumOp @ SetSliceOp. Slices the tensor on all non-trivial dims. For mul dims, also adds a SetSliceOp on the output side (swap). For dot/batch dims, no output op needed (fuse). """ from boundlab.linearop._einsum import EinsumOp dims_map = {} slices_map = {} for d, dim_slices in enumerate(ss.slices): if not _is_full(dim_slices, ss.output_shape[d]): dims_map[d] = einsum.input_dims[d] slices_map[d] = dim_slices tensor = _slice_tensor_along_dims(einsum.tensor, dims_map, slices_map) new_einsum = EinsumOp(tensor, einsum.input_dims, einsum.output_dims, name=merge_name(einsum, "@", ss)) assert new_einsum.input_shape == ss.input_shape, \ f"_apply_einsum_setslice: input_shape {new_einsum.input_shape} != {ss.input_shape}" if not mul_dims: assert new_einsum.output_shape == einsum.output_shape, \ f"_apply_einsum_setslice: output_shape {new_einsum.output_shape} != {einsum.output_shape}" return new_einsum # Build output-side slices for mul dims output_side_slices = [[slice(0, einsum.output_shape[d])] for d in range(len(einsum.output_shape))] for d in mul_dims: tensor_dim = einsum.input_dims[d] output_d = einsum.output_dims.index(tensor_dim) output_side_slices[output_d] = ss.slices[d] needs_output_slice = any( not _is_full(output_side_slices[d], einsum.output_shape[d]) for d in range(len(einsum.output_shape)) ) if needs_output_slice: output_ss = SetSliceOp(einsum.output_shape, output_side_slices) return ComposedOp(output_ss, new_einsum) return new_einsum