"""Miscellaneous shape LinearOp implementations.
Contains ops not covered by the dedicated reshape/permute/expand modules:
RepeatOp, TileOp, FlipOp, RollOp, DiagOp.
Re-exports reshape/permute/expand ops for backward compatibility.
"""
import torch
from boundlab.linearop._base import LinearOp, LinearOpFlags
# Re-export for backward compatibility
from boundlab.linearop._reshape import (
ReshapeOp, FlattenOp, UnflattenOp, SqueezeOp, UnsqueezeOp, _meta_output_shape,
)
from boundlab.linearop._permute import PermuteOp, TransposeOp
from boundlab.linearop._expand import ExpandOp
# ---------------------------------------------------------------------------
# Repeat / tile
# ---------------------------------------------------------------------------
[docs]
class RepeatOp(LinearOp):
"""Tile-repeat the tensor (adjoint folds and sums repeated blocks)."""
[docs]
def __init__(self, input_shape: torch.Size, sizes: tuple[int, ...]):
self.sizes = sizes
n_pad = len(sizes) - len(input_shape)
self._padded_input_shape = torch.Size([1] * n_pad + list(input_shape))
output_shape = torch.Size(
s * r for s, r in zip(self._padded_input_shape, sizes))
super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING)
[docs]
def forward(self, x):
return x.repeat(*self.sizes)
[docs]
def backward(self, grad):
new_shape = []
for r, s in zip(self.sizes, self._padded_input_shape):
new_shape.extend([r, s])
grad = grad.reshape(new_shape)
sum_dims = list(range(0, len(new_shape), 2))
grad = grad.sum(dim=sum_dims)
return grad.reshape(self.input_shape)
def __str__(self):
return f"<repeat {list(self.sizes)}>"
[docs]
class TileOp(RepeatOp):
"""Alias for repeat with dimension-padding handled like ``torch.tile``."""
[docs]
def __init__(self, input_shape: torch.Size, sizes: tuple[int, ...]):
n_pad = len(input_shape) - len(sizes)
if n_pad > 0:
sizes = (1,) * n_pad + tuple(sizes)
super().__init__(input_shape, sizes)
def __str__(self):
return f"<tile {list(self.sizes)}>"
# ---------------------------------------------------------------------------
# Element-reordering (self-adjoint or simple inverse)
# ---------------------------------------------------------------------------
[docs]
class FlipOp(LinearOp):
"""Reverse elements along *dims* (self-adjoint)."""
[docs]
def __init__(self, input_shape: torch.Size, dims):
self.dims = dims
super().__init__(input_shape, input_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs]
def forward(self, x):
return x.flip(self.dims)
[docs]
def backward(self, grad):
return grad.flip(self.dims)
def __str__(self):
return f"<flip {self.dims}>"
[docs]
class RollOp(LinearOp):
"""Circular-shift elements (adjoint is the inverse shift)."""
[docs]
def __init__(self, input_shape: torch.Size, shifts, dims):
self.shifts = shifts
self.dims = dims
if isinstance(shifts, int):
self._inv_shifts = -shifts
else:
self._inv_shifts = [-s for s in shifts]
super().__init__(input_shape, input_shape, flags=LinearOpFlags.IS_NON_NEGATIVE)
[docs]
def forward(self, x):
return x.roll(self.shifts, self.dims)
[docs]
def backward(self, grad):
return grad.roll(self._inv_shifts, self.dims)
def __str__(self):
return f"<roll {self.shifts} {self.dims}>"
# ---------------------------------------------------------------------------
# Diagonal
# ---------------------------------------------------------------------------
[docs]
class DiagOp(LinearOp):
"""Extract or create a diagonal (1D↔2D)."""
[docs]
def __init__(self, input_shape: torch.Size, diagonal: int = 0):
self.diagonal = diagonal
self._input_ndim = len(input_shape)
output_shape = _meta_output_shape(
lambda x: x.diag(diagonal), input_shape)
super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING | LinearOpFlags.IS_PURE_CONTRACTING)
[docs]
def forward(self, x):
return x.diag(self.diagonal)
[docs]
def backward(self, grad):
if self._input_ndim == 1:
return grad.diag(self.diagonal)
else:
result = torch.zeros(
self.input_shape, dtype=grad.dtype, device=grad.device)
n = len(grad)
idx = torch.arange(n, device=grad.device)
if self.diagonal >= 0:
result[idx, idx + self.diagonal] = grad
else:
result[idx - self.diagonal, idx] = grad
return result
def __str__(self):
return f"<diag {self.diagonal}>"