synthphonia/forward/
executor.rs

1use std::{
2    cell::{Cell, RefCell, UnsafeCell}, collections::{hash_map::Entry, HashMap, HashSet}, default, f64::consts::E, fs, future::Future, pin::pin, sync::atomic::AtomicBool, task::Poll, time::{self, Duration, Instant}
3};
4
5use derive_more::{Constructor, Deref, From, Into};
6use futures::StreamExt;
7use itertools::Itertools;
8use simple_rc_async::{sync::{broadcast, broadcastque}, task::{self, JoinHandle}};
9
10use crate::{
11    backward::{ Deducer, DeducerEnum, Problem}, debg, debg2, expr::{
12         cfg::{Cfg, ProdRule}, context::Context, Expr
13    }, forward::{data::{size, substr}, enumeration::ProdRuleEnumerateExt, executor}, galloc::AllocForAny, info, log, parser::problem::PBEProblem, solutions::CONDITIONS, text::parsing::{ParseInt, TextObjData}, utils::UnsafeCellExt, value::{ConstValue, Type, Value}, warn
14};
15use crate::expr;
16use super::{bridge::Bridge, data::{self, all_eq, size::EV, Data}};
17
18pub trait EnumFn = FnMut(Expr, Value) -> Result<(), ()>;
19
20/// A Boolean flag that utilizes atomic operations to safely manage concurrent access across multiple threads. 
21/// 
22/// This static variable, initialized to `false`, provides a thread-safe mechanism to signal stopping or interrupting operations. 
23/// It is typically used in scenarios where multiple threads need to be coordinated to end processes gracefully or to check and respond to a halt condition efficiently without data races. 
24/// The use of `AtomicBool` ensures that checks and updates to this variable are done atomically, maintaining data integrity across concurrent executions.
25/// 
26pub static STOP_SIGNAL: AtomicBool = AtomicBool::new(false);
27
28/// Holds all tasks waiting for a cost limit to be released.
29pub struct TaskWaitingCost {
30    sender: broadcastque::Sender<()>,
31    cur_cost: usize,
32}
33
34impl Default for TaskWaitingCost {
35    /// Constructs and returns a default instance by invoking an alternative constructor. 
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl TaskWaitingCost {
42    /// Creates a new instance of TaskWaitingCost. 
43    pub fn new() -> Self {
44        TaskWaitingCost { sender: broadcastque::channel(), cur_cost: 0  }
45    }
46    
47    /// Increments the cost associated with a task and updates the problem's used cost. 
48    /// 
49    /// It will wait for the specified amount of cost to be released before proceeding.
50    pub async fn inc_cost(&mut self, problem: &mut Problem, amount: usize) {
51        let mut rv: broadcastque::Reciever<()> = self.sender.reciever();
52        problem.used_cost += amount;
53        let amount = problem.used_cost as isize - self.cur_cost as isize;
54        if amount > 0 {
55            for _ in 0..amount {
56                let _ = rv.next().await;
57            }
58        }
59    }
60    
61    /// Releases a specified number of waiting slots on a task queue. 
62    /// 
63    /// This function sends a signal to the task queue, allowing the specified number of tasks to proceed.
64    pub fn release_cost_limit(&mut self, count: usize) {
65        self.sender.send((), count);
66    }
67}
68
69// pub(crate) struct OtherData {
70//     pub(crate) all_str_const: HashSet<&'static str>,
71//     // pub problems: UnsafeCell<HashMap<(usize, Value), TaskORc<&'static Expr>>>,
72// }
73
74/// A structure representing an executor for managing and coordinating the synthesis process. 
75/// 
76/// This structure encapsulates various fields required to execute synthesis tasks effectively, such as tracking state and managing data flow. 
77/// 
78/// Usage:
79/// ```rust
80/// let exec = Executor::new(ctx, cfg);
81/// let result = exec.solve_top_blocked();
82/// let result = DefineFun { sig: problem.synthfun().sig.clone(), expr: result};
83/// ```
84pub struct Executor {
85    pub ctx: Context,
86    pub cfg: Cfg,
87    /// All deducers used in the executor.
88    pub deducers: Vec<DeducerEnum>,
89    /// Term Dispatcher data structures
90    pub data: Vec<Data>,
91    /// A counter for the number of expressions enumerated.
92    pub counter: Cell<usize>,
93    /// A counter for the number of subproblems processed.
94    pub subproblem_count: Cell<usize>,
95    /// The current size of the expression being processed.
96    pub cur_size: Cell<usize>,
97    /// The current non-terminal index being processed.
98    pub cur_nt: Cell<usize>,
99    /// No longer used
100    /// Queue of tasks waiting for cost limit to be released.
101    pub waiting_tasks: UnsafeCell<TaskWaitingCost>,
102    /// Top task to be executed.
103    pub top_task: UnsafeCell<JoinHandle<&'static Expr>>,
104    expr_collector: UnsafeCell<Vec<EV>>,
105    /// Bridge to interact with other threads
106    pub bridge: Bridge,
107    /// Timestamp when the executor started.
108    pub start_time: time::Instant,
109}
110
111impl Executor {
112    /// Retrieves the count of subproblems processed by the executor. 
113    pub fn problem_count(&self) -> usize{
114        self.subproblem_count.get()
115    }
116    /// Creates a new instance. 
117    pub fn new(ctx: Context, cfg: Cfg) -> Self {
118        let data = Data::new(&cfg, &ctx);
119        let deducers = (0..cfg.len()).map(|i, | DeducerEnum::from_nt(&cfg, &ctx, i)).collect_vec();
120        let exec = Self { counter: 0.into(), subproblem_count: 0.into(), ctx, cfg, data, deducers, expr_collector: Vec::new().into(),
121            cur_size: 0.into(), cur_nt: 0.into(), waiting_tasks: TaskWaitingCost::new().into(),
122            top_task: task::spawn(futures::future::pending()).into(), bridge: Bridge::new(),
123            start_time: Instant::now() };
124        TextObjData::build_trie(&exec);
125        exec
126    }
127    pub fn top_task(&self) -> &mut JoinHandle<&'static Expr> {
128        unsafe { self.top_task.as_mut() }
129    }
130    /// Collects expressions and their associated values. Save them into the `expr_collector` field.
131    pub fn collect_expr(&self, e: &'static Expr, v: Value) {
132        unsafe { self.expr_collector.as_mut().push((e, v)) }
133    }
134    /// Returns a mutable reference to the field `waiting_tasks`. 
135    pub fn waiting_tasks(&self) -> &mut TaskWaitingCost {
136        unsafe { self.waiting_tasks.as_mut() }
137    }
138    /// Extracts the contents of the `expr_collector` and returns them as a `Vec<EV>`. 
139    pub fn extract_expr_collector(&self) -> Vec<EV> {
140        UnsafeCellExt::replace(&self.expr_collector, Vec::new())
141    }
142    /// Provides a method to access the current data entry from the `data` vector within the Executor context. 
143    pub fn cur_data(&self) -> &Data {
144        &self.data[self.cur_nt.get()]
145    }
146    #[inline]
147    /// Solves a given synthesis problem asynchronously and returns a reference to an expression. 
148    pub async fn solve_task(&'static self, problem: Problem) -> &'static Expr {
149        if let Some(e) = self.data[problem.nt].all_eq.at(problem.value) {
150            return e;
151        }
152        self.subproblem_count.update(|x| x+1);
153        task::spawn(self.deducers[problem.nt].deduce(self, problem)).await
154    }
155    #[inline]
156    /// Asynchronously generates a conditional expression for a given problem and result. 
157    pub async fn generate_condition(&'static self, problem: Problem, result: &'static Expr) -> &'static Expr {
158        if problem.value.is_all_true() { return result; }
159        let left = pin!(self.solve_task(problem));
160        let right = pin!(self.solve_task(problem.with_value(problem.value.bool_not())));
161        let cond = futures::future::select(left, right).await;
162        match cond {
163            futures::future::Either::Left((c, _)) => 
164                expr!(Ite {c} {result} "").galloc(),
165            futures::future::Either::Right((c, _)) => 
166                expr!(Ite {c} "" {result}).galloc(),
167        }
168    }
169    /// Attempts to solve the top-level problem and manage its execution. 
170    pub fn solve_top_blocked(self) -> &'static Expr {
171        let problem = Problem::root(0, self.ctx.output);
172        let this = unsafe { (&self as *const Executor).as_ref::<'static>().unwrap() };
173        this.subproblem_count.update(|x| x+1);
174        *this.top_task() = task::spawn(this.deducers[problem.nt].deduce(this, problem));
175        let _ = this.run();
176        self.bridge.abort_all();
177        if let Poll::Ready(r) = this.top_task().poll_rc_nocx() {
178            r
179        } else { panic!("should not reach here.") }
180        // match problems.entry((nt, value)) {
181        //     Entry::Occupied(o) => o.get().clone(),
182        //     Entry::Vacant(e) => {
183        //         let t = ;
184        //         e.insert(t.clone());
185        //         t
186        //     }
187        // }
188    }
189
190    /// Attempts to solve the top problem with a limit within the `Executor`. 
191    pub fn solve_top_with_limit(self) -> Option<&'static Expr> {
192        let problem = Problem::root(0, self.ctx.output);
193        let this = unsafe { (&self as *const Executor).as_ref::<'static>().unwrap() };
194        this.subproblem_count.update(|x| x+1);
195        *this.top_task() = task::spawn(this.deducers[problem.nt].deduce(this, problem));
196        let _ = this.run();
197        self.bridge.abort_all();
198        if let Poll::Ready(r) = this.top_task().poll_rc_nocx() {
199            Some(r)
200        } else { None }
201    }
202
203    /// Retrieves the current size of the executor. 
204    pub fn size(&self) -> usize { self.cur_size.get() }
205    
206    /// Retrieves the current non-terminal index from the executor. 
207    pub fn nt(&self) -> usize { self.cur_nt.get() }
208
209    /// Returns the current value of the `counter` field. 
210    pub fn count(&self) -> usize { self.counter.get() }
211    
212    #[inline]
213    /// Handle when a new express is enumerated.
214    pub fn enum_expr(&'static self, e: Expr, v: Value) -> Result<(), ()> {
215        if self.counter.get() % 10000 == 1 {
216            if self.counter.get() % 300000 == 1 {
217                info!("Searching size={} [{}] - {:?} {:?} {}", self.cur_size.get(), self.counter.get(), e, v, self.subproblem_count.get());
218            }
219            self.waiting_tasks().release_cost_limit(self.cfg.config.increase_cost_limit);
220            self.bridge.check();
221        }
222        self.counter.update(|x| x + 1);
223        if self.ctx.output.ty() != Type::Bool && v.ty() == Type::Bool {
224            self.collect_condition(&e);
225        } else if let Some(e) = self.cur_data().update(self, e, v)? {
226            self.collect_expr(e,v);
227        }
228        if self.top_task().is_ready() || (Instant::now() - self.start_time).as_millis() >= self.cfg.config.time_limit as u128 {
229            return Err(());
230        }
231        while STOP_SIGNAL.load(std::sync::atomic::Ordering::Relaxed) { std::hint::spin_loop() }
232        Ok(())
233    }
234    /// Collects and inserts an expression into a shared collection of conditions `CONDITIONS` 
235    fn collect_condition(&'static self, e: &Expr) {
236        if let Some(x) = CONDITIONS.lock().as_mut() { x.insert(e) }
237    }
238    /// Start Enumeration
239    fn run(&'static self) -> Result<(), ()> {
240        let _ = self.extract_expr_collector();
241        for size in 1 ..self.cfg.config.size_limit {
242            for (nt, ntdata) in self.cfg.iter().enumerate() {
243                self.cur_size.set(size);
244                self.cur_nt.set(nt);
245                info!("Enumerating size={} nt={} with - {}", size, ntdata.name, self.counter.get());
246                self.cur_data().to.enumerate(self)?;
247                for rule in &ntdata.rules {
248                    rule.enumerate(self)?;
249                }
250                
251                self.cur_data().size.add(size, self.extract_expr_collector());
252            }
253        }
254        Ok(())
255    }
256    // pub fn get_problem(&'static self, p: Problem) -> TaskORc<&'static Expr> {
257    //     let hash = unsafe { self.other.problems.as_mut() };
258    //     match hash.entry(p.clone()) {
259    //         std::collections::hash_map::Entry::Occupied(p) => p.get().clone(),
260    //         std::collections::hash_map::Entry::Vacant(v) => {
261    //             let t = task::spawn(p.deduce(self)).tasko();
262    //             v.insert(t.clone());
263    //             t
264    //         }
265    //     }
266    // }
267    // pub fn block_on<T>(&'static self, t: TaskORc<T>) -> Option<T> {
268    //     task::with_top_task(t.clone().task(), || {
269    //         let _ = self.run();
270    //     });
271    //     match t.poll_unpin() {
272    //         Poll::Ready(res) => Some(res),
273    //         Poll::Pending => None,
274    //     }
275    // }
276    // pub fn run_with(&'static self, p: Problem) -> Option<&'static Expr> {
277    //     let t = self.get_problem(p);
278    //     task::with_top_task(t.task(), || {
279    //         let _ = self.run();
280    //     });
281    //     match t.poll_unpin() {
282    //         Poll::Ready(res) => Some(res),
283    //         Poll::Pending => None,
284    //     }
285    // }
286}
287