tower_abci/buffer4/
worker.rs

1use super::{
2    error::{Closed, ServiceError},
3    message::Message,
4};
5use futures::stream::StreamExt;
6use std::sync::{Arc, Mutex, Weak};
7use tokio::{
8    select,
9    sync::{mpsc, Semaphore},
10};
11use tokio_stream::wrappers::UnboundedReceiverStream;
12use tower::{Service, ServiceExt};
13use tracing::Instrument;
14
15pub struct Worker<T, Request>
16where
17    T: Service<Request>,
18    T::Error: Into<crate::BoxError>,
19{
20    service: T,
21    handle: Handle,
22    failed: Option<ServiceError>,
23
24    rx1: Option<mpsc::UnboundedReceiver<Message<Request, T::Future>>>,
25    rx2: Option<mpsc::UnboundedReceiver<Message<Request, T::Future>>>,
26    rx3: Option<mpsc::UnboundedReceiver<Message<Request, T::Future>>>,
27    rx4: Option<mpsc::UnboundedReceiver<Message<Request, T::Future>>>,
28
29    close1: Option<Weak<Semaphore>>,
30    close2: Option<Weak<Semaphore>>,
31    close3: Option<Weak<Semaphore>>,
32    close4: Option<Weak<Semaphore>>,
33}
34
35/// Get the error out
36#[derive(Debug, Clone)]
37pub(crate) struct Handle {
38    inner: Arc<Mutex<Option<ServiceError>>>,
39}
40
41impl<T, Request> Worker<T, Request>
42where
43    T: Service<Request>,
44    T::Error: Into<crate::BoxError>,
45{
46    #[allow(clippy::too_many_arguments)]
47    pub(crate) fn new(
48        service: T,
49        rx1: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
50        semaphore1: &Arc<Semaphore>,
51        rx2: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
52        semaphore2: &Arc<Semaphore>,
53        rx3: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
54        semaphore3: &Arc<Semaphore>,
55        rx4: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
56        semaphore4: &Arc<Semaphore>,
57    ) -> (Handle, Worker<T, Request>) {
58        let handle = Handle {
59            inner: Arc::new(Mutex::new(None)),
60        };
61
62        let close1 = Some(Arc::downgrade(semaphore1));
63        let close2 = Some(Arc::downgrade(semaphore2));
64        let close3 = Some(Arc::downgrade(semaphore3));
65        let close4 = Some(Arc::downgrade(semaphore4));
66
67        let worker = Worker {
68            service,
69            handle: handle.clone(),
70            failed: None,
71            rx1: Some(rx1),
72            rx2: Some(rx2),
73            rx3: Some(rx3),
74            rx4: Some(rx4),
75            close1,
76            close2,
77            close3,
78            close4,
79        };
80
81        (handle, worker)
82    }
83
84    fn shutdown(&mut self) {
85        if let Some(close1) = self.close1.take().as_ref().and_then(Weak::upgrade) {
86            tracing::debug!("buffer 1 closing; waking pending tasks");
87            close1.close();
88        }
89        if let Some(close2) = self.close2.take().as_ref().and_then(Weak::upgrade) {
90            tracing::debug!("buffer 2 closing; waking pending tasks");
91            close2.close();
92        }
93        if let Some(close3) = self.close3.take().as_ref().and_then(Weak::upgrade) {
94            tracing::debug!("buffer 3 closing; waking pending tasks");
95            close3.close();
96        }
97        if let Some(close4) = self.close4.take().as_ref().and_then(Weak::upgrade) {
98            tracing::debug!("buffer 4 closing; waking pending tasks");
99            close4.close();
100        }
101    }
102
103    fn failed(&mut self, error: crate::BoxError) {
104        tracing::debug!({ %error }, "service failed");
105        let error = ServiceError::new(error);
106        let mut inner = self.handle.inner.lock().unwrap();
107
108        if inner.is_some() {
109            unreachable!("cannot fail twice");
110        }
111
112        *inner = Some(error.clone());
113        drop(inner);
114        if let Some(chan) = self.rx1.as_mut() {
115            chan.close()
116        }
117        if let Some(chan) = self.rx2.as_mut() {
118            chan.close()
119        }
120        if let Some(chan) = self.rx3.as_mut() {
121            chan.close()
122        }
123        if let Some(chan) = self.rx4.as_mut() {
124            chan.close()
125        }
126
127        self.failed = Some(error);
128    }
129
130    async fn process(&mut self, msg: Message<Request, T::Future>) {
131        match self.service.ready().await {
132            Ok(svc) => {
133                tracing::trace!("dispatching request to service");
134                let response = svc.call(msg.request);
135                tracing::trace!("returning response future");
136                let _ = msg.tx.send(Ok(response));
137            }
138            Err(e) => {
139                self.failed(e.into());
140                let error = self.failed.as_ref().expect("just set error").clone();
141                let _ = msg.tx.send(Err(error));
142            }
143        }
144    }
145
146    pub(crate) async fn run(mut self) {
147        loop {
148            if let Some(ref failed) = self.failed {
149                tracing::trace!("flushing pending requests after worker failure");
150                // We've failed and closed all channels.
151                // Now we flush any pending channel entries.
152                flush_channel(failed, self.rx1.take()).await;
153                flush_channel(failed, self.rx2.take()).await;
154                flush_channel(failed, self.rx3.take()).await;
155                flush_channel(failed, self.rx4.take()).await;
156
157                self.shutdown();
158                return;
159            }
160
161            select! {
162                // Using a biased select means the channels will be polled
163                // in priority order, not in a random (fair) order.
164                biased;
165                msg = recv_option(self.rx1.as_mut()), if self.rx1.is_some() => {
166                    match msg {
167                        Some(msg) => {
168                            let span = msg.span.clone();
169                            self.process(msg).instrument(span).await
170                        }
171                        None => self.rx1 = None,
172                    }
173                }
174                msg = recv_option(self.rx2.as_mut()), if self.rx2.is_some() => {
175                    match msg {
176                        Some(msg) => {
177                            let span = msg.span.clone();
178                            self.process(msg).instrument(span).await
179                        }
180                        None => self.rx2 = None,
181                    }
182                }
183                msg = recv_option(self.rx3.as_mut()), if self.rx3.is_some() => {
184                    match msg {
185                        Some(msg) => {
186                            let span = msg.span.clone();
187                            self.process(msg).instrument(span).await
188                        }
189                        None => self.rx3 = None,
190                    }
191                }
192                msg = recv_option(self.rx4.as_mut()), if self.rx4.is_some() => {
193                    match msg {
194                        Some(msg) => {
195                            let span = msg.span.clone();
196                            self.process(msg).instrument(span).await
197                        }
198                        None => self.rx4 = None,
199                    }
200                }
201            };
202
203            if self.rx1.is_none() && self.rx2.is_none() && self.rx3.is_none() && self.rx4.is_none()
204            {
205                tracing::trace!("all senders closed, shutting down");
206                self.shutdown();
207                return;
208            }
209        }
210    }
211}
212
213async fn flush_channel<T, F>(
214    failed: &ServiceError,
215    rx: Option<mpsc::UnboundedReceiver<Message<T, F>>>,
216) {
217    if let Some(chan) = rx {
218        let mut s = UnboundedReceiverStream::new(chan);
219        while let Some(msg) = s.next().await {
220            let _guard = msg.span.enter();
221            tracing::trace!("notifying caller about worker failure");
222            let _ = msg.tx.send(Err(failed.clone()));
223        }
224    }
225}
226
227impl Handle {
228    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
229        self.inner
230            .lock()
231            .unwrap()
232            .as_ref()
233            .map(|svc_err| svc_err.clone().into())
234            .unwrap_or_else(|| Closed::new().into())
235    }
236}
237
238async fn recv_option<T>(x: Option<&mut mpsc::UnboundedReceiver<T>>) -> Option<T> {
239    x?.recv().await
240}