Source code for boundlab.linearop._indexing

"""Index-tensor-based LinearOp implementations.

``GetIndicesOp`` and ``SetIndicesOp`` index along a single dimension using
a tensor of indices, replacing that dimension with ``added_shape``.
"""

import torch

from boundlab.linearop._base import LinearOp, LinearOpFlags, ComposedOp
from boundlab.utils import merge_name


[docs] class GetIndicesOp(LinearOp): """Gather elements along *dim* using an index tensor. output_shape = input_shape[:dim] + added_shape + input_shape[dim+1:] Args: input_shape: Shape of the source tensor. dim: Dimension along which to index. indices: Index tensor with shape ``added_shape``, values in ``[0, input_shape[dim])``. added_shape: Shape that replaces ``input_shape[dim]``. """
[docs] def __init__(self, input_shape: torch.Size, dim: int, indices: torch.Tensor, added_shape: torch.Size): self.dim = dim self.indices = indices self.added_shape = added_shape assert indices.shape == added_shape, \ f"indices.shape={indices.shape} != added_shape={added_shape}" output_shape = torch.Size( list(input_shape[:dim]) + list(added_shape) + list(input_shape[dim + 1:]) ) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self._gather(x, self.dim)
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: return self._scatter(grad, self.dim, self.input_shape[self.dim])
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: return self._gather(x, self.dim)
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: batch_ndim = grad.dim() - len(self.output_shape) return self._scatter(grad, batch_ndim + self.dim, self.input_shape[self.dim])
def _gather(self, x: torch.Tensor, dim: int) -> torch.Tensor: """Gather along dim, replacing it with added_shape.""" n_added = len(self.added_shape) # Reshape indices for broadcasting: insert dims for all non-dim axes of x idx = self.indices # Add trailing dims for dims after 'dim' in x n_after = x.dim() - dim - 1 for _ in range(n_after): idx = idx.unsqueeze(-1) # Add leading dims for dims before 'dim' in x for _ in range(dim): idx = idx.unsqueeze(0) # Now idx has shape: (1,)*dim + added_shape + (1,)*n_after # Expand to match x's shape except at the indexed dim (replaced by added_shape) expand_shape = list(x.shape[:dim]) + list(self.added_shape) + list(x.shape[dim + 1:]) idx = idx.expand(expand_shape) # Flatten added_shape into a single dim for torch.gather # Reshape x: merge nothing (gather directly if added_shape is 1D) if n_added == 1: return torch.gather(x, dim, idx) else: # Flatten added_shape dims in idx flat_idx = idx.flatten(dim, dim + n_added - 1) # Insert extra dims in x to match result = torch.gather(x, dim, flat_idx) # Unflatten back to added_shape return result.unflatten(dim, self.added_shape) def _scatter(self, grad: torch.Tensor, dim: int, source_size: int) -> torch.Tensor: """Scatter gradients back along dim.""" n_added = len(self.added_shape) # Flatten added_shape dims if n_added > 1: grad = grad.flatten(dim, dim + n_added - 1) idx = self.indices.reshape(-1) # flatten added_shape # Reshape idx for broadcasting n_after = grad.dim() - dim - 1 for _ in range(n_after): idx = idx.unsqueeze(-1) for _ in range(dim): idx = idx.unsqueeze(0) idx = idx.expand_as(grad) result_shape = list(grad.shape) result_shape[dim] = source_size result = torch.zeros(result_shape, dtype=grad.dtype, device=grad.device) result.scatter_add_(dim, idx, grad) return result def __matmul__(self, other): """Fuse GetIndicesOp @ EinsumOp.""" from boundlab.linearop._einsum import EinsumOp if isinstance(other, EinsumOp): assert self.input_shape == other.output_shape tensor_dim = other.output_dims[self.dim] return _apply_getindices_einsum(self, other, is_mul=tensor_dim in other.mul_dims) if isinstance(other, GetIndicesOp): # Compose: self gathers from other's output if self.dim == other.dim: # Compose indices: self.indices indexes into other's output along dim, # which is other.added_shape. Map through other.indices. new_indices = other.indices.flatten()[self.indices.flatten()].reshape(self.added_shape) return GetIndicesOp(other.input_shape, self.dim, new_indices, self.added_shape) return NotImplemented def __str__(self): return f"<getindices dim={self.dim} added={list(self.added_shape)}>"
[docs] class SetIndicesOp(LinearOp): """Scatter values along *dim* using an index tensor. input_shape = output_shape[:dim] + added_shape + output_shape[dim+1:] Args: output_shape: Shape of the output tensor (zeros template). dim: Dimension along which to scatter. indices: Index tensor with shape ``added_shape``, values in ``[0, output_shape[dim])``. added_shape: Shape that replaces ``output_shape[dim]`` in the input. """
[docs] def __init__(self, output_shape: torch.Size, dim: int, indices: torch.Tensor, added_shape: torch.Size): self.dim = dim self.indices = indices self.added_shape = added_shape assert indices.shape == added_shape, \ f"indices.shape={indices.shape} != added_shape={added_shape}" input_shape = torch.Size( list(output_shape[:dim]) + list(added_shape) + list(output_shape[dim + 1:]) ) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self._scatter(x, self.dim, self.output_shape[self.dim])
[docs] def backward(self, grad: torch.Tensor) -> torch.Tensor: return self._gather(grad, self.dim)
[docs] def vforward(self, x: torch.Tensor) -> torch.Tensor: return self._scatter(x, self.dim, self.output_shape[self.dim])
[docs] def vbackward(self, grad: torch.Tensor) -> torch.Tensor: batch_ndim = grad.dim() - len(self.output_shape) return self._gather(grad, batch_ndim + self.dim)
def _gather(self, x: torch.Tensor, dim: int) -> torch.Tensor: """Gather along dim (backward of scatter).""" n_added = len(self.added_shape) idx = self.indices.reshape(-1) n_after = x.dim() - dim - 1 for _ in range(n_after): idx = idx.unsqueeze(-1) for _ in range(dim): idx = idx.unsqueeze(0) idx = idx.expand(*x.shape[:dim], len(self.indices.flatten()), *x.shape[dim + 1:]) result = torch.gather(x, dim, idx) if n_added > 1: result = result.unflatten(dim, self.added_shape) return result def _scatter(self, x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor: """Scatter x along dim into zeros.""" n_added = len(self.added_shape) if n_added > 1: x = x.flatten(dim, dim + n_added - 1) idx = self.indices.reshape(-1) n_after = x.dim() - dim - 1 for _ in range(n_after): idx = idx.unsqueeze(-1) for _ in range(dim): idx = idx.unsqueeze(0) idx = idx.expand_as(x) result_shape = list(x.shape) result_shape[dim] = target_size result = torch.zeros(result_shape, dtype=x.dtype, device=x.device) result.scatter_add_(dim, idx, x) return result def __rmatmul__(self, other): """Fuse EinsumOp @ SetIndicesOp.""" from boundlab.linearop._einsum import EinsumOp if isinstance(other, EinsumOp) and self.output_shape == other.input_shape: tensor_dim = other.input_dims[self.dim] return _apply_einsum_setindices(other, self, is_mul=tensor_dim in other.mul_dims) return super().__rmatmul__(other) def __str__(self): return f"<setindices dim={self.dim} added={list(self.added_shape)}>"
# --------------------------------------------------------------------------- # Fusion with EinsumOp # --------------------------------------------------------------------------- def _index_tensor_along_dim(tensor, tensor_dim, indices): """Index a tensor along tensor_dim using 1D indices.""" idx = indices.flatten() slices = [slice(None)] * tensor.dim() slices[tensor_dim] = idx result = tensor[tuple(slices)] if len(indices.shape) > 1: # Need to unflatten the tensor_dim shape = list(result.shape) shape[tensor_dim:tensor_dim + 1] = list(indices.shape) result = result.reshape(shape) return result def _remap_dims_after_index(dims, tensor_dim, added_shape): """Remap a list of tensor dims after indexing tensor_dim with added_shape. tensor_dim in the original tensor becomes len(added_shape) dims starting at tensor_dim. All dims > tensor_dim shift by (len(added_shape) - 1). """ shift = len(added_shape) - 1 result = [] for d in dims: if d > tensor_dim: result.append(d + shift) else: result.append(d) return result def _apply_getindices_einsum(gi: GetIndicesOp, einsum, is_mul: bool): """Fuse/swap GetIndicesOp @ EinsumOp. For dot/batch dims (is_mul=False): fuse by indexing the tensor, no input op. For mul dims (is_mul=True): index the tensor AND add a GetIndicesOp on the input side. """ from boundlab.linearop._einsum import EinsumOp tensor_dim = einsum.output_dims[gi.dim] n_added = len(gi.added_shape) new_tensor = _index_tensor_along_dim(einsum.tensor, tensor_dim, gi.indices) if n_added == 1: new_input_dims = einsum.input_dims new_output_dims = einsum.output_dims else: shape = list(new_tensor.shape) shape[tensor_dim:tensor_dim + 1] = list(gi.added_shape) new_tensor = new_tensor.reshape(shape) shift = n_added - 1 new_output_dims = [] for i, d in enumerate(einsum.output_dims): if i < gi.dim: new_output_dims.append(d if d < tensor_dim else d + shift) elif i == gi.dim: new_output_dims.extend(tensor_dim + k for k in range(n_added)) else: new_output_dims.append(d + shift if d >= tensor_dim else d) if is_mul: new_input_dims = list(einsum.input_dims) for j in range(len(new_input_dims)): if new_input_dims[j] > tensor_dim: new_input_dims[j] += shift # input dim == tensor_dim stays (handled by the input GetIndicesOp) else: new_input_dims = _remap_dims_after_index(einsum.input_dims, tensor_dim, gi.added_shape) new_einsum = EinsumOp(new_tensor, new_input_dims, new_output_dims, name=merge_name(gi, "@", einsum)) assert new_einsum.output_shape == gi.output_shape, \ f"_apply_getindices_einsum: output_shape {new_einsum.output_shape} != {gi.output_shape}" if is_mul: input_d = einsum.input_dims.index(tensor_dim) input_gi = GetIndicesOp(einsum.input_shape, input_d, gi.indices, gi.added_shape) return ComposedOp(new_einsum, input_gi) assert new_einsum.input_shape == einsum.input_shape, \ f"_apply_getindices_einsum: input_shape {new_einsum.input_shape} != {einsum.input_shape}" return new_einsum def _apply_einsum_setindices(einsum, si: SetIndicesOp, is_mul: bool): """Fuse/swap EinsumOp @ SetIndicesOp. For dot/batch dims (is_mul=False): fuse by indexing the tensor, no output op. For mul dims (is_mul=True): index the tensor AND add a SetIndicesOp on the output side. """ from boundlab.linearop._einsum import EinsumOp tensor_dim = einsum.input_dims[si.dim] n_added = len(si.added_shape) new_tensor = _index_tensor_along_dim(einsum.tensor, tensor_dim, si.indices) if n_added == 1: new_input_dims = einsum.input_dims new_output_dims = einsum.output_dims else: shape = list(new_tensor.shape) shape[tensor_dim:tensor_dim + 1] = list(si.added_shape) new_tensor = new_tensor.reshape(shape) shift = n_added - 1 new_input_dims = [] for i, d in enumerate(einsum.input_dims): if i < si.dim: new_input_dims.append(d if d < tensor_dim else d + shift) elif i == si.dim: new_input_dims.extend(tensor_dim + k for k in range(n_added)) else: new_input_dims.append(d + shift if d >= tensor_dim else d) if is_mul: new_output_dims = list(einsum.output_dims) for j in range(len(new_output_dims)): if new_output_dims[j] > tensor_dim: new_output_dims[j] += shift # output dim == tensor_dim stays (handled by the output SetIndicesOp) else: new_output_dims = _remap_dims_after_index(einsum.output_dims, tensor_dim, si.added_shape) new_einsum = EinsumOp(new_tensor, new_input_dims, new_output_dims, name=merge_name(einsum, "@", si)) assert new_einsum.input_shape == si.input_shape, \ f"_apply_einsum_setindices: input_shape {new_einsum.input_shape} != {si.input_shape}" if is_mul: output_d = einsum.output_dims.index(tensor_dim) output_si = SetIndicesOp(einsum.output_shape, output_d, si.indices, si.added_shape) return ComposedOp(output_si, new_einsum) assert new_einsum.output_shape == einsum.output_shape, \ f"_apply_einsum_setindices: output_shape {new_einsum.output_shape} != {einsum.output_shape}" return new_einsum