Procházet zdrojové kódy

task: add JoinHandle::join to wait for result

greatbridf před 10 měsíci
rodič
revize
f8ded5c5f2
2 změnil soubory, kde provedl 71 přidání a 18 odebrání
  1. 23 7
      src/kernel/task/scheduler.rs
  2. 48 11
      src/kernel/task/task.rs

+ 23 - 7
src/kernel/task/scheduler.rs

@@ -63,6 +63,28 @@ impl Task {
     }
 }
 
+impl<O> JoinHandle<O>
+where
+    O: Send,
+{
+    pub fn join(self) -> O {
+        let Self(output) = self;
+        let mut waker = Some(Waker::from(Task::current().clone()));
+
+        loop {
+            let mut locked = output.lock();
+            match locked.try_resolve() {
+                Some(output) => break output,
+                None => {
+                    if let Some(waker) = waker.take() {
+                        locked.register_waiter(waker);
+                    }
+                }
+            }
+        }
+    }
+}
+
 impl Scheduler {
     /// `Scheduler` might be used in various places. Do not hold it for a long time.
     ///
@@ -107,7 +129,7 @@ impl Scheduler {
         O: Send,
     {
         let (task, output) = Self::extract_handle(task);
-        TASKS.lock().insert(task.clone());
+        Task::add(task.clone());
         self.activate(&task);
 
         JoinHandle(output)
@@ -133,12 +155,6 @@ impl Scheduler {
         preempt::enable();
     }
 
-    pub fn schedule_noreturn() -> ! {
-        preempt::disable();
-        Self::schedule();
-        panic!("Scheduler::schedule_noreturn(): Should never return")
-    }
-
     pub async fn yield_now() {
         struct Yield(bool);
 

+ 48 - 11
src/kernel/task/task.rs

@@ -39,9 +39,53 @@ pub struct TaskHandle<Output: Send> {
     output: Arc<Spin<TaskOutput<Output>>>,
 }
 
+enum TaskOutputState<Output: Send> {
+    Waiting(Option<Waker>),
+    Finished(Option<Output>),
+    TakenOut,
+}
+
 pub struct TaskOutput<Output: Send> {
-    output: Option<Output>,
-    waker: Option<Waker>,
+    inner: TaskOutputState<Output>,
+}
+
+impl<Output> TaskOutput<Output>
+where
+    Output: Send,
+{
+    pub fn try_resolve(&mut self) -> Option<Output> {
+        let output = match &mut self.inner {
+            TaskOutputState::Waiting(_) => return None,
+            TaskOutputState::Finished(output) => output.take(),
+            TaskOutputState::TakenOut => panic!("Output already taken out"),
+        };
+
+        self.inner = TaskOutputState::TakenOut;
+        if let Some(output) = output {
+            Some(output)
+        } else {
+            unreachable!("Output should be present")
+        }
+    }
+
+    pub fn register_waiter(&mut self, waker: Waker) {
+        if let TaskOutputState::Waiting(inner_waker) = &mut self.inner {
+            inner_waker.replace(waker);
+        } else {
+            panic!("Output is not waiting");
+        }
+    }
+
+    pub fn commit_output(&mut self, output: Output) {
+        if let TaskOutputState::Waiting(inner_waker) = &mut self.inner {
+            if let Some(waker) = inner_waker.take() {
+                waker.wake();
+            }
+            self.inner = TaskOutputState::Finished(Some(output));
+        } else {
+            panic!("Output is not waiting");
+        }
+    }
 }
 
 /// A `Task` represents a schedulable unit.
@@ -111,8 +155,7 @@ impl Task {
         static ID: AtomicU32 = AtomicU32::new(0);
 
         let output = Arc::new(Spin::new(TaskOutput {
-            output: None,
-            waker: None,
+            inner: TaskOutputState::Waiting(None),
         }));
 
         let kernel_stack = KernelStack::new();
@@ -193,13 +236,7 @@ impl Task {
             let output_data = runnable.pinned_join(&waker);
 
             if let Some(output) = output.upgrade() {
-                let mut output = output.lock();
-                let old = output.output.replace(output_data);
-                debug_assert!(old.is_none(), "Output should be empty");
-
-                if let Some(waker) = output.waker.take() {
-                    waker.wake();
-                }
+                output.lock().commit_output(output_data);
             }
         }