synthphonia/forward/enumeration/
mod.rs1use enum_dispatch::enum_dispatch;
2use itertools::Itertools;
3
4
5
6use crate::{expr::{cfg::ProdRule, ops::{Op1, Op1Enum, Op2, Op2Enum, Op3, Op3Enum}, Expr}, galloc::AllocForAny};
10use ext_trait::extension;
11use super::executor::Executor;
12
13
14pub trait Enumerator1 : Op1 {
16 #[inline(always)]
17 fn enumerate(&self, this: &'static Op1Enum, exec: &'static Executor, opnt: [usize; 1]) -> Result<(), ()> {
18 enumerate1(self, this, exec, opnt)
19 }
20}
21#[inline(always)]
22
23pub fn enumerate1(s: &impl Op1, this: &'static Op1Enum, exec: &'static Executor, opnt: [usize; 1]) -> Result<(), ()> {
24 if exec.size() <= s.cost() { return Ok(()); }
25 for (e, v) in exec.data[opnt[0]].size.get_all(exec.size() - s.cost()) {
26 let expr = Expr::Op1(this, e);
27 if let (true, value) = s.try_eval(*v) {
28 exec.enum_expr(expr, value)?;
29 }
30 }
31 Ok(())
32}
33
34pub trait Enumerator2 : Op2 {
36 #[inline(always)]
37 fn enumerate(&self, this: &'static Op2Enum, exec: &'static Executor, nt: [usize; 2]) -> Result<(), ()> {
38 enumerate2(self, this, exec, nt)
39 }
40}
41#[inline(always)]
42pub fn enumerate2(s: &impl Op2, this: &'static Op2Enum, exec: &'static Executor, nt: [usize; 2]) -> Result<(), ()> {
43 if exec.size() <= s.cost() { return Ok(()); }
44 let total = exec.size() - s.cost();
45 for (i, (e1, v1)) in exec.data[nt[0]].size.get_all_under(total) {
46 for (e2, v2) in exec.data[nt[1]].size.get_all(total - i) {
47 let expr = Expr::Op2(this, e1, e2);
48 if let (true, value) = s.try_eval(*v1, *v2) {
49 exec.enum_expr(expr, value)?;
50 }
51 }
52 }
53 Ok(())
54}
55
56pub trait Enumerator3 : Op3 {
58 #[inline(always)]
59 fn enumerate(&self, this: &'static Op3Enum, exec: &'static Executor, nt: [usize; 3]) -> Result<(), ()> {
60 enumerate3(self, this, exec, nt)
61 }
62}
63#[inline(always)]
64pub fn enumerate3(s: &impl Op3, this: &'static Op3Enum, exec: &'static Executor, nt: [usize; 3]) -> Result<(), ()> {
65 if exec.size() < s.cost() { return Ok(()); }
66 let total = exec.size() - s.cost();
67 for (i, (e1, v1)) in exec.data[nt[0]].size.get_all_under(total) {
68 for (j, (e2, v2)) in exec.data[nt[1]].size.get_all_under(total - i) {
69 for (e3, v3) in exec.data[nt[2]].size.get_all(total - i - j) {
70 let expr = Expr::Op3(this, e1, e2, e3);
71 if let (true, value) = s.try_eval(*v1, *v2, *v3) {
72 exec.enum_expr(expr, value)?;
73 }
74 }
75 }
76 }
77 Ok(())
78}
79
80impl Enumerator1 for Op1Enum {
81 #[inline]
82 fn enumerate(&self, this: &'static Op1Enum, exec: &'static Executor, opnt: [usize; 1]) -> Result<(), ()> {
83 macro_rules! _do {($($op:ident)*) => {$(
84 if let Self::$op(a) = self {
85 return a.enumerate(this, exec, opnt);
86 }
87 )*};}
88 crate::for_all_op1!();
89 panic!()
90 }
91}
92
93impl Enumerator2 for Op2Enum {
94 #[inline]
95 fn enumerate(&self, this: &'static Op2Enum, exec: &'static Executor, opnt: [usize; 2]) -> Result<(), ()> {
96 macro_rules! _do {($($op:ident)*) => {$(
97 if let Self::$op(a) = self {
98 return a.enumerate(this, exec, opnt);
99 }
100 )*};}
101 crate::for_all_op2!();
102 panic!()
103 }
104}
105
106impl Enumerator3 for Op3Enum {
107 #[inline]
108 fn enumerate(&self, this: &'static Op3Enum, exec: &'static Executor, opnt: [usize; 3]) -> Result<(), ()> {
109 macro_rules! _do {($($op:ident)*) => {$(
110 if let Self::$op(a) = self {
111 return a.enumerate(this, exec, opnt);
112 }
113 )*};}
114 crate::for_all_op3!();
115 panic!()
116 }
117}
118
119#[extension(pub trait ProdRuleEnumerateExt)]
120impl ProdRule {
121 fn enumerate(&self, exec: &'static Executor) -> Result<(), ()> {
123 match self {
124 ProdRule::Const(c) => {
125 if exec.size() == 1 {
126 exec.enum_expr(Expr::Const(*c), c.value(exec.ctx.len()))?;
127 }
128 Ok(())
129 }
130 ProdRule::Var(v) => {
131 if exec.size() == 1 {
132 exec.enum_expr(Expr::Var(*v), *exec.ctx.get(*v).unwrap())?;
133 }
134 Ok(())
135 }
136 ProdRule::Op1(op1, nt1) => {
137 op1.enumerate(op1, exec, [*nt1])
138 }
139 ProdRule::Op2(op2, nt1, nt2) => {
140 op2.enumerate(op2, exec, [*nt1, *nt2])
141 }
142 ProdRule::Op3(op3, nt1, nt2, nt3) => {
143 op3.enumerate(op3, exec, [*nt1, *nt2, *nt3])
144 }
145 ProdRule::Nt(_) => todo!(),
146 }
147 }
148}
149
150
151