1use std::{cell::{RefCell, UnsafeCell}, cmp::{max, min}, collections::{HashMap, HashSet}, pin::{pin, Pin}, rc::Rc, sync::Arc, task::{Poll, Waker}};
2
3use bumpalo::collections::CollectIn;
4use figment::util::diff_paths;
5use futures::{future::select, FutureExt};
6use futures_core::Future;
7use itertools::Itertools;
8use simple_rc_async::task::{self, JoinHandle};
9
10use crate::{async_closure, closure, debg, expr::{ context::Context, ops::Op1Enum, Expr}, forward::executor::Executor, info, utils::select_ret5, value::Type, DEBUG};
11use crate::{galloc::{self, AllocForAny, AllocForExactSizeIter, AllocForIter}, never, utils::{pending_if, select_all, select_ret, select_ret3, select_ret4, UnsafeCellExt}, value::Value};
12
13use crate::expr;
14use super::{Deducer, Problem};
15
16pub struct HandleRcVec<T: Unpin>(Arc<UnsafeCell<Vec<JoinHandle<T>>>>);
22
23impl<T: Unpin> Clone for HandleRcVec<T> {
24 fn clone(&self) -> Self {
28 Self(self.0.clone())
29 }
30}
31
32impl<T: Unpin> Future for HandleRcVec<T> {
33 type Output=T;
34
35 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
41 for v in unsafe{ self.0.as_mut()}.iter_mut() {
42 if let Poll::Ready(a) = v.poll_unpin(cx) {
43 return Poll::Ready(a);
44 }
45 }
46 Poll::Pending
47 }
48}
49
50impl<T: Unpin> Default for HandleRcVec<T> {
51 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl<T: Unpin> HandleRcVec<T> {
61 pub fn new() -> Self {
65 Self(Arc::new(UnsafeCell::new(Vec::new())))
66 }
67 pub fn extend_iter(&self, v: impl Iterator<Item=JoinHandle<T>>) {
70 for f in v {
71 unsafe{self.0.as_mut()}.push(f);
72 }
73 }
74 pub fn len(&self) -> usize {
77 unsafe { self.0.as_mut().len() }
78 }
79}
80
81#[derive(Debug)]
82pub struct StrDeducer {
84 pub nt: usize,
86 pub split_once: (usize, usize),
88 pub join: (usize, usize),
90 pub ite_concat: (usize, usize),
92 pub index: (usize, usize),
93 pub formatter: Vec<(Op1Enum, usize)>,
95 pub decay_rate: usize,
97}
98
99impl StrDeducer {
100 pub fn new(nt: usize) -> Self {
102 Self { nt, split_once: (usize::MAX, 0), join: (usize::MAX, 0), ite_concat: (usize::MAX, usize::MAX), index: (usize::MAX, usize::MAX), formatter: Vec::new(), decay_rate: usize::MAX }
103 }
104}
105
106impl Deducer for StrDeducer {
107 async fn deduce(&'static self, exec: &'static Executor, prob: Problem) -> &'static crate::expr::Expr {
109 assert!(self.nt == prob.nt);
110 assert!(prob.value.ty() == Type::Str, "Expected a string value, got: {:?}", prob.value);
111 let this = self;
112 let mut eq = pin!(exec.data[self.nt].all_eq.acquire(prob.value));
113 debg!("Deducing subproblem: {} {:?}", self.nt, prob.value);
114 if let Poll::Ready(r) = futures::poll!(&mut eq) { return r; }
115
116 let futures = HandleRcVec::new();
118
119 let substr_event = closure! { clone futures, clone prob; async move {
120 if exec.data[self.nt].substr().is_some() {
121 exec.data[self.nt].substr().unwrap().listen_for_each(prob.value, closure! { clone futures, clone prob; move |delimiter: Value| {
122 futures.extend_iter(this.split1(exec, prob, delimiter).into_iter());
123 futures.extend_iter(this.join(exec, prob, delimiter).into_iter());
124 None::<&'static Expr>
125 }}).await
126 } else { never!(&'static Expr) }
127 }};
128
129 let prefix_event = closure! { clone futures, clone prob; async move {
130 if exec.data[self.nt].prefix().is_some() {
131 exec.data[self.nt].prefix().unwrap().listen_for_each(prob.value, move |prefix: Value| {
132 futures.extend_iter(this.ite_concat(exec, prob, prefix).into_iter());
133 None::<&'static Expr>
134 }).await
135 } else { never!(&'static Expr) }
136 }};
137
138 let index_event = closure! { clone futures, clone prob; async move {
139 if self.index.0 != usize::MAX && prob.used_cost < 3 && exec.data[self.index.0].contains.is_some() {
140 exec.data[self.index.0].contains.as_ref().unwrap().listen_for_each(prob.value, move |list: Value| {
141 futures.extend_iter(this.index(exec, prob, list).into_iter());
142 None::<&'static Expr>
143 }).await
144 } else { never!(&'static Expr) }
145 }};
146
147 let join_empty_str_cond = self.join.0 < usize::MAX && prob.used_cost <= 8 &&
148 prob.value.to_str().iter().all(|x| x.chars().all(|c| c.is_alphanumeric())) &&
149 prob.value.to_str().iter().any(|x| x.len() > 2);
150
151 let map_event = pin!(closure! {clone futures; async move {
152 if join_empty_str_cond {
153 let v = exec.data[self.join.1].len().unwrap().listen_once(prob.value).await;
154 futures.extend_iter(this.join_empty_str(exec, prob).into_iter());
155 }
156 never!(&'static Expr)
157 }});
158 let iter = self.formatter.iter().map(|x| self.fmt(prob, x, exec));
159
160 let substr_event = pin!(substr_event);
161 let prefix_event = pin!(prefix_event);
162 let index_event = pin!(index_event);
163 let events = select_ret4(prefix_event, substr_event, map_event, index_event);
164
165 let result = select_ret4(eq, events, futures, pin!(select_all(iter))).await;
166 result
167 }
168}
169
170
171
172impl StrDeducer {
173
174 #[inline]
175 fn split1(&'static self, exec: &'static Executor, mut prob: Problem, delimiter: Value) -> Option<JoinHandle<&'static Expr>> {
177 let delimiter = delimiter.to_str();
178 let v = prob.value.to_str();
179 let contain_count: usize = v.iter().zip(delimiter.iter()).filter(|(x, y)| if !y.is_empty() { x.contains(*y) } else { false }).count();
180 Some(task::spawn(async move {
184 let (a, b, cases) = split_once(v, delimiter);
185 if !cases.is_all_true() || self.ite_concat.1 == usize::MAX { return never!() }
186 exec.waiting_tasks().inc_cost(&mut prob, 1).await;
187
188 debg!("StrDeducer::split1 {v:?} {delimiter:?}");
189
190 let left = exec.solve_task(prob.with_value(a)).await;
191 let right = exec.solve_task(prob.with_value(b)).await;
192
193 let mut result = exec.data[prob.nt].all_eq.get(delimiter.into());
194 if self.ite_concat.1 != usize::MAX {
195 result = self.generate_condition(exec, prob.with_nt(self.ite_concat.1, cases), result).await;
196 }
197 if !a.is_all_empty() {
198 result = expr!(Concat {left} {result}).galloc();
199 }
200 if !b.is_all_empty() {
201 return expr!(Concat {result} {right}).galloc();
202 }
203 result
204 }))
205 }
206 #[inline]
207 pub async fn generate_condition(&'static self, exec: &'static Executor, prob: Problem, result: &'static Expr) -> &'static Expr {
209 if prob.value.is_all_true() { return result; }
210 let left = pin!(exec.solve_task(prob));
211 let right = pin!(exec.solve_task(prob.with_value(prob.value.bool_not())));
212 let cond = futures::future::select(left, right).await;
213 match cond {
214 futures::future::Either::Left((c, _)) =>
215 expr!(Ite {c} {result} "").galloc(),
216 futures::future::Either::Right((c, _)) =>
217 expr!(Ite {c} "" {result}).galloc(),
218 }
219 }
220 #[inline]
221 pub fn ite_concat(&'static self, exec: &'static Executor, mut prob: Problem, prefix: Value) -> Option<JoinHandle<&'static Expr>> {
223 let v: &[&str] = prob.value.to_str();
224 let prefix: &[&str] = prefix.to_str();
225 let start_count: usize = v.iter().zip(prefix.iter()).map(|(x, y)| if x.starts_with(*y) { y.len() } else { 0 }).sum();
226 let eq_count: usize = v.iter().zip(prefix.iter()).map(|(x, y)| if x == y { y.len() } else { 0 }).sum();
227
228 Some(task::spawn(async move {
232 debg!("StrDeducer::ite_concat {} {:?} {:?} {start_count} {eq_count}", prob.nt, v, prefix);
233 let (a, b) = ite_concat_split(v, prefix);
234
235 exec.waiting_tasks().inc_cost(&mut prob, 1).await;
236
237 let right = exec.solve_task(prob.with_value(b)).await;
238
239 let mut result = exec.data[prob.nt].all_eq.get(prefix.into());
240 result = self.generate_condition(exec, prob.with_nt(self.ite_concat.1, a), result).await;
241 if !b.is_all_empty() {
242 result = expr!(Concat {result} {right}).galloc();
243 }
244 result
245 }))
246 }
247
248 pub fn index(&'static self, exec: &'static Executor, mut prob: Problem, list: Value) -> Option<JoinHandle<&'static Expr>> {
249 let v: &[&str] = prob.value.to_str();
250 let list : &[&[&str]] = list.to_liststr();
251
252 let indices = v.iter().zip(list.iter()).map(|(x, y)| {
253 y.iter().position(|&z| z == *x).unwrap_or(y.len()) as i64
254 }).galloc_scollect();
255 if self.index.0 == usize::MAX { return None; }
256 Some(task::spawn(async move {
257 debg!("StrDeducer::index {} {:?} {:?} {:?} {} ", prob.nt, v, list, indices, self.index.1);
258 let indices = exec.data[self.index.1].all_eq.acquire(indices.into()).await;
261 let mut result = exec.data[self.index.0].all_eq.get(list.into());
262 expr!(At {result} {indices}).galloc()
263 }))
264 }
265
266 #[inline]
267 fn join(&'static self, exec: &'static Executor, mut prob: Problem, delimiter: Value) -> Option<JoinHandle<&'static Expr>> {
269 let delimiter = delimiter.to_str();
270 let v = prob.value.to_str();
271 if prob.used_cost >= 5 { return None; }
272
273 let contain_count: usize = v.iter().zip(delimiter.iter()).map(|(x, y)| x.matches(y).count() + 1).max().unwrap_or(10000);
274 if contain_count < self.join.0 { return None; }
275
276
277 Some(task::spawn(async move {
278 debg!("StrDeducer::join {v:?} {delimiter:?} {} {}", prob.used_cost, contain_count);
279 let a = value_split(v, delimiter);
282
283 let list = exec.solve_task(prob.with_nt(self.join.1, a)).await;
284
285 let mut delim = exec.data[prob.nt].all_eq.get(delimiter.into());
286 expr!(Join {list} {delim}).galloc()
287 }))
288 }
289 #[inline]
290 fn join_empty_str(&'static self, exec: &'static Executor, mut prob: Problem) -> Option<JoinHandle<&'static Expr>> {
292 debg!("StrDeducer::join_empty_str {:?}", prob.value);
293
294 Some(task::spawn(async move {
295 exec.waiting_tasks().inc_cost(&mut prob, 1).await;
296 let v = prob.value.to_str();
297 let li = v.iter().map(|x| (0..x.len()).map(|i| &x[i..i+1]).galloc_scollect() ).galloc_scollect();
298 let list = exec.solve_task(prob.with_nt(self.join.1, li.into())).await;
299 expr!(Join {list} "").galloc()
300 }))
301 }
302 #[inline]
312 async fn fmt(&self, mut problem: Problem, formatter: &(Op1Enum, usize), exec: &'static Executor) -> &'static Expr {
314 let v = problem.value.to_str();
315 if let Some((op, a, b, cond)) = formatter.0.format_all(v) {
316 debg!("StrDeducer::fmt {v:?} {formatter:?}");
317 if !cond.is_all_true() { exec.waiting_tasks().inc_cost(&mut problem, 1).await; }
318 else { exec.waiting_tasks().inc_cost(&mut problem, 1).await; }
319
320 let inner = exec.solve_task(problem.with_nt(formatter.1, a)).await;
321 let rest = exec.solve_task(problem.with_nt(self.nt, b)).await;
322
323 let mut result = Expr::Op1(op.clone().galloc(), inner).galloc();
324 if self.ite_concat.1 != usize::MAX {
325 result = exec.generate_condition(problem.with_nt(self.ite_concat.1, cond), result).await;
326 }
327 result = expr!(Concat {result} {rest}).galloc();
328 if DEBUG.get() {
329 assert_eq!(result.eval(&exec.ctx), Value::Str(v), "Expression: {:?} {:?}", result, a);
330 }
331 result
332 } else { never!() }
333 }
334}
335
336pub fn split_once(s: &'static [&'static str], delimiter: &'static [&'static str]) -> (Value, Value, Value) {
338 assert!(s.len() == delimiter.len());
339 let mut a = galloc::new_bvec(s.len());
340 let mut b = galloc::new_bvec(s.len());
341 let mut cases = galloc::new_bvec(s.len());
342 for (x, y) in s.iter().zip(delimiter.iter()) {
343 if y.is_empty() {
344 a.push("");
345 b.push(*x);
346 cases.push(true)
347 } else if let Some((l, r)) = x.split_once(*y) {
348 a.push(l);
349 b.push(r);
350 cases.push(true)
351 } else {
352 a.push(x);
353 b.push("");
354 cases.push(false)
355 }
356 }
357 (Value::Str(a.into_bump_slice()), Value::Str(b.into_bump_slice()), Value::Bool(cases.into_bump_slice()))
358}
359
360pub fn ite_concat_split(s: &'static [&'static str], delimiter: &'static [&'static str]) -> (Value, Value) {
367 assert!(s.len() == delimiter.len());
368 let mut a = galloc::new_bvec(s.len());
369 let mut b = galloc::new_bvec(s.len());
370 for (x, y) in s.iter().zip(delimiter.iter()) {
371 let v = x.starts_with(y);
372 a.push(v);
373 if v {
374 b.push(&x[y.len()..])
375 } else {
376 b.push(x)
377 }
378 }
379 (Value::Bool(a.into_bump_slice()), Value::Str(b.into_bump_slice()))
380}
381
382pub fn value_split(s: &'static [&'static str], delimiter: &'static [&'static str]) -> Value {
387 Value::ListStr(s.iter().zip(delimiter.iter()).map(|(x, y)| x.split(y).galloc_collect()).galloc_collect())
388}