1use core::task::{ContextBuilder, LocalWaker};
2use std::{cell::{Cell, RefCell}, future::Future, pin::Pin, rc::Rc, task::{Context, LocalWake, Poll, Waker}};
3
4use crate::sync::oneshot::{self, Channel};
5
6struct TaskInner<T: Future + 'static> where T::Output : Unpin {
7 chan: oneshot::Channel<T::Output>,
8 fut: T
9}
10
11struct Task<T: Future + 'static>(RefCell<TaskInner<T>>) where T::Output : Unpin;
12
13impl<T: Future + 'static> Task<T> where T::Output : Unpin {
14 fn new(fut: T) -> Self {
15 Self(RefCell::new(TaskInner { chan: Channel::new(), fut }))
16 }
17 fn poll_fut(self: &Rc<Self>) -> Poll<T::Output> {
18 let local_waker = self.clone().into();
19 let mut cx = ContextBuilder::from_waker(Waker::noop()).local_waker(&local_waker).build();
20 let mut lock = self.0.borrow_mut();
21 let p = unsafe { Pin::new_unchecked(&mut lock.fut) };
22 p.poll(&mut cx)
23 }
24}
25
26
27trait TaskTrait<T: Unpin> {
28 fn into_waker(self: Rc<Self>) -> LocalWaker;
29 fn is_ready(&self) -> bool;
30 fn is_completed(&self) -> bool;
31 fn poll_rc_nocx(self: Rc<Self>) -> Poll<T>;
32 fn poll_rc(self: Rc<Self>, cx: &mut Context<'_>) -> Poll<T>;
33}
34
35impl<T: Future + 'static> TaskTrait<T::Output> for Task<T> where T::Output : Unpin {
36 fn into_waker(self: Rc<Self>) -> LocalWaker {
37 self.into()
38 }
39
40 fn is_ready(&self) -> bool {
41 self.0.borrow().chan.is_ready()
42 }
43 fn is_completed(&self) -> bool {
44 self.0.borrow().chan.is_completed()
45 }
46
47 fn poll_rc_nocx(self: Rc<Self>) -> Poll<T::Output> {
48 self.0.borrow_mut().chan.poll_ref_nocx()
49 }
50 fn poll_rc(self: Rc<Self>, cx: &mut Context<'_>) -> Poll<T::Output> {
51 self.0.borrow_mut().chan.poll_ref(cx)
52 }
53}
54
55impl<T: Future + 'static> LocalWake for Task<T> where T::Output : Unpin {
56 fn wake_by_ref(self: &Rc<Self>) {
57 if self.is_completed() || self.is_ready() { return; }
58 let result = self.poll_fut();
59 if let Poll::Ready(a) = result {
60 let waker = self.0.borrow_mut().chan.send(a);
61 waker.wake();
62 }
63 }
64 fn wake(self: Rc<Self>) {
65 self.wake_by_ref()
66 }
67}
68
69pub struct JoinHandle<T: Unpin>(Rc<dyn TaskTrait<T>>);
70
71impl<T: Unpin> JoinHandle<T> {
72 pub fn waker(&self) -> LocalWaker {
73 self.0.clone().into_waker()
74 }
75 pub fn is_ready(&self) -> bool {
76 self.0.is_ready()
77 }
78 pub fn is_completed(&self) -> bool {
79 self.0.is_completed()
80 }
81 pub fn poll_rc_nocx(&self) -> Poll<T> {
82 self.0.clone().poll_rc_nocx()
83 }
84}
85
86impl<T: Unpin> std::future::Future for JoinHandle<T> {
87 type Output = T;
88
89 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
90 self.0.clone().poll_rc(cx)
91 }
92}
93
94thread_local!{
95 static NUM_TASKS : Cell<usize> = 0.into();
96}
97
98pub fn spawn<T: Unpin>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
99 NUM_TASKS.with(|x| x.update(|x| x+1));
100 let handle = JoinHandle(Rc::new(Task::new(fut)));
101 handle.waker().wake();
102 handle
103}
104
105pub fn number_of_tasks() -> usize {
106 NUM_TASKS.with(|x| x.get())
107}