Source code for boundlab.prop.eqprop
"""Exact propagation helper for symbolic expressions."""
import queue
import torch
import boundlab
from boundlab.linearop._base import LinearOp
[docs]
def eqprop(x: "boundlab.expr.Expr") -> "boundlab.expr.Expr":
"""Work like `ublb` but stops when `==` propagation is no longer possible, returning an expression."""
from boundlab.expr import Expr
from boundlab.expr._affine import AffineSum, ConstVal
from boundlab.linearop import ScalarOp
from boundlab.linearop._base import ZeroOp
# x.simplify_ops_()
from boundlab.expr._tuple import GetTupleItem, TupleExpr
subnodes = x.all_subnodes()
result = ConstVal(x.shape)
weights = [[ZeroOp(x.shape, s) for s in e.shape] if isinstance(e, TupleExpr) else ZeroOp(x.shape, e.shape) for e in subnodes]
weights[-1] = ScalarOp(1.0, x.shape)
for i, node in reversed(list(enumerate(subnodes))):
if isinstance(node, GetTupleItem):
op = weights[i]
weights[subnodes.index(node.tuple_expr)][node._index] += op
elif isinstance(node, TupleExpr):
ops: tuple[LinearOp, ...] = weights[i]
if p := node.backward(*ops, direction='=='):
c, child_ops = p
for child, child_op in zip(node.children, child_ops):
weights[subnodes.index(child)] += child_op
result = result + c
else:
result = result + AffineSum(*[(op, node[i]) for i, op in enumerate(ops)])
elif isinstance(node, Expr):
op: LinearOp = weights[i]
if p := node.backward(op, direction='=='):
c, child_ops = p
for child, child_op in zip(node.children, child_ops):
weights[subnodes.index(child)] += child_op
result = result + c
else:
result = result + op(node)
else:
raise ValueError(f"Unexpected node type {type(node)} in expression DAG.")
return result