tower_abci/buffer4/
service.rs

1use super::{
2    future::ResponseFuture,
3    message::Message,
4    worker::{Handle, Worker},
5};
6
7use futures::ready;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
11use tokio_util::sync::PollSemaphore;
12use tower::Service;
13
14/// Adds an mpsc buffer in front of an inner service.
15///
16/// See the module documentation for more details.
17#[derive(Debug)]
18pub struct Buffer<T, Request>
19where
20    T: Service<Request>,
21{
22    // Note: this actually _is_ bounded, but rather than using Tokio's bounded
23    // channel, we use Tokio's semaphore separately to implement the bound.
24    tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
25    // When the buffer's channel is full, we want to exert backpressure in
26    // `poll_ready`, so that callers such as load balancers could choose to call
27    // another service rather than waiting for buffer capacity.
28    //
29    // Unfortunately, this can't be done easily using Tokio's bounded MPSC
30    // channel, because it doesn't expose a polling-based interface, only an
31    // `async fn ready`, which borrows the sender. Therefore, we implement our
32    // own bounded MPSC on top of the unbounded channel, using a semaphore to
33    // limit how many items are in the channel.
34    semaphore: PollSemaphore,
35    // The current semaphore permit, if one has been acquired.
36    //
37    // This is acquired in `poll_ready` and taken in `call`.
38    permit: Option<OwnedSemaphorePermit>,
39    handle: Handle,
40}
41
42impl<T, Request> Buffer<T, Request>
43where
44    T: Service<Request>,
45    T::Error: Into<crate::BoxError>,
46{
47    /// Creates a new [`Buffer`] wrapping `service`.
48    ///
49    /// `bound` gives the maximal number of requests that can be queued for the service before
50    /// backpressure is applied to callers.
51    ///
52    /// The default Tokio executor is used to run the given service, which means that this method
53    /// must be called while on the Tokio runtime.
54    ///
55    /// # A note on choosing a `bound`
56    ///
57    /// When [`Buffer`]'s implementation of [`poll_ready`] returns [`Poll::Ready`], it reserves a
58    /// slot in the channel for the forthcoming [`call`]. However, if this call doesn't arrive,
59    /// this reserved slot may be held up for a long time. As a result, it's advisable to set
60    /// `bound` to be at least the maximum number of concurrent requests the [`Buffer`] will see.
61    /// If you do not, all the slots in the buffer may be held up by futures that have just called
62    /// [`poll_ready`] but will not issue a [`call`], which prevents other senders from issuing new
63    /// requests.
64    ///
65    /// # A note on the scope of `bound`
66    ///
67    /// Note that `bound` will only limit the rate of the _submission_ of [Message]s to the [Worker],
68    /// not their _execution_. If the execution itself is asynchronous, concurrency should be further
69    /// controlled by applying an appropriate [tower::Layer] on the returned service component.
70    ///
71    /// [`Poll::Ready`]: std::task::Poll::Ready
72    /// [`call`]: crate::Service::call
73    /// [`poll_ready`]: crate::Service::poll_ready
74    pub fn new(service: T, bound: usize) -> (Self, Self, Self, Self)
75    where
76        T: Send + 'static,
77        T::Future: Send,
78        T::Error: Send + Sync,
79        Request: Send + 'static,
80    {
81        let (svc1, svc2, svc3, svc4, worker) = Self::pair(service, bound);
82        tokio::spawn(worker.run());
83        (svc1, svc2, svc3, svc4)
84    }
85
86    /// Creates a new [`Buffer`] wrapping `service`, but returns the background worker.
87    ///
88    /// This is useful if you do not want to spawn directly onto the tokio runtime
89    /// but instead want to use your own executor. This will return the [`Buffer`] and
90    /// the background `Worker` that you can then spawn.
91    #[allow(clippy::type_complexity)]
92    pub fn pair(
93        service: T,
94        bound: usize,
95    ) -> (
96        Buffer<T, Request>,
97        Buffer<T, Request>,
98        Buffer<T, Request>,
99        Buffer<T, Request>,
100        Worker<T, Request>,
101    )
102    where
103        T: Send + 'static,
104        T::Error: Send + Sync,
105        Request: Send + 'static,
106    {
107        let (tx1, rx1) = mpsc::unbounded_channel();
108        let (tx2, rx2) = mpsc::unbounded_channel();
109        let (tx3, rx3) = mpsc::unbounded_channel();
110        let (tx4, rx4) = mpsc::unbounded_channel();
111
112        let semaphore1 = Arc::new(Semaphore::new(bound));
113        let semaphore2 = Arc::new(Semaphore::new(bound));
114        let semaphore3 = Arc::new(Semaphore::new(bound));
115        let semaphore4 = Arc::new(Semaphore::new(bound));
116
117        let (handle, worker) = Worker::new(
118            service,
119            rx1,
120            &semaphore1,
121            rx2,
122            &semaphore2,
123            rx3,
124            &semaphore3,
125            rx4,
126            &semaphore4,
127        );
128
129        let buffer1 = Buffer {
130            tx: tx1,
131            handle: handle.clone(),
132            semaphore: PollSemaphore::new(semaphore1),
133            permit: None,
134        };
135        let buffer2 = Buffer {
136            tx: tx2,
137            handle: handle.clone(),
138            semaphore: PollSemaphore::new(semaphore2),
139            permit: None,
140        };
141        let buffer3 = Buffer {
142            tx: tx3,
143            handle: handle.clone(),
144            semaphore: PollSemaphore::new(semaphore3),
145            permit: None,
146        };
147        let buffer4 = Buffer {
148            tx: tx4,
149            handle,
150            semaphore: PollSemaphore::new(semaphore4),
151            permit: None,
152        };
153
154        (buffer1, buffer2, buffer3, buffer4, worker)
155    }
156
157    fn get_worker_error(&self) -> crate::BoxError {
158        self.handle.get_error_on_closed()
159    }
160}
161
162impl<T, Request> Service<Request> for Buffer<T, Request>
163where
164    T: Service<Request>,
165    T::Error: Into<crate::BoxError>,
166{
167    type Response = T::Response;
168    type Error = crate::BoxError;
169    type Future = ResponseFuture<T::Future>;
170
171    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172        // First, check if the worker is still alive.
173        if self.tx.is_closed() {
174            // If the inner service has errored, then we error here.
175            return Poll::Ready(Err(self.get_worker_error()));
176        }
177
178        // Then, check if we've already acquired a permit.
179        if self.permit.is_some() {
180            // We've already reserved capacity to send a request. We're ready!
181            return Poll::Ready(Ok(()));
182        }
183
184        // Finally, if we haven't already acquired a permit, poll the semaphore
185        // to acquire one. If we acquire a permit, then there's enough buffer
186        // capacity to send a new request. Otherwise, we need to wait for
187        // capacity.
188        let permit =
189            ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
190        self.permit = Some(permit);
191
192        Poll::Ready(Ok(()))
193    }
194
195    fn call(&mut self, request: Request) -> Self::Future {
196        tracing::trace!("sending request to buffer worker");
197        let _permit = self
198            .permit
199            .take()
200            .expect("buffer full; poll_ready must be called first");
201
202        // get the current Span so that we can explicitly propagate it to the worker
203        // if we didn't do this, events on the worker related to this span wouldn't be counted
204        // towards that span since the worker would have no way of entering it.
205        let span = tracing::Span::current();
206
207        // If we've made it here, then a semaphore permit has already been
208        // acquired, so we can freely allocate a oneshot.
209        let (tx, rx) = oneshot::channel();
210
211        match self.tx.send(Message {
212            request,
213            span,
214            tx,
215            _permit,
216        }) {
217            Err(_) => ResponseFuture::failed(self.get_worker_error()),
218            Ok(_) => ResponseFuture::new(rx),
219        }
220    }
221}
222
223impl<T, Request> Clone for Buffer<T, Request>
224where
225    T: Service<Request>,
226{
227    fn clone(&self) -> Self {
228        Self {
229            tx: self.tx.clone(),
230            handle: self.handle.clone(),
231            semaphore: self.semaphore.clone(),
232            // The new clone hasn't acquired a permit yet. It will when it's
233            // next polled ready.
234            permit: None,
235        }
236    }
237}