Source code for boundlab.linearop._expand

"""Expand LinearOp — implemented as EinsumOp for broadcast expansion."""

import torch

from boundlab.linearop._base import ScalarOp, ComposedOp


[docs] class ExpandOp: """Broadcast-expand dimensions. Implemented as an EinsumOp with a ones template tensor. When ``len(input_shape) < len(output_shape)``, leading size-1 dims are prepended to input_shape automatically. Uses ``torch.tensor(1.0).expand(output_shape)`` as the einsum tensor to save memory (stride-zero storage). """
[docs] def __new__(cls, input_shape: torch.Size, output_shape: torch.Size): if not isinstance(input_shape, torch.Size): input_shape = torch.Size(input_shape) if not isinstance(output_shape, torch.Size): output_shape = torch.Size(output_shape) n_new = len(output_shape) - len(input_shape) assert n_new >= 0, \ f"ExpandOp: output cannot have fewer dims than input ({len(output_shape)} < {len(input_shape)})" if n_new == 0 and input_shape == output_shape: return ScalarOp(1.0, input_shape) if n_new == 0: return _make_expand_einsum(input_shape, output_shape) # Prepend n_new size-1 dims then expand from boundlab.linearop._reshape import UnsqueezeOp padded_input = torch.Size([1] * n_new + list(input_shape)) expand_op = _make_expand_einsum(padded_input, output_shape) # Compose: expand_op @ unsqueeze(0)^n_new # Each unsqueeze adds a dim at position 0, building up from input_shape to padded_input op = expand_op cur_shape = padded_input for i in range(n_new): # Remove the leading 1 to get intermediate shape prev_shape = torch.Size(list(cur_shape)[1:]) unsq = UnsqueezeOp(prev_shape, dim=0) op = op @ unsq cur_shape = prev_shape return op
def _make_expand_einsum(input_shape: torch.Size, output_shape: torch.Size): """Build an EinsumOp that performs broadcast expansion.""" from boundlab.linearop._einsum import EinsumOp tensor = torch.tensor(1.0).expand(*output_shape) output_dims = list(range(len(output_shape))) input_dims = [] for d in range(len(input_shape)): if input_shape[d] == output_shape[d]: # Shared dim (mul_dim) input_dims.append(d) else: assert input_shape[d] == 1, \ f"ExpandOp: dim {d} has input size {input_shape[d]} != 1 and != output size {output_shape[d]}" # Add a new size-1 tensor dim for this input dim new_dim = tensor.dim() tensor = tensor.unsqueeze(new_dim) input_dims.append(new_dim) return EinsumOp(tensor, input_dims, output_dims, name=f"<expand {list(input_shape)} -> {list(output_shape)}>")