tower_abci/v037/
codec.rs

1use std::marker::PhantomData;
2
3use tokio_util::codec::{Decoder, Encoder};
4
5use bytes::{BufMut, BytesMut};
6
7pub struct Decode<M> {
8    state: DecodeState,
9    _marker: PhantomData<M>,
10}
11
12impl<M> Default for Decode<M> {
13    fn default() -> Self {
14        Self {
15            state: DecodeState::Head,
16            _marker: PhantomData,
17        }
18    }
19}
20
21#[derive(Debug)]
22enum DecodeState {
23    Head,
24    Body { len: usize },
25}
26
27impl<M: prost::Message + Default> Decoder for Decode<M> {
28    type Item = M;
29    type Error = crate::BoxError;
30
31    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
32        match self.state {
33            DecodeState::Head => {
34                tracing::trace!(?src, "decoding head");
35
36                // we don't use decode_varint directly, because it advances the
37                // buffer regardless of success, but Decoder assumes that when
38                // the buffer advances we've consumed the data. this is sort of
39                // a sad hack, but it works.
40                // TODO(erwan): fix this
41
42                // Tendermint socket protocol:
43                //   "Messages are serialized using Protobuf3 and length-prefixed
44                //    with an unsigned varint"
45                // See: https://github.com/tendermint/tendermint/blob/v0.37.x/spec/abci/abci++_client_server.md#socket
46                let mut tmp = src.clone().freeze();
47                let len = match prost::encoding::decode_varint(&mut tmp) {
48                    Ok(_) => {
49                        // advance the real buffer
50                        prost::encoding::decode_varint(src).unwrap() as usize
51                    }
52                    Err(_) => {
53                        tracing::trace!(?self.state, src.len = src.len(), "waiting for header data");
54                        return Ok(None);
55                    }
56                };
57                self.state = DecodeState::Body { len };
58                tracing::trace!(?self.state, "ready for body");
59
60                // Recurse to attempt body decoding.
61                self.decode(src)
62            }
63            DecodeState::Body { len } => {
64                if src.len() < len {
65                    tracing::trace!(?self.state, src.len = src.len(), "waiting for body");
66                    return Ok(None);
67                }
68
69                let body = src.split_to(len);
70                tracing::trace!(?body, "decoding body");
71                let message = M::decode(body)?;
72
73                // Now reset the decoder state for the next message.
74                self.state = DecodeState::Head;
75
76                Ok(Some(message))
77            }
78        }
79    }
80}
81
82pub struct Encode<M> {
83    _marker: PhantomData<M>,
84}
85
86impl<M> Default for Encode<M> {
87    fn default() -> Self {
88        Self {
89            _marker: PhantomData,
90        }
91    }
92}
93
94impl<M: prost::Message + Sized + std::fmt::Debug> Encoder<M> for Encode<M> {
95    type Error = crate::BoxError;
96
97    fn encode(&mut self, item: M, dst: &mut BytesMut) -> Result<(), Self::Error> {
98        let mut buf = BytesMut::new();
99        item.encode(&mut buf)?;
100        let buf = buf.freeze();
101
102        // Tendermint socket protocol:
103        //   "Messages are serialized using Protobuf3 and length-prefixed
104        //    with an unsigned varint"
105        // See: https://github.com/tendermint/tendermint/blob/v0.37.x/spec/abci/abci++_client_server.md#socket
106        prost::encoding::encode_varint(buf.len() as u64, dst);
107        dst.put(buf);
108
109        Ok(())
110    }
111}