Source code for boundlab.linearop._permute
"""Permutation and transposition LinearOp implementations."""
import torch
from boundlab.linearop._base import LinearOp, LinearOpFlags
from boundlab.utils import merge_name
[docs]
class PermuteOp(LinearOp):
"""Permute dimensions of the input tensor."""
[docs]
def __init__(self, input_shape: torch.Size, dims: tuple[int, ...]):
self.dims = list(dims)
self.inv_dims = [0] * len(dims)
for i, d in enumerate(dims):
self.inv_dims[d] = i
output_shape = torch.Size(input_shape[d] for d in dims)
super().__init__(input_shape, output_shape, flags=LinearOpFlags.IS_NON_NEGATIVE | LinearOpFlags.IS_PURE_EXPANDING | LinearOpFlags.IS_PURE_CONTRACTING)
[docs]
def forward(self, x):
return x.permute(*self.dims)
[docs]
def backward(self, grad):
return grad.permute(*self.inv_dims)
[docs]
def vforward(self, x):
n = len(self.dims)
batch_ndim = x.dim() - n
perm = self.dims + [n + i for i in range(batch_ndim)]
return x.permute(*perm)
[docs]
def vbackward(self, grad):
n = len(self.inv_dims)
batch_ndim = grad.dim() - n
perm = list(range(batch_ndim)) + [batch_ndim + d for d in self.inv_dims]
return grad.permute(*perm)
def __matmul__(self, other):
from ._einsum import EinsumOp
if isinstance(other, PermuteOp):
assert self.input_shape == other.output_shape
new_dims = [other.dims[self.dims[i]] for i in range(len(self.dims))]
return PermuteOp(other.input_shape, tuple(new_dims))
if isinstance(other, EinsumOp):
assert self.input_shape == other.output_shape
new_output_dims = [other.output_dims[self.dims[i]] for i in range(len(other.output_dims))]
result = EinsumOp(other.tensor, other.input_dims, new_output_dims, name=merge_name(self, "@", other))
assert result.input_shape == other.input_shape, f"PermuteOp.__matmul__: input_shape {result.input_shape} != {other.input_shape}"
assert result.output_shape == self.output_shape, f"PermuteOp.__matmul__: output_shape {result.output_shape} != {self.output_shape}"
return result
return NotImplemented
def __rmatmul__(self, other):
from ._einsum import EinsumOp
if isinstance(other, EinsumOp):
assert self.output_shape == other.input_shape
new_input_dims = [other.input_dims[self.inv_dims[i]] for i in range(len(other.input_dims))]
result = EinsumOp(other.tensor, new_input_dims, other.output_dims, name=merge_name(other, "@", self))
assert result.input_shape == self.input_shape, f"PermuteOp.__rmatmul__: input_shape {result.input_shape} != {self.input_shape}"
assert result.output_shape == other.output_shape, f"PermuteOp.__rmatmul__: output_shape {result.output_shape} != {other.output_shape}"
return result
return super().__rmatmul__(other)
def __str__(self):
return f"<permute {self.dims}>"
[docs]
class TransposeOp(PermuteOp):
"""Swap two dimensions of the input tensor — special case of PermuteOp."""
[docs]
def __init__(self, input_shape: torch.Size, dim0: int, dim1: int):
self.dim0 = dim0
self.dim1 = dim1
dims = list(range(len(input_shape)))
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
super().__init__(input_shape, tuple(dims))
def __str__(self):
return f"<transpose {self.dim0} {self.dim1}>"