Source code for boundlab.linearop._indices

"""Indexing LinearOp implementations for bound propagation.

This module provides LinearOps for various indexing and slicing operations.
The key distinction is:

- **Slice-based** (`GetSliceOp`, `SetSliceOp`): Use Python slice/int indices.
  These are the generalized versions that subsume `NarrowOp`, `SelectOp`,
  `GetItemOp`, and `PadOp`.

- **Index-tensor-based** (`GetIndicesOp`, `SetIndicesOp`): Use tensor indices
  for advanced indexing where each element position is specified by tensors.

- **Gather/Scatter** (`GatherOp`, `ScatterOp`): Dimension-specific operations
  that gather or scatter along a single dimension using index tensors.
"""

import torch

from boundlab.linearop._base import LinearOp, LinearOpFlags


def _meta_output_shape(fn, input_shape: torch.Size) -> torch.Size:
    """Infer an operator output shape without materializing data.

    Args:
        fn: A shape-preserving callable that accepts a tensor.
        input_shape: Shape of the hypothetical input tensor.

    Returns:
        The output shape obtained by running ``fn`` on a meta-device tensor.
    """
    return fn(torch.empty(input_shape, device="meta")).shape


# ---------------------------------------------------------------------------
# Gather / Scatter operations
# ---------------------------------------------------------------------------


