"""Reshape LinearOp implementations.
ReshapeOp is the base class; FlattenOp, UnflattenOp, SqueezeOp, and
UnsqueezeOp are thin subclasses that delegate to ReshapeOp logic.
"""
from functools import reduce
import torch
from boundlab.linearop._base import LinearOp, LinearOpFlags, ScalarOp
from boundlab.utils import merge_name
def _meta_output_shape(fn, input_shape: torch.Size) -> torch.Size:
"""Compute output shape by tracing *fn* on a meta-device tensor."""
return fn(torch.empty(input_shape, device="meta")).shape
[docs]
class ReshapeOp(LinearOp):
"""Reshape (view) the input tensor to *output_shape*."""
[docs]
def __init__(self, input_shape: torch.Size, output_shape: tuple[int, ...]):
if not isinstance(output_shape, torch.Size):
output_shape = _meta_output_shape(lambda x: x.reshape(*output_shape), input_shape)
self.target_shape = tuple(output_shape)
self.reshape_groups = []
self.dims_map = {}
self.dims_map_inv = {}
# Edge case: either side is a scalar (0-dim). Allowed only when
# both sides have numel == 1.
if len(input_shape) == 0 or len(output_shape) == 0:
input_numel = reduce(lambda x, y: x * y, input_shape, 1)
output_numel = reduce(lambda x, y: x * y, output_shape, 1)
assert input_numel == output_numel, \
f"ReshapeOp: cannot align {input_shape} -> {output_shape} (numel {input_numel} != {output_numel})"
if len(input_shape) > 0 or len(output_shape) > 0:
self.reshape_groups.append((0, len(input_shape) - 1, 0, len(output_shape) - 1))
super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING | LinearOpFlags.IS_PURE_CONTRACTING)
return
i, j = 0, 0
input_win = 0
output_win = 0
while True:
input_numel = reduce(lambda x, y: x * y, input_shape[input_win:i+1], 1)
output_numel = reduce(lambda x, y: x * y, output_shape[output_win:j+1], 1)
if input_numel == output_numel:
if i > input_win or j > output_win:
self.reshape_groups.append((input_win, i, output_win, j))
else:
assert i == input_win and j == output_win
self.dims_map[i] = j
self.dims_map_inv[j] = i
i += 1
j += 1
input_win = i
output_win = j
if input_numel < output_numel:
i += 1
if input_numel > output_numel:
j += 1
if i >= len(input_shape) or j >= len(output_shape):
while i < len(input_shape) and input_shape[i] == 1:
i += 1
while j < len(output_shape) and output_shape[j] == 1:
j += 1
if input_win < i or output_win < j:
self.reshape_groups.append((input_win, max(i - 1, input_win), output_win, max(j - 1, output_win)))
assert i == len(input_shape) and j == len(output_shape), \
f"ReshapeOp: cannot align {input_shape} -> {output_shape} (i={i}, j={j})"
break
super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING | LinearOpFlags.IS_PURE_CONTRACTING)
[docs]
def forward(self, x):
return x.reshape(self.target_shape)
[docs]
def backward(self, grad):
return grad.reshape(self.input_shape)
[docs]
def vforward(self, x):
extra = x.shape[len(self.input_shape):]
return x.reshape(self.target_shape + tuple(extra))
[docs]
def vbackward(self, grad):
extra = grad.shape[:-len(self.output_shape)] if len(self.output_shape) > 0 else grad.shape
return grad.reshape(tuple(extra) + tuple(self.input_shape))
def __matmul__(self, other):
from ._einsum import EinsumOp
if isinstance(other, ReshapeOp):
# self @ other = self(other(x)): maps other.input_shape → self.output_shape
return ReshapeOp(other.input_shape, self.output_shape)
if isinstance(other, EinsumOp):
for in_s, in_e, out_s, out_e in self.reshape_groups:
if any(i in other.mul_dims for i in other.output_dims[in_s:in_e + 1]):
return NotImplemented
op = other.permute_for_output()
assert all(i == p for i, p in enumerate(op.output_dims))
tensor = self.vforward(op.tensor)
def dims_map(d):
shift = len(self.output_shape) - len(self.input_shape)
return self.dims_map[d] if d < len(self.input_shape) else d + shift
output_dims = list(range(len(self.output_shape)))
input_dims = [dims_map(d) for d in op.input_dims]
result = EinsumOp(tensor, input_dims, output_dims, name=merge_name(self, "@", other))
assert result.input_shape == other.input_shape, f"ReshapeOp.__matmul__: input_shape {result.input_shape} != {other.input_shape}"
assert result.output_shape == self.output_shape, f"ReshapeOp.__matmul__: output_shape {result.output_shape} != {self.output_shape}"
return result
return NotImplemented
def __rmatmul__(self, other):
from ._einsum import EinsumOp
if isinstance(other, EinsumOp):
for in_s, in_e, out_s, out_e in self.reshape_groups:
if any(i in other.mul_dims for i in other.input_dims[out_s:out_e + 1]):
return super().__rmatmul__(other)
op = other.permute_for_input()
n_non_input = op.tensor.dim() - len(op.input_dims)
assert all(i + n_non_input == p for i, p in enumerate(op.input_dims))
tensor = self.vbackward(op.tensor)
def dims_map_inv(d):
return self.dims_map_inv[d] if d >= 0 else d
input_dims = list(range(n_non_input, n_non_input + len(self.input_shape)))
output_dims = [dims_map_inv(d - n_non_input) + n_non_input for d in op.output_dims]
result = EinsumOp(tensor, input_dims, output_dims, name=merge_name(other, "@", self))
assert result.input_shape == self.input_shape, f"ReshapeOp.__rmatmul__: input_shape {result.input_shape} != {self.input_shape}"
assert result.output_shape == other.output_shape, f"ReshapeOp.__rmatmul__: output_shape {result.output_shape} != {other.output_shape}"
return result
return super().__rmatmul__(other)
def __str__(self):
return f"<reshape {list(self.input_shape)} -> {list(self.target_shape)}>"
[docs]
class FlattenOp(ReshapeOp):
"""Flatten dimensions [start_dim .. end_dim] into a single dimension."""
[docs]
def __init__(self, input_shape: torch.Size, start_dim: int = 0, end_dim: int = -1):
self.start_dim = start_dim
self.end_dim = end_dim if end_dim >= 0 else len(input_shape) + end_dim
self.original_sizes = input_shape[self.start_dim:self.end_dim + 1]
target_shape = _meta_output_shape(
lambda x: x.flatten(start_dim, end_dim), input_shape)
super().__init__(input_shape, target_shape)
[docs]
def forward(self, x):
return x.flatten(self.start_dim, self.end_dim)
[docs]
def backward(self, grad):
return grad.unflatten(self.start_dim, self.original_sizes)
def __str__(self):
return f"<flatten {self.start_dim} {self.end_dim}>"
[docs]
class UnflattenOp(ReshapeOp):
"""Unflatten dimension *dim* into *sizes*."""
[docs]
def __init__(self, input_shape: torch.Size, dim: int, sizes: tuple[int, ...]):
self.dim = dim
self.sizes = sizes
self.end_dim = dim + len(sizes) - 1
target_shape = _meta_output_shape(
lambda x: x.unflatten(dim, sizes), input_shape)
super().__init__(input_shape, target_shape)
[docs]
def forward(self, x):
return x.unflatten(self.dim, self.sizes)
[docs]
def backward(self, grad):
return grad.flatten(self.dim, self.end_dim)
def __str__(self):
return f"<unflatten {self.dim} {list(self.sizes)}>"
[docs]
class SqueezeOp(ReshapeOp):
"""Remove size-1 dimension(s)."""
[docs]
def __init__(self, input_shape: torch.Size, dim=None):
self.dim = dim
if dim is not None:
self._is_noop = (input_shape[dim] != 1)
if self._is_noop:
target_shape = input_shape
else:
target_shape = torch.Size(
s for i, s in enumerate(input_shape) if i != dim)
else:
self._is_noop = all(s != 1 for s in input_shape)
self._squeezed_dims = [i for i, s in enumerate(input_shape) if s == 1]
target_shape = torch.Size(s for s in input_shape if s != 1)
super().__init__(input_shape, target_shape)
[docs]
def forward(self, x):
return x.squeeze(self.dim) if self.dim is not None else x.squeeze()
[docs]
def backward(self, grad):
if self._is_noop:
return grad
if self.dim is not None:
return grad.unsqueeze(self.dim)
for d in self._squeezed_dims:
grad = grad.unsqueeze(d)
return grad
def __matmul__(self, other):
from ._einsum import EinsumOp
if isinstance(other, EinsumOp):
assert self.input_shape == other.output_shape
if self.dim is not None:
squeezed = [] if self._is_noop else [self.dim]
else:
squeezed = list(self._squeezed_dims)
op = other
for pos in sorted(squeezed, reverse=True):
op = op.squeeze_output(pos)
result = EinsumOp(op.tensor, op.input_dims, op.output_dims, name=merge_name(self, "@", other))
assert result.input_shape == other.input_shape, f"SqueezeOp.__matmul__: input_shape {result.input_shape} != {other.input_shape}"
assert result.output_shape == self.output_shape, f"SqueezeOp.__matmul__: output_shape {result.output_shape} != {self.output_shape}"
return result
return NotImplemented
def __rmatmul__(self, other):
from ._einsum import EinsumOp
if isinstance(other, EinsumOp):
assert self.output_shape == other.input_shape
if self.dim is not None:
squeezed = [] if self._is_noop else [self.dim]
else:
squeezed = list(self._squeezed_dims)
op = other
for pos in sorted(squeezed):
op = op.unsqueeze_input(pos)
result = EinsumOp(op.tensor, op.input_dims, op.output_dims, name=merge_name(other, "@", self))
assert result.input_shape == self.input_shape, f"SqueezeOp.__rmatmul__: input_shape {result.input_shape} != {self.input_shape}"
assert result.output_shape == other.output_shape, f"SqueezeOp.__rmatmul__: output_shape {result.output_shape} != {other.output_shape}"
return result
return super().__rmatmul__(other)
def __str__(self):
return f"<squeeze {self.dim}>"
[docs]
class UnsqueezeOp(ReshapeOp):
"""Insert a size-1 dimension at *dim*."""
[docs]
def __init__(self, input_shape: torch.Size, dim: int):
if dim < 0:
dim += len(input_shape) + 1
assert 0 <= dim <= len(input_shape), f"Invalid unsqueeze dim {dim} for input shape {input_shape}"
self.dim = dim
target_shape = _meta_output_shape(lambda x: x.unsqueeze(dim), input_shape)
super().__init__(input_shape, target_shape)
[docs]
def forward(self, x):
return x.unsqueeze(self.dim)
[docs]
def backward(self, grad):
return grad.squeeze(self.dim)
[docs]
def vforward(self, x):
return x.unsqueeze(self.dim)
[docs]
def vbackward(self, grad):
batch_ndim = grad.dim() - len(self.output_shape)
return grad.squeeze(batch_ndim + self.dim)
def __matmul__(self, other):
from ._einsum import EinsumOp
if isinstance(other, EinsumOp):
assert self.input_shape == other.output_shape, f"{self}, {other}"
op = other.unsqueeze_output(self.dim)
result = EinsumOp(op.tensor, op.input_dims, op.output_dims, name=merge_name(self, "@", other))
assert result.input_shape == other.input_shape, f"UnsqueezeOp.__matmul__: input_shape {result.input_shape} != {other.input_shape}"
assert result.output_shape == self.output_shape, f"UnsqueezeOp.__matmul__: output_shape {result.output_shape} != {self.output_shape}"
return result
return NotImplemented
def __rmatmul__(self, other):
from ._einsum import EinsumOp
if isinstance(other, EinsumOp):
assert self.output_shape == other.input_shape
op = other.squeeze_input(self.dim)
result = EinsumOp(op.tensor, op.input_dims, op.output_dims, name=merge_name(other, "@", self))
assert result.input_shape == self.input_shape, f"UnsqueezeOp.__rmatmul__: input_shape {result.input_shape} != {self.input_shape}"
assert result.output_shape == other.output_shape, f"UnsqueezeOp.__rmatmul__: output_shape {result.output_shape} != {other.output_shape}"
return result
return super().__rmatmul__(other)
def __str__(self):
return f"<unsqueeze {list(self.input_shape)} {self.dim}>"