Source code for boundlab.linearop._shape

"""Shape and indexing LinearOp implementations for bound propagation.

Each class implements explicit forward (the shape operation) and backward
(its adjoint/transpose) so that no automatic VJP is needed.
"""

import torch

from boundlab.linearop._base import LinearOp, LinearOpFlags, ScalarOp


def _meta_output_shape(fn, input_shape: torch.Size) -> torch.Size:
    """Compute output shape by tracing *fn* on a meta-device tensor."""
    return fn(torch.empty(input_shape, device="meta")).shape


# ---------------------------------------------------------------------------
# Reshape / view operations
# ---------------------------------------------------------------------------

[docs] class ReshapeOp(LinearOp): """Reshape (view) the input tensor to *target_shape*."""
[docs] def __init__(self, input_shape: torch.Size, target_shape: tuple[int, ...]): self.target_shape = target_shape output_shape = _meta_output_shape(lambda x: x.reshape(*target_shape), input_shape) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.reshape(*self.target_shape)
[docs] def backward(self, grad): return grad.reshape(self.input_shape)
[docs] def vforward(self, x): extra = x.shape[len(self.input_shape):] return x.reshape(*self.output_shape, *extra)
[docs] def vbackward(self, grad): extra = grad.shape[:-len(self.output_shape)] return grad.reshape(*extra, *self.input_shape)
def __str__(self): return f"reshape({list(self.target_shape)})"
[docs] class FlattenOp(LinearOp): """Flatten dimensions [start_dim .. end_dim] into a single dimension."""
[docs] def __init__(self, input_shape: torch.Size, start_dim: int = 0, end_dim: int = -1): self.start_dim = start_dim self.end_dim = end_dim if end_dim >= 0 else len(input_shape) + end_dim self.original_sizes = input_shape[self.start_dim:self.end_dim + 1] output_shape = _meta_output_shape( lambda x: x.flatten(start_dim, end_dim), input_shape) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.flatten(self.start_dim, self.end_dim)
[docs] def backward(self, grad): return grad.unflatten(self.start_dim, self.original_sizes)
def __str__(self): return f"flatten({self.start_dim}, {self.end_dim})"
[docs] class UnflattenOp(LinearOp): """Unflatten dimension *dim* into *sizes*."""
[docs] def __init__(self, input_shape: torch.Size, dim: int, sizes: tuple[int, ...]): self.dim = dim self.sizes = sizes self.end_dim = dim + len(sizes) - 1 output_shape = _meta_output_shape( lambda x: x.unflatten(dim, sizes), input_shape) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.unflatten(self.dim, self.sizes)
[docs] def backward(self, grad): return grad.flatten(self.dim, self.end_dim)
def __str__(self): return f"unflatten({self.dim}, {list(self.sizes)})"
# --------------------------------------------------------------------------- # Permutation / transposition # ---------------------------------------------------------------------------
[docs] class PermuteOp(LinearOp): """Permute dimensions of the input tensor."""
[docs] def __init__(self, input_shape: torch.Size, dims: tuple[int, ...]): self.dims = list(dims) self.inv_dims = [0] * len(dims) for i, d in enumerate(dims): self.inv_dims[d] = i output_shape = torch.Size(input_shape[d] for d in dims) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.permute(*self.dims)
[docs] def backward(self, grad): return grad.permute(*self.inv_dims)
[docs] def vforward(self, x): n = len(self.dims) batch_ndim = x.dim() - n perm = self.dims + [n + i for i in range(batch_ndim)] return x.permute(*perm)
[docs] def vbackward(self, grad): n = len(self.inv_dims) batch_ndim = grad.dim() - n perm = list(range(batch_ndim)) + [batch_ndim + d for d in self.inv_dims] return grad.permute(*perm)
def __str__(self): return f"permute({self.dims})"
[docs] class TransposeOp(PermuteOp): """Swap two dimensions of the input tensor — special case of PermuteOp."""
[docs] def __init__(self, input_shape: torch.Size, dim0: int, dim1: int): self.dim0 = dim0 self.dim1 = dim1 dims = list(range(len(input_shape))) dims[dim0], dims[dim1] = dims[dim1], dims[dim0] super().__init__(input_shape, tuple(dims))
def __str__(self): return f"transpose({self.dim0}, {self.dim1})"
# --------------------------------------------------------------------------- # Squeeze / unsqueeze # ---------------------------------------------------------------------------
[docs] class SqueezeOp(LinearOp): """Remove size-1 dimension(s)."""
[docs] def __init__(self, input_shape: torch.Size, dim=None): self.dim = dim if dim is not None: self._is_noop = (input_shape[dim] != 1) if self._is_noop: output_shape = input_shape else: output_shape = torch.Size( s for i, s in enumerate(input_shape) if i != dim) else: self._is_noop = all(s != 1 for s in input_shape) self._squeezed_dims = [i for i, s in enumerate(input_shape) if s == 1] output_shape = torch.Size(s for s in input_shape if s != 1) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.squeeze(self.dim) if self.dim is not None else x.squeeze()
[docs] def backward(self, grad): if self._is_noop: return grad if self.dim is not None: return grad.unsqueeze(self.dim) for d in self._squeezed_dims: grad = grad.unsqueeze(d) return grad
def __str__(self): return f"squeeze({self.dim})"
[docs] class UnsqueezeOp(LinearOp): """Insert a size-1 dimension at *dim*."""
[docs] def __init__(self, input_shape: torch.Size, dim: int): self.dim = dim output_shape = _meta_output_shape(lambda x: x.unsqueeze(dim), input_shape) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.unsqueeze(self.dim)
[docs] def backward(self, grad): return grad.squeeze(self.dim)
[docs] def vforward(self, x): return x.unsqueeze(self.dim)
[docs] def vbackward(self, grad): batch_ndim = grad.dim() - len(self.output_shape) return grad.squeeze(batch_ndim + self.dim)
def __str__(self): return f"unsqueeze({self.dim})"
# --------------------------------------------------------------------------- # Expand / repeat / tile # ---------------------------------------------------------------------------
[docs] class ExpandOp(LinearOp): """Broadcast-expand dimensions (adjoint sums over expanded dims)."""
[docs] def __new__(cls, input_shape: torch.Size, sizes: tuple[int, ...]): if input_shape == torch.Size(sizes): return ScalarOp(1.0, input_shape) return super().__new__(cls)
[docs] def __init__(self, input_shape: torch.Size, sizes: tuple[int, ...]): self.sizes = sizes output_shape = _meta_output_shape(lambda x: x.expand(*sizes), input_shape) # Dims that need to be summed in backward n_new = len(output_shape) - len(input_shape) self._sum_dims: list[int] = list(range(n_new)) for i in range(len(input_shape)): if input_shape[i] == 1 and output_shape[n_new + i] > 1: self._sum_dims.append(n_new + i) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.expand(*self.sizes)
[docs] def backward(self, grad): if self._sum_dims: grad = grad.sum(dim=self._sum_dims) return grad.reshape(self.input_shape)
def __str__(self): return f"expand({list(self.sizes)})"
[docs] class RepeatOp(LinearOp): """Tile-repeat the tensor (adjoint folds and sums repeated blocks)."""
[docs] def __init__(self, input_shape: torch.Size, sizes: tuple[int, ...]): self.sizes = sizes n_pad = len(sizes) - len(input_shape) self._padded_input_shape = torch.Size([1] * n_pad + list(input_shape)) output_shape = torch.Size( s * r for s, r in zip(self._padded_input_shape, sizes)) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.repeat(*self.sizes)
[docs] def backward(self, grad): # Interleave (repeat_factor, original_size) pairs, then sum repeat dims new_shape = [] for r, s in zip(self.sizes, self._padded_input_shape): new_shape.extend([r, s]) grad = grad.reshape(new_shape) sum_dims = list(range(0, len(new_shape), 2)) grad = grad.sum(dim=sum_dims) return grad.reshape(self.input_shape)
def __str__(self): return f"repeat({list(self.sizes)})"
[docs] class TileOp(RepeatOp): """Alias for repeat with dimension-padding handled like ``torch.tile``."""
[docs] def __init__(self, input_shape: torch.Size, sizes: tuple[int, ...]): # tile pads *sizes* with leading 1s when tensor has more dims n_pad = len(input_shape) - len(sizes) if n_pad > 0: sizes = (1,) * n_pad + tuple(sizes) super().__init__(input_shape, sizes)
def __str__(self): return f"tile({list(self.sizes)})"
# --------------------------------------------------------------------------- # Element-reordering (self-adjoint or simple inverse) # ---------------------------------------------------------------------------
[docs] class FlipOp(LinearOp): """Reverse elements along *dims* (self-adjoint)."""
[docs] def __init__(self, input_shape: torch.Size, dims): self.dims = dims super().__init__(input_shape, input_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.flip(self.dims)
[docs] def backward(self, grad): return grad.flip(self.dims)
def __str__(self): return f"flip({self.dims})"
[docs] class RollOp(LinearOp): """Circular-shift elements (adjoint is the inverse shift)."""
[docs] def __init__(self, input_shape: torch.Size, shifts, dims): self.shifts = shifts self.dims = dims # Inverse shifts for backward if isinstance(shifts, int): self._inv_shifts = -shifts else: self._inv_shifts = [-s for s in shifts] super().__init__(input_shape, input_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.roll(self.shifts, self.dims)
[docs] def backward(self, grad): return grad.roll(self._inv_shifts, self.dims)
def __str__(self): return f"roll({self.shifts}, {self.dims})"
# --------------------------------------------------------------------------- # Diagonal # ---------------------------------------------------------------------------
[docs] class DiagOp(LinearOp): """Extract or create a diagonal (1D↔2D)."""
[docs] def __init__(self, input_shape: torch.Size, diagonal: int = 0): self.diagonal = diagonal self._input_ndim = len(input_shape) output_shape = _meta_output_shape( lambda x: x.diag(diagonal), input_shape) super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs] def forward(self, x): return x.diag(self.diagonal)
[docs] def backward(self, grad): if self._input_ndim == 1: # Forward was 1D→2D (create diagonal matrix); adjoint extracts diagonal return grad.diag(self.diagonal) else: # Forward was 2D→1D (extract diagonal); adjoint embeds into zeros result = torch.zeros( self.input_shape, dtype=grad.dtype, device=grad.device) n = len(grad) idx = torch.arange(n, device=grad.device) if self.diagonal >= 0: result[idx, idx + self.diagonal] = grad else: result[idx - self.diagonal, idx] = grad return result
def __str__(self): return f"diag({self.diagonal})"