[docs] class GatherOp(LinearOp): """A LinearOp that implements ``torch.gather`` along a specified dimension. Forward: ``output[i][j][k] = input[i][index[i][j][k]][k]`` (for dim=1) Backward: Scatters gradient back using ``scatter_add``. """
[docs] def __init__(self, input_shape: torch.Size, dim: int, index: torch.Tensor): self.dim = dim self.index = index output_shape = torch.Size(index.shape) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # Check if x has extra batch dimensions (e.g., when called via vmap) if x.shape != self.input_shape: return self.vforward(x) return torch.gather(x, self.dim, self.index)
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: # Check if grad has extra batch dimensions if grad.shape != self.output_shape: return self.vbackward(grad) result = torch.zeros(self.input_shape, dtype=grad.dtype, device=grad.device) result.scatter_add_(self.dim, self.index, grad) return result
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: # x: (*input_shape, *batch_dims) batch_dims = x.shape[len(self.input_shape):] # Expand index to match batch dims index = self.index for _ in batch_dims: index = index.unsqueeze(-1) index = index.expand(*self.index.shape, *batch_dims) return torch.gather(x, self.dim, index)
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: # grad: (*batch_dims, *output_shape) batch_dims = grad.shape[:-len(self.output_shape)] batch_ndim = len(batch_dims) # Expand index to match batch dims index = self.index for _ in batch_dims: index = index.unsqueeze(0) index = index.expand(*batch_dims, *self.index.shape) result = torch.zeros(*batch_dims, *self.input_shape, dtype=grad.dtype, device=grad.device) result.scatter_add_(batch_ndim + self.dim, index, grad) return result
def __str__(self): return f"gather(dim={self.dim}, index.shape={list(self.index.shape)})"
[docs] class ScatterOp(LinearOp): """A LinearOp that implements ``torch.scatter`` along a specified dimension. Forward: Creates zeros of output_shape, then scatters input values at index positions. Backward: Gathers gradient from the scattered positions. """
[docs] def __init__(self, input_shape: torch.Size, dim: int, index: torch.Tensor, output_shape: torch.Size): self.dim = dim self.index = index assert index.shape == input_shape, f"Index shape {index.shape} must match input shape {input_shape}" super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # Check if x has extra batch dimensions if x.shape != self.input_shape: return self.vforward(x) result = torch.zeros(self.output_shape, dtype=x.dtype, device=x.device) result.scatter_(self.dim, self.index, x) return result
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: # Check if grad has extra batch dimensions if grad.shape != self.output_shape: return self.vbackward(grad) return torch.gather(grad, self.dim, self.index)
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: batch_dims = x.shape[len(self.input_shape):] index = self.index for _ in batch_dims: index = index.unsqueeze(-1) index = index.expand(*self.index.shape, *batch_dims) result = torch.zeros(*self.output_shape, *batch_dims, dtype=x.dtype, device=x.device) result.scatter_(self.dim, index, x) return result
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: batch_dims = grad.shape[:-len(self.output_shape)] batch_ndim = len(batch_dims) index = self.index for _ in batch_dims: index = index.unsqueeze(0) index = index.expand(*batch_dims, *self.index.shape) return torch.gather(grad, batch_ndim + self.dim, index)
def __str__(self): return f"scatter(dim={self.dim}, index.shape={list(self.index.shape)})"
# --------------------------------------------------------------------------- # Advanced indexing with index tensors # ---------------------------------------------------------------------------
[docs] class GetIndicesOp(LinearOp): r"""Advanced indexing with a tuple of index tensors. Given ``indices = (i_0, i_1, ..., i_{d-1})``, forward evaluation computes: .. math:: y = x[\text{indices}] where each index tensor has shape ``output_shape``. The transpose (backward) operation writes each entry of ``grad`` back to its indexed position in an all-zero tensor, with accumulation for repeated indices. Args: indices: Tuple of integer index tensors, one per input dimension. input_shape: Shape of the source tensor ``x``. output_shape: Shape of the indexed output tensor ``y``. Notes: If an index appears multiple times, gradients are summed at that position (``accumulate=True`` semantics). """
[docs] def __init__(self, indices: tuple[torch.Tensor, ...], input_shape: torch.Size, output_shape: torch.Size): self.indices = indices assert isinstance(indices, tuple) and len(indices) == len(input_shape), \ "Indices must be a tuple of the same length as input_shape." for idx in indices: assert idx.shape == output_shape, \ f"Each index tensor must have shape {output_shape}, got {idx.shape}" super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # Check if x has extra batch dimensions if x.shape != self.input_shape: return self.vforward(x) return x[self.indices]
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: # Check if grad has extra batch dimensions if grad.shape != self.output_shape: return self.vbackward(grad) result = torch.zeros(self.input_shape, dtype=grad.dtype, device=grad.device) # Use index_put_ with accumulate=True for correct gradient with repeated indices result.index_put_(self.indices, grad, accumulate=True) return result
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: # x: (*input_shape, *batch) batch = x.shape[len(self.input_shape):] if not batch: return x[self.indices] # Expand indices to include batch dims expanded_indices = [] for idx in self.indices: for _ in batch: idx = idx.unsqueeze(-1) expanded_indices.append(idx.expand(*self.output_shape, *batch)) return x[tuple(expanded_indices)]
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: # grad: (*batch, *output_shape) batch = grad.shape[:-len(self.output_shape)] batch_ndim = len(batch) result = torch.zeros(*batch, *self.input_shape, dtype=grad.dtype, device=grad.device) expanded_indices = [] for idx in self.indices: for _ in batch: idx = idx.unsqueeze(0) expanded_indices.append(idx.expand(*batch, *self.output_shape)) # Prepend batch indices batch_indices = [ torch.arange(b, device=grad.device).reshape( *([1] * i), b, *([1] * (batch_ndim - i - 1 + len(self.output_shape))) ).expand(*batch, *self.output_shape) for i, b in enumerate(batch) ] all_indices = tuple(batch_indices) + tuple(expanded_indices) result.index_put_(all_indices, grad, accumulate=True) return result
def __str__(self): return f"get_indices({self.output_shape})"
[docs] class SetIndicesOp(LinearOp): """Scatter values to advanced index positions in a zero-initialized tensor. Forward creates ``result = zeros(output_shape)`` and assigns: ``result[indices] = input``. Backward gathers gradients at the same index positions. This operator is the transpose/adjoint counterpart of :class:`GetIndicesOp`. """
[docs] def __init__(self, indices: tuple[torch.Tensor, ...], input_shape: torch.Size, output_shape: torch.Size): self.indices = indices assert isinstance(indices, tuple) and len(indices) == len(output_shape), \ "Indices must be a tuple of the same length as output_shape." for idx in indices: assert idx.shape == input_shape, \ f"Each index tensor must have shape {input_shape}, got {idx.shape}" super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # Check if x has extra batch dimensions if x.shape != self.input_shape: return self.vforward(x) result = torch.zeros(self.output_shape, dtype=x.dtype, device=x.device) result[self.indices] = x return result
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: # Check if grad has extra batch dimensions if grad.shape != self.output_shape: return self.vbackward(grad) return grad[self.indices]
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: batch = x.shape[len(self.input_shape):] if not batch: return self.forward(x) expanded_indices = [] for idx in self.indices: for _ in batch: idx = idx.unsqueeze(-1) expanded_indices.append(idx.expand(*self.input_shape, *batch)) result = torch.zeros(*self.output_shape, *batch, dtype=x.dtype, device=x.device) result[tuple(expanded_indices)] = x return result
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: batch = grad.shape[:-len(self.output_shape)] expanded_indices = [] for idx in self.indices: for _ in batch: idx = idx.unsqueeze(0) expanded_indices.append(idx.expand(*batch, *self.input_shape)) batch_indices = [ torch.arange(b, device=grad.device).reshape( *([1] * i), b, *([1] * (len(batch) - i - 1 + len(self.input_shape))) ).expand(*batch, *self.input_shape) for i, b in enumerate(batch) ] all_indices = tuple(batch_indices) + tuple(expanded_indices) return grad[all_indices]
def __str__(self): return f"set_indices({list(self.input_shape)} -> {list(self.output_shape)})"
# --------------------------------------------------------------------------- # Slice-based indexing (basic indexing with int/slice) # ---------------------------------------------------------------------------
[docs] class GetSliceOp(LinearOp): """Basic slicing via ``x[indices]`` where indices contains int/slice/None/Ellipsis. This is a generalization that subsumes: - NarrowOp: ``x.narrow(dim, start, length)`` → ``GetSliceOp`` with a slice at dim - SelectOp: ``x.select(dim, index)`` → ``GetSliceOp`` with an int at dim - GetItemOp: ``x[indices]`` → ``GetSliceOp`` directly Forward: ``output = input[indices]`` Backward: Embeds gradient into zeros at the sliced positions. """
[docs] def __init__(self, input_shape: torch.Size, indices): self.indices = indices output_shape = _meta_output_shape(lambda x: x[indices], input_shape) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE) # Pre-compute backward info for functional (vmap-compatible) backward self._backward_info = self._compute_backward_info(input_shape, indices, output_shape)
def _compute_backward_info(self, input_shape, indices, output_shape): # noqa: ARG002 """Pre-compute info needed for functional backward using F.pad.""" # Normalize indices to a tuple if not isinstance(indices, tuple): indices = (indices,) # Check if we can use F.pad (pure slices, no integer indices that reduce dims) # and compute the pad_spec ndim = len(input_shape) normalized = [] int_indices = [] # (dim, index) pairs for integer indices idx_pos = 0 for i in range(ndim): if idx_pos < len(indices): idx = indices[idx_pos] if idx is Ellipsis: # Ellipsis expands to fill remaining dims remaining = ndim - len(indices) + 1 for _ in range(remaining): normalized.append(slice(None)) idx_pos += 1 continue elif isinstance(idx, int): int_indices.append((i, idx)) normalized.append(idx) elif isinstance(idx, slice): normalized.append(idx) elif idx is None: # newaxis - adds dimension, handle separately normalized.append(idx) else: normalized.append(idx) idx_pos += 1 else: normalized.append(slice(None)) # If there are integer indices, we need to unsqueeze before padding # Compute pad_spec from normalized indices pad_spec = [] grad_dim = 0 # Current dimension in grad for d in reversed(range(ndim)): idx = normalized[d] if d < len(normalized) else slice(None) if isinstance(idx, int): # Integer index - this dim is removed in output, need to unsqueeze pad_before = idx if idx >= 0 else input_shape[d] + idx pad_after = input_shape[d] - pad_before - 1 pad_spec.extend([pad_before, pad_after]) elif isinstance(idx, slice): start, stop, step = idx.indices(input_shape[d]) if step != 1: return None # Can't use F.pad for step != 1 pad_before = start pad_after = input_shape[d] - stop pad_spec.extend([pad_before, pad_after]) grad_dim += 1 else: pad_spec.extend([0, 0]) grad_dim += 1 return { 'pad_spec': pad_spec, 'int_indices': int_indices, 'normalized': normalized, }
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return x[self.indices]
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: """Backward using F.pad for vmap compatibility.""" import torch.nn.functional as F info = self._backward_info if info is None: # Fallback for complex indexing (step != 1, etc.) result = torch.zeros(self.input_shape, dtype=grad.dtype, device=grad.device) result[self.indices] = grad return result # First, unsqueeze for any integer indices (in reverse order to maintain positions) result = grad for dim, _ in reversed(info['int_indices']): result = result.unsqueeze(dim) # Then apply padding if any(p != 0 for p in info['pad_spec']): result = F.pad(result, info['pad_spec']) return result
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: # x: (*input_shape, *batch) # Apply slicing to the leading dims, preserve trailing batch dims return x[self.indices]
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: # For vbackward, we can use the same F.pad approach since it's vmap-compatible return self.backward(grad)
def __str__(self): return f"getslice({_format_indices(self.indices)})"
[docs] class SetSliceOp(LinearOp): """Embed input into zeros at specified slice positions. This is the adjoint/transpose of GetSliceOp and generalizes PadOp. Forward: Creates zeros of output_shape, sets ``result[indices] = input``. Backward: Extracts gradient at the sliced positions. """
[docs] def __init__(self, indices, input_shape: torch.Size, output_shape: torch.Size): self.indices = indices super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # Check if x has extra batch dimensions (e.g., when called via vmap) if x.shape != self.input_shape: return self.vforward(x) result = torch.zeros(self.output_shape, dtype=x.dtype, device=x.device) result[self.indices] = x return result
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: # Check if grad has extra batch dimensions if grad.shape != self.output_shape: return self.vbackward(grad) return grad[self.indices]
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: batch = x.shape[len(self.input_shape):] if not batch: return self.forward(x) result = torch.zeros(*self.output_shape, *batch, dtype=x.dtype, device=x.device) result[self.indices] = x return result
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: batch = grad.shape[:-len(self.output_shape)] if self.output_shape else grad.shape if not batch: return self.backward(grad) batch_ndim = len(batch) batch_slices = tuple(slice(None) for _ in range(batch_ndim)) if isinstance(self.indices, tuple): full_indices = batch_slices + self.indices else: full_indices = batch_slices + (self.indices,) return grad[full_indices]
def __str__(self): return f"setslice({_format_indices(self.indices)})"
# --------------------------------------------------------------------------- # Convenience constructors # ---------------------------------------------------------------------------
[docs] def narrow_indices(ndim: int, dim: int, start: int, length: int) -> tuple: """Create slice indices equivalent to ``tensor.narrow(dim, start, length)``.""" indices = [slice(None)] * ndim indices[dim] = slice(start, start + length) return tuple(indices)
[docs] def select_indices(ndim: int, dim: int, index: int) -> tuple: """Create slice indices equivalent to ``tensor.select(dim, index)``.""" indices = [slice(None)] * ndim indices[dim] = index return tuple(indices)
[docs] def pad_indices(input_shape: torch.Size, pad_spec: list[int]) -> tuple: """Create slice indices for embedding input into a padded output. The pad_spec follows PyTorch's F.pad convention: [left, right, top, bottom, ...] applied from the last dimension backwards. """ ndim = len(input_shape) indices = [] for d in range(ndim): d_rev = ndim - 1 - d if 2 * d_rev + 1 < len(pad_spec): pad_before = pad_spec[2 * d_rev] indices.append(slice(pad_before, pad_before + input_shape[d])) else: indices.append(slice(None)) return tuple(indices)
[docs] def pad_output_shape(input_shape: torch.Size, pad_spec: list[int]) -> torch.Size: """Compute output shape after padding.""" ndim = len(input_shape) output = list(input_shape) for d in range(ndim): d_rev = ndim - 1 - d if 2 * d_rev + 1 < len(pad_spec): output[d] += pad_spec[2 * d_rev] + pad_spec[2 * d_rev + 1] return torch.Size(output)
# --------------------------------------------------------------------------- # Helper functions # --------------------------------------------------------------------------- def _format_indices(indices) -> str: """Format indices for string representation.""" if not isinstance(indices, tuple): indices = (indices,) parts = [] for idx in indices: if isinstance(idx, slice): start = "" if idx.start is None else str(idx.start) stop = "" if idx.stop is None else str(idx.stop) step = "" if idx.step is None else f":{idx.step}" parts.append(f"{start}:{stop}{step}") elif idx is None: parts.append("None") elif idx is Ellipsis: parts.append("...") else: parts.append(str(idx)) return ", ".join(parts)