from __future__ import annotations
r"""Utility Functions for BoundLab
This module provides helper functions used throughout the BoundLab framework.
Examples
--------
>>> from boundlab.utils import merge_name
>>> merge_name("x", "+", "y")
'(x + y)'
"""
from torch._subclasses.fake_tensor import FakeTensorMode
from collections import Counter
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, TypeAlias, TypeVar as _TypeVar, Union
if TYPE_CHECKING:
from boundlab.expr import Expr
import torch
import copy
A = _TypeVar("A")
Triple: TypeAlias = tuple[A, A, A]
__all__ = ["Triple", "merge_name"]
[docs]
def merge_name(name1, op: str, name2) -> str | None:
"""Merge two optional names into a single name for a composed operation.
Examples
--------
>>> merge_name("left", "@", "right")
'(left @ right)'
>>> merge_name(None, "@", "right") is None
True
"""
name1 = name1.name if hasattr(name1, "name") else name1
name2 = name2.name if hasattr(name2, "name") else name2
if type(name1) is not str or type(name2) is not str:
return None
if name1 is not None and name2 is not None:
return f"({name1} {op} {name2})"
return None
def is0(input) -> bool:
"""Helper function to check if a tensor is identically zero."""
if isinstance(input, int):
return input == 0
else:
return False
def not0(input) -> bool:
"""Helper function to check if a tensor is not identically zero."""
if isinstance(input, int):
return input != 0
else:
return True
def multiple_diagnonal(tensor, dims: list[tuple[int, int]]) -> tuple[torch.Tensor, list[int]]:
"""Extract multiple diagonals from a tensor iteratively.
Each ``(dim1, dim2)`` pair is passed to :func:`torch.diagonal`, which
removes those two dims and appends a new trailing dim holding the
diagonal. ``dims`` references the *original* tensor's dimensions.
Returns:
``(new_tensor, dim_map)`` where ``dim_map[i]`` is the current
position of the original dim ``i``. Two original dims that got
fused into the same diagonal map to the same (last) position.
Examples:
>>> import torch
>>> from boundlab.utils import multiple_diagnonal
>>> t = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
>>> out, dim_map = multiple_diagnonal(t, [(0, 1)])
>>> out.tolist()
[1.0, 5.0, 9.0]
>>> dim_map
[0, 0]
"""
dim_map = list(range(len(tensor.shape)))
for dim1, dim2 in dims:
assert dim1 != dim2
d1, d2 = dim_map[dim1], dim_map[dim2]
tensor = torch.diagonal(tensor, dim1=d1, dim2=d2)
last = len(tensor.shape) - 1
new_dim_map = []
for v in dim_map:
if v == d1 or v == d2:
new_dim_map.append(last)
else:
shift = (1 if v > d1 else 0) + (1 if v > d2 else 0)
new_dim_map.append(v - shift)
dim_map = new_dim_map
return tensor, dim_map
def multiple_diag_embed(tensor, dims: Counter[int, int]) -> tuple[torch.Tensor, list]:
"""Embed specified dims as diagonals, iteratively expanding the tensor.
Each dim in ``dims`` is turned into a ``(k, k)`` diagonal via
:func:`torch.diag_embed`, where ``k`` is that dim's size. ``dims``
references the *original* tensor's dimensions.
Returns:
``(new_tensor, dim_map)`` where each original dim ``i`` maps to
either an ``int`` (its current position) or a ``(p, q)`` tuple
of two positions when that dim was embedded as a diagonal.
Examples:
>>> import torch
>>> from boundlab.utils import multiple_diag_embed
>>> t = torch.tensor([1., 2., 3.])
>>> out, dim_map = multiple_diag_embed(t, [0])
>>> out.tolist()
[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]
>>> dim_map
[(0, 1)]
"""
dim_map = [[i] for i in range(len(tensor.shape))]
for dim, count in dims.items():
assert count >= 2
d = dim_map[dim]
assert len(d) == 1
d = d[0]
tensor = tensor.transpose(d, -1)
last = len(tensor.shape) - 1
# After transpose, swap positions d and last in dim_map entries.
def _swap(v):
return list(d if x == last else last if x == d else x for x in v)
dim_map = [_swap(v) for v in dim_map]
for _ in range(count - 1):
tensor = torch.diag_embed(tensor, dim1=d, dim2=d + 1)
# After diag_embed: two new dims inserted at d, d+1; existing dims >= d shift by +2.
def _shift(v):
return list(x + count if x >= d else x for x in v)
dim_map = [_shift(v) for v in dim_map]
dim_map[dim] = [d + i for i in range(count)]
return tensor, dim_map
@dataclass
class EQCondition:
eqclasses: set[tuple[int, ...]]
def __post_init__(self):
new_eqclasses = set()
for eqclass in self.eqclasses:
if len(eqclass) <= 1:
continue
if any(set(eqclass).intersection(set(c)) for c in new_eqclasses):
raise ValueError(f"EQCondition cannot have overlapping eqclasses: {eqclass} overlaps with existing classes.")
new_eqclasses.add(eqclass)
self.eqclasses = new_eqclasses
def _add_tuple(self, tup: tuple[int, ...]) -> "EQCondition":
for eqclass in self.eqclasses:
if any(x in eqclass for x in tup):
new_eqclass = tuple(set(eqclass) | set(tup))
new_eqclasses = set(c for c in self.eqclasses if c != eqclass)
new_eqclasses.add(new_eqclass)
return EQCondition(new_eqclasses)
return EQCondition(self.eqclasses | {tup})
def _sub_tuple(self, tup: tuple[int, ...]) -> "EQCondition":
new_eqclasses = set()
for eqclass in self.eqclasses:
assert type(eqclass) == tuple
if all(x not in tup for x in eqclass):
new_eqclasses.add(eqclass)
else:
new_eqclasses.add(tuple([tup[0]] + [x for x in eqclass if x not in tup]))
return EQCondition(new_eqclasses)
def __add__(self, other: "EQCondition") -> "EQCondition":
result = EQCondition({})
for eqclass in self.eqclasses:
result = result._add_tuple(eqclass)
for eqclass in other.eqclasses:
result = result._add_tuple(eqclass)
return result
def __le__(self, other: "EQCondition") -> "EQCondition":
for eqclass in self.eqclasses:
if not any(set(eqclass).issubset(set(eqother)) for eqother in other.eqclasses):
return False
return True
def __ge__(self, other: "EQCondition") -> "EQCondition":
return other.__le__(self)
def __sub__(self, other: "EQCondition") -> "EQCondition":
assert self >= other, f"Cannot subtract {other} from {self} since {self} is not a stronger condition than {other}"
result = copy.deepcopy(self)
for eqclass in other.eqclasses:
result = result._sub_tuple(eqclass)
return result
def __str__(self):
return " & ".join(" = ".join(str(i) for i in eqclass) for eqclass in self.eqclasses)
def to_pairs(self):
pairs = list()
for eqclass in self.eqclasses:
for i in eqclass[1:]:
pairs.append((eqclass[0], i))
return pairs
def all_pairs(self):
return all(len(tup) <= 2 for tup in self.eqclasses)
def current_fake_mode():
mode = torch.utils._python_dispatch._get_current_dispatch_mode()
return mode if isinstance(mode, FakeTensorMode) else None
def pairwise_diff(x: Expr, dim: int = -1) -> Expr:
"""Build ``d[..., i, j, ...] = x[..., j, ...] - x[..., i, ...]`` as a
single :class:`EinsumOp` applied to *x*.
Using one LinearOp (rather than two broadcasted terms combined via
subtraction) avoids the ``SumOp`` merging two structurally similar
:class:`ExpandOp` s that would otherwise cancel the noise contribution.
"""
if dim < 0:
dim += len(x.shape)
N = x.shape[dim]
l = x.unsqueeze(dim).expand_on(dim, N)
r = x.unsqueeze(dim+1).expand_on(dim + 1, N)
return r - l
def remove_diagonal(tensor: Union[torch.Tensor, Expr], dim1: int=0, dim2: int=1) -> Union[torch.Tensor, Expr]:
"""Remove the diagonal along the specified dims, returning a tensor with one fewer dimension."""
assert dim1 == 0 and dim2 == 1
assert tensor.shape[0] == tensor.shape[1]
N = tensor.shape[0]
tensor = tensor.reshape(N * N, -1)
return tensor[1:].reshape(N-1, N+1, -1)[:, :-1].reshape(N, N-1, -1)
def inverse_permutation(perm: list[int]) -> list[int]:
"""Return the inverse of a permutation."""
inv = [0] * len(perm)
for i, p in enumerate(perm):
inv[p] = i
return inv