synthphonia/tree_learning/
mod.rs

1use std::{cell::RefCell, borrow::{Borrow, BorrowMut}};
2
3use bumpalo::Bump;
4
5
6pub mod bits;
7
8use bits::BoxSliceExt;
9pub use bits::Bits;
10
11use crate::{debg, debg2, expr::Expr};
12
13/// An enum representing subproblems within a decision tree learning process for string synthesis. 
14pub enum SubProblem<'a> {
15    Unsolved(Bits, f32),
16    Accept(usize),
17    Ite{ expr: usize, entropy: f32, t: SubProb<'a>, f: SubProb<'a> }
18}
19
20impl<'a> SubProblem<'a> {
21    #[inline]
22    /// Adds subproblems for exploration in the decision tree. 
23    pub fn add_subproblems(&self, subproblem: &mut Vec<(SubProb<'a>, bool)>) {
24        if let SubProblem::Ite { expr, entropy, t, f } = self {
25            subproblem.push((f, true));
26            subproblem.push((t, true));
27        }
28    }
29}
30
31pub type SubProb<'a> = &'a RefCell<SubProblem<'a>>;
32
33/// A struct encapsulating the state and parameters for a decision tree learning process in string synthesis. 
34pub struct TreeLearning<'a, 'b> {
35    pub size: usize,
36    root: SubProb<'a>,
37    pub subproblems: Vec<SubProb<'a>>,
38    limit: usize,
39    pub conditions: &'b [(&'static Expr, Bits)],
40    pub options: Vec<(&'static Expr, Bits)>,
41    pub bump: &'a Bump,
42    pub solved: bool,
43}
44
45/// An enum that captures the outcomes of decision-making processes for solving subproblems in decision trees. 
46/// 
47/// This enum comprises three variants: `Accept`, `Ite`, and `Failed`. 
48/// 
49/// 
50/// The `Accept` variant represents an outcome where a subproblem is successfully solved, containing a `usize` value typically indicating an index or identifier of the solution. 
51/// The `Ite` variant indicates a decision to use a conditional branch, often an 'if-then-else' construct, and includes a `usize` for identification, a `f32` for weight or probability, and two tuples of `(Bits, f32)` representing the branching conditions and associated probabilities. 
52/// The `Failed` variant signifies that a subproblem could not be resolved under the current conditions, indicating a failure in the decision-making process.
53pub enum SelectResult {
54    Accept(usize),
55    Ite(usize, f32, (Bits, f32), (Bits, f32)),
56    Failed,
57}
58
59
60impl<'a, 'b> TreeLearning<'a, 'b> {
61    // pub fn split_infomation(bits: Bits) -> f32 {
62
63    // }
64    /// Creates a new instance with specified parameters including size, conditions, options, memory allocator, and limit. 
65    pub fn new_in(size: usize, conditions: &'b [(&'static Expr, Bits)], options: Vec<(&'static Expr, Bits)>, bump: &'a Bump, limit: usize) -> Self {
66        let mut this = Self {
67            size,
68            root: bump.alloc(RefCell::new(SubProblem::Unsolved(bits::boxed_ones(size), 0.0))),
69            subproblems: Vec::new(),
70            conditions,
71            options,
72            bump,
73            solved: false,
74            limit
75        };
76        let root_entro = this.entropy(& bits::boxed_ones(size));
77        if let SubProblem::Unsolved(a, entropy) = &mut *this.root.borrow_mut() {
78            *entropy = root_entro;
79        }
80        this.subproblems.push(this.root);
81        this
82    }
83
84    #[inline]
85    /// Calculates the entropy of a given set of bits within the context of the `TreeLearning` algorithm's options. 
86    pub fn entropy(&self, bits: & Bits) -> f32 {
87        
88        let mut vec: Vec<_> = self.options.iter().enumerate().map(|(i, b)| {
89            let mut res = b.1.clone();
90            res.conjunction_assign(bits);
91            (i, res.count_ones(), res)
92        }).collect();
93        vec.sort_by_key(|a| u32::MAX - a.1);
94
95        let total = bits.count_ones();
96        let mut rest = bits.clone();
97        let mut rest_count = rest.count_ones();
98        let mut res = 0.0;
99        for (_, _, b) in vec {
100            rest.difference_assign(&b);
101            let count = rest_count - rest.count_ones();
102            let p = count as f32 / total as f32;
103            if p > 0.0 {
104                res += - p * p.log2();
105            }
106            rest_count = rest.count_ones();
107        }
108        res
109    }
110    
111    /// Calculates the conditional entropy of a given set of bits based on a specified condition bitset. 
112    pub fn cond_entropy(&self, bits: &Bits, condition: &Bits) -> (f32, (Bits, f32), (Bits, f32)) {
113        let total = bits.count_ones();
114        let mut and_bits = bits.clone();
115        and_bits.conjunction_assign(condition);
116        let and_entro = self.entropy(&and_bits);
117        let and_count = and_bits.count_ones();
118        let mut diff_bits = bits.clone();
119        diff_bits.difference_assign(condition);
120        let diff_entro = self.entropy(&diff_bits);
121        let diff_count = diff_bits.count_ones();
122        if and_count == 0 || diff_count == 0 {
123            (1e10, (and_bits, and_entro), (diff_bits, diff_entro))
124        } else {
125            (
126                (and_entro * and_count as f32 + diff_entro * diff_count as f32) / total as f32,
127                (and_bits, and_entro), (diff_bits, diff_entro)
128            )
129        }
130    }
131    
132    #[inline]
133    /// Determines the next action for an unsolved subproblem in the tree learning process. 
134    pub fn select(&self, unsolved: &SubProblem<'a>) -> SelectResult {
135        if let SubProblem::Unsolved(bits, entro) = unsolved {
136            if *entro <= 0.0001 {
137                if let Some((i, _)) = self.options.iter().enumerate().find(|(_, x)| bits.subset(&x.1) ) {
138                    return SelectResult::Accept(i)
139                }
140            }
141            let (i, (centro, tb, fb)) = self.conditions.iter().enumerate()
142                .map(|(i, (e, cb))| {
143                    let ce = self.cond_entropy(bits, cb);
144                    (i, ce)
145                })
146                .min_by(|a, b| a.1.0.partial_cmp(&b.1.0).unwrap())
147                .expect("At least have one condition.");
148            if centro - 0.00001 < *entro {
149                SelectResult::Ite(i, centro, tb, fb)
150            } else {
151                SelectResult::Failed
152            }
153        } else { panic!("last should be unsolved.") }
154    }
155
156    /// Executes the learning algorithm by iterating over the subproblems within the decision tree. 
157    pub fn run(&mut self) -> bool {
158        let mut counter = 1;
159        while let Some(last) = self.subproblems.pop() {
160            let sel = self.select(&last.borrow());
161            match sel {
162                SelectResult::Accept(i) => {
163                    *last.borrow_mut() = SubProblem::Accept(i);
164                }
165                SelectResult::Ite(expr, entropy, t, f) => {
166                    let tb = self.bump.alloc(SubProblem::Unsolved(t.0, t.1).into());
167                    let fb = self.bump.alloc(SubProblem::Unsolved(f.0, f.1).into());
168                    self.subproblems.push(fb);
169                    self.subproblems.push(tb);
170                    *last.borrow_mut() = SubProblem::Ite{ expr, entropy, t: tb, f: fb };
171                    counter += 2;
172                    if counter > self.limit { 
173                        debg2!("{:?}", self);
174                        return false;
175                    }
176                }
177                SelectResult::Failed => {
178                    debg2!("{:?}", self);
179                    return false;
180                }
181            }
182        }
183        self.solved = true;
184        debg2!("{:?}", self);
185        true
186    }
187
188    /// Facilitates the recursive formatting of the decision tree contained within the `TreeLearning` structure for display purposes. 
189    fn fmt_recursive(&self, f: &mut std::fmt::Formatter<'_>, node: SubProb<'a>, indent: &mut String) -> std::fmt::Result {
190        match &*node.borrow() {
191            SubProblem::Unsolved(bits, entropy) => 
192                writeln!(f, "{indent}?? {} {:x?}", entropy, bits),
193            SubProblem::Accept(i) => 
194                writeln!(f, "{indent}{:?}", self.options[*i].0),
195            SubProblem::Ite { expr, entropy, t: tb, f: fb } => {
196                writeln!(f, "{indent}ite {:?} {:x?}", self.conditions[*expr].0, self.conditions[*expr].1)?;
197                indent.push_str("  ");
198                self.fmt_recursive(f, tb, indent)?;
199                self.fmt_recursive(f, fb, indent)?;
200                indent.pop(); indent.pop();
201                Ok(())
202            }
203        }
204    }
205    /// Determines the size of the decision tree by recursively traversing through its nodes. 
206    fn size_recursive(&self, node: SubProb<'a>) -> usize {
207        match &*node.borrow() {
208            SubProblem::Unsolved(bits, entropy) => 1,
209            SubProblem::Accept(i) => 1,
210            SubProblem::Ite { expr, entropy, t: tb, f: fb } => 1 + self.size_recursive(tb) + self.size_recursive(fb),
211        }
212    }
213    /// Covers a decision tree recursively starting from a given node and determining the set of bits covered by the tree structure. 
214    fn cover_recursive(&self, node: SubProb<'a>) -> Bits {
215        match &*node.borrow() {
216            SubProblem::Unsolved(bits, entropy) => bits.clone(),
217            SubProblem::Accept(i) => self.options[*i].1.clone(),
218            SubProblem::Ite { expr, entropy, t: tb, f: fb } => {
219                let mut t = self.cover_recursive(tb);
220                let mut f = self.cover_recursive(fb);
221                let bits = self.conditions[*expr].1.clone();
222                t.conjunction_assign(&bits);
223                f.difference_assign(&bits);
224                t.union_assign(&f);
225                t
226            }
227        }
228    }
229    /// Returns the expression representation of a given node in the decision tree. 
230    fn expr_recursizve(&self, node: SubProb<'a>) -> &'static Expr {
231        match &*node.borrow() {
232            SubProblem::Unsolved(bits, entropy) => panic!("Still subproblem remain."),
233            SubProblem::Accept(i) => self.options[*i].0,
234            SubProblem::Ite { expr, entropy, t: tb, f: fb } => {
235                let t = self.expr_recursizve(tb);
236                let f = self.expr_recursizve(fb);
237                let cond = self.conditions[*expr].0;
238                cond.ite(t, f)
239            }
240        }
241    }
242    /// Recursively traverses through a decision tree to collect bits from unsolved subproblems. 
243    fn unsolved_recursive(&self, node: SubProb<'a>, result: &mut Vec<Box<[u128]>>) {
244        match &*node.borrow() {
245            SubProblem::Unsolved(bits, entropy) => {
246                result.push(bits.clone());
247            }
248            SubProblem::Accept(i) => {}
249            SubProblem::Ite { expr, entropy, t: tb, f: fb } => {
250                self.unsolved_recursive(tb, result);
251                self.unsolved_recursive(fb, result);
252            }
253        }
254    }
255    /// Returns a vector of boxed slices containing `u128` values that represent unsolved components of a decision tree. 
256    fn unsolved(&self) -> Vec<Box<[u128]>> {
257        let mut result = Vec::new();
258        self.unsolved_recursive(self.root, &mut result);
259        result
260    }
261    /// Returns the expression associated with the root of the decision tree. 
262    /// This function utilizes a recursive approach by invoking `expr_recursizve` on the tree's root node to retrieve the expression efficiently, leveraging the recursive structure to navigate through potentially complex tree configurations within the `TreeLearning` context.
263    pub fn expr(&self) -> &'static Expr {
264        self.expr_recursizve(self.root)
265    }
266    
267    /// Calculates the result size of a decision tree by recursively determining the size starting from the root node. 
268    /// This implementation utilizes the `size_recursive` function on the `root` to compute the cumulative size of the tree structure, which includes all subproblems, branches, and accepted solutions present in the tree.
269    pub fn result_size(&self) -> usize {
270        self.size_recursive(self.root)
271    }
272}
273
274impl<'a, 'b> std::fmt::Debug for TreeLearning<'a, 'b> {
275    /// Formats the decision tree within the `TreeLearning` instance for display. 
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        self.fmt_recursive(f, self.root, &mut "".into())
278    }
279}
280
281#[inline(always)]
282pub fn tree_learning<'a, 'b>(options: Vec<(&'static Expr, Bits)>, conditions: &'b [(&'static Expr, Bits)], size: usize, bump: &'a Bump, limit: usize) -> TreeLearning<'a, 'b> {
283    let mut tl = TreeLearning::new_in(size, conditions, options, bump, limit);
284    tl.run();
285    tl
286}
287