1use std::{cell::RefCell, borrow::{Borrow, BorrowMut}};
2
3use bumpalo::Bump;
4
5
6pub mod bits;
7
8use bits::BoxSliceExt;
9pub use bits::Bits;
10
11use crate::{debg, debg2, expr::Expr};
12
13pub enum SubProblem<'a> {
15 Unsolved(Bits, f32),
16 Accept(usize),
17 Ite{ expr: usize, entropy: f32, t: SubProb<'a>, f: SubProb<'a> }
18}
19
20impl<'a> SubProblem<'a> {
21 #[inline]
22 pub fn add_subproblems(&self, subproblem: &mut Vec<(SubProb<'a>, bool)>) {
24 if let SubProblem::Ite { expr, entropy, t, f } = self {
25 subproblem.push((f, true));
26 subproblem.push((t, true));
27 }
28 }
29}
30
31pub type SubProb<'a> = &'a RefCell<SubProblem<'a>>;
32
33pub struct TreeLearning<'a, 'b> {
35 pub size: usize,
36 root: SubProb<'a>,
37 pub subproblems: Vec<SubProb<'a>>,
38 limit: usize,
39 pub conditions: &'b [(&'static Expr, Bits)],
40 pub options: Vec<(&'static Expr, Bits)>,
41 pub bump: &'a Bump,
42 pub solved: bool,
43}
44
45pub enum SelectResult {
54 Accept(usize),
55 Ite(usize, f32, (Bits, f32), (Bits, f32)),
56 Failed,
57}
58
59
60impl<'a, 'b> TreeLearning<'a, 'b> {
61 pub fn new_in(size: usize, conditions: &'b [(&'static Expr, Bits)], options: Vec<(&'static Expr, Bits)>, bump: &'a Bump, limit: usize) -> Self {
66 let mut this = Self {
67 size,
68 root: bump.alloc(RefCell::new(SubProblem::Unsolved(bits::boxed_ones(size), 0.0))),
69 subproblems: Vec::new(),
70 conditions,
71 options,
72 bump,
73 solved: false,
74 limit
75 };
76 let root_entro = this.entropy(& bits::boxed_ones(size));
77 if let SubProblem::Unsolved(a, entropy) = &mut *this.root.borrow_mut() {
78 *entropy = root_entro;
79 }
80 this.subproblems.push(this.root);
81 this
82 }
83
84 #[inline]
85 pub fn entropy(&self, bits: & Bits) -> f32 {
87
88 let mut vec: Vec<_> = self.options.iter().enumerate().map(|(i, b)| {
89 let mut res = b.1.clone();
90 res.conjunction_assign(bits);
91 (i, res.count_ones(), res)
92 }).collect();
93 vec.sort_by_key(|a| u32::MAX - a.1);
94
95 let total = bits.count_ones();
96 let mut rest = bits.clone();
97 let mut rest_count = rest.count_ones();
98 let mut res = 0.0;
99 for (_, _, b) in vec {
100 rest.difference_assign(&b);
101 let count = rest_count - rest.count_ones();
102 let p = count as f32 / total as f32;
103 if p > 0.0 {
104 res += - p * p.log2();
105 }
106 rest_count = rest.count_ones();
107 }
108 res
109 }
110
111 pub fn cond_entropy(&self, bits: &Bits, condition: &Bits) -> (f32, (Bits, f32), (Bits, f32)) {
113 let total = bits.count_ones();
114 let mut and_bits = bits.clone();
115 and_bits.conjunction_assign(condition);
116 let and_entro = self.entropy(&and_bits);
117 let and_count = and_bits.count_ones();
118 let mut diff_bits = bits.clone();
119 diff_bits.difference_assign(condition);
120 let diff_entro = self.entropy(&diff_bits);
121 let diff_count = diff_bits.count_ones();
122 if and_count == 0 || diff_count == 0 {
123 (1e10, (and_bits, and_entro), (diff_bits, diff_entro))
124 } else {
125 (
126 (and_entro * and_count as f32 + diff_entro * diff_count as f32) / total as f32,
127 (and_bits, and_entro), (diff_bits, diff_entro)
128 )
129 }
130 }
131
132 #[inline]
133 pub fn select(&self, unsolved: &SubProblem<'a>) -> SelectResult {
135 if let SubProblem::Unsolved(bits, entro) = unsolved {
136 if *entro <= 0.0001 {
137 if let Some((i, _)) = self.options.iter().enumerate().find(|(_, x)| bits.subset(&x.1) ) {
138 return SelectResult::Accept(i)
139 }
140 }
141 let (i, (centro, tb, fb)) = self.conditions.iter().enumerate()
142 .map(|(i, (e, cb))| {
143 let ce = self.cond_entropy(bits, cb);
144 (i, ce)
145 })
146 .min_by(|a, b| a.1.0.partial_cmp(&b.1.0).unwrap())
147 .expect("At least have one condition.");
148 if centro - 0.00001 < *entro {
149 SelectResult::Ite(i, centro, tb, fb)
150 } else {
151 SelectResult::Failed
152 }
153 } else { panic!("last should be unsolved.") }
154 }
155
156 pub fn run(&mut self) -> bool {
158 let mut counter = 1;
159 while let Some(last) = self.subproblems.pop() {
160 let sel = self.select(&last.borrow());
161 match sel {
162 SelectResult::Accept(i) => {
163 *last.borrow_mut() = SubProblem::Accept(i);
164 }
165 SelectResult::Ite(expr, entropy, t, f) => {
166 let tb = self.bump.alloc(SubProblem::Unsolved(t.0, t.1).into());
167 let fb = self.bump.alloc(SubProblem::Unsolved(f.0, f.1).into());
168 self.subproblems.push(fb);
169 self.subproblems.push(tb);
170 *last.borrow_mut() = SubProblem::Ite{ expr, entropy, t: tb, f: fb };
171 counter += 2;
172 if counter > self.limit {
173 debg2!("{:?}", self);
174 return false;
175 }
176 }
177 SelectResult::Failed => {
178 debg2!("{:?}", self);
179 return false;
180 }
181 }
182 }
183 self.solved = true;
184 debg2!("{:?}", self);
185 true
186 }
187
188 fn fmt_recursive(&self, f: &mut std::fmt::Formatter<'_>, node: SubProb<'a>, indent: &mut String) -> std::fmt::Result {
190 match &*node.borrow() {
191 SubProblem::Unsolved(bits, entropy) =>
192 writeln!(f, "{indent}?? {} {:x?}", entropy, bits),
193 SubProblem::Accept(i) =>
194 writeln!(f, "{indent}{:?}", self.options[*i].0),
195 SubProblem::Ite { expr, entropy, t: tb, f: fb } => {
196 writeln!(f, "{indent}ite {:?} {:x?}", self.conditions[*expr].0, self.conditions[*expr].1)?;
197 indent.push_str(" ");
198 self.fmt_recursive(f, tb, indent)?;
199 self.fmt_recursive(f, fb, indent)?;
200 indent.pop(); indent.pop();
201 Ok(())
202 }
203 }
204 }
205 fn size_recursive(&self, node: SubProb<'a>) -> usize {
207 match &*node.borrow() {
208 SubProblem::Unsolved(bits, entropy) => 1,
209 SubProblem::Accept(i) => 1,
210 SubProblem::Ite { expr, entropy, t: tb, f: fb } => 1 + self.size_recursive(tb) + self.size_recursive(fb),
211 }
212 }
213 fn cover_recursive(&self, node: SubProb<'a>) -> Bits {
215 match &*node.borrow() {
216 SubProblem::Unsolved(bits, entropy) => bits.clone(),
217 SubProblem::Accept(i) => self.options[*i].1.clone(),
218 SubProblem::Ite { expr, entropy, t: tb, f: fb } => {
219 let mut t = self.cover_recursive(tb);
220 let mut f = self.cover_recursive(fb);
221 let bits = self.conditions[*expr].1.clone();
222 t.conjunction_assign(&bits);
223 f.difference_assign(&bits);
224 t.union_assign(&f);
225 t
226 }
227 }
228 }
229 fn expr_recursizve(&self, node: SubProb<'a>) -> &'static Expr {
231 match &*node.borrow() {
232 SubProblem::Unsolved(bits, entropy) => panic!("Still subproblem remain."),
233 SubProblem::Accept(i) => self.options[*i].0,
234 SubProblem::Ite { expr, entropy, t: tb, f: fb } => {
235 let t = self.expr_recursizve(tb);
236 let f = self.expr_recursizve(fb);
237 let cond = self.conditions[*expr].0;
238 cond.ite(t, f)
239 }
240 }
241 }
242 fn unsolved_recursive(&self, node: SubProb<'a>, result: &mut Vec<Box<[u128]>>) {
244 match &*node.borrow() {
245 SubProblem::Unsolved(bits, entropy) => {
246 result.push(bits.clone());
247 }
248 SubProblem::Accept(i) => {}
249 SubProblem::Ite { expr, entropy, t: tb, f: fb } => {
250 self.unsolved_recursive(tb, result);
251 self.unsolved_recursive(fb, result);
252 }
253 }
254 }
255 fn unsolved(&self) -> Vec<Box<[u128]>> {
257 let mut result = Vec::new();
258 self.unsolved_recursive(self.root, &mut result);
259 result
260 }
261 pub fn expr(&self) -> &'static Expr {
264 self.expr_recursizve(self.root)
265 }
266
267 pub fn result_size(&self) -> usize {
270 self.size_recursive(self.root)
271 }
272}
273
274impl<'a, 'b> std::fmt::Debug for TreeLearning<'a, 'b> {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 self.fmt_recursive(f, self.root, &mut "".into())
278 }
279}
280
281#[inline(always)]
282pub fn tree_learning<'a, 'b>(options: Vec<(&'static Expr, Bits)>, conditions: &'b [(&'static Expr, Bits)], size: usize, bump: &'a Bump, limit: usize) -> TreeLearning<'a, 'b> {
283 let mut tl = TreeLearning::new_in(size, conditions, options, bump, limit);
284 tl.run();
285 tl
286}
287