tower_abci/v038/
codec.rs

1use std::marker::PhantomData;
2
3use tokio_util::codec::{Decoder, Encoder};
4
5use bytes::{BufMut, BytesMut};
6
7pub struct Decode<M> {
8    state: DecodeState,
9    _marker: PhantomData<M>,
10}
11
12impl<M> Default for Decode<M> {
13    fn default() -> Self {
14        Self {
15            state: DecodeState::Head,
16            _marker: PhantomData,
17        }
18    }
19}
20
21#[derive(Debug)]
22enum DecodeState {
23    Head,
24    Body { len: usize },
25}
26
27impl<M: prost::Message + Default> Decoder for Decode<M> {
28    type Item = M;
29    type Error = crate::BoxError;
30
31    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
32        match self.state {
33            DecodeState::Head => {
34                tracing::trace!(?src, "decoding head");
35                // we don't use decode_varint directly, because it advances the
36                // buffer regardless of success, but Decoder assumes that when
37                // the buffer advances we've consumed the data. this is sort of
38                // a sad hack, but it works.
39                // TODO(erwan): fix this
40
41                // Tendermint socket protocol:
42                //   "Messages are serialized using Protobuf3 and length-prefixed
43                //    with an unsigned varint"
44                // See: https://github.com/tendermint/tendermint/blob/v0.38.x/spec/abci/abci++_client_server.md#socket
45                let mut tmp = src.clone().freeze();
46                let len = match prost::encoding::decode_varint(&mut tmp) {
47                    Ok(_) => {
48                        // advance the real buffer
49                        prost::encoding::decode_varint(src).unwrap() as usize
50                    }
51                    Err(_) => {
52                        tracing::trace!(?self.state, src.len = src.len(), "waiting for header data");
53                        return Ok(None);
54                    }
55                };
56                self.state = DecodeState::Body { len };
57                tracing::trace!(?self.state, "ready for body");
58
59                // Recurse to attempt body decoding.
60                self.decode(src)
61            }
62            DecodeState::Body { len } => {
63                if src.len() < len {
64                    tracing::trace!(?self.state, src.len = src.len(), "waiting for body");
65                    return Ok(None);
66                }
67
68                let body = src.split_to(len);
69                tracing::trace!(?body, "decoding body");
70                let message = M::decode(body)?;
71
72                // Now reset the decoder state for the next message.
73                self.state = DecodeState::Head;
74
75                Ok(Some(message))
76            }
77        }
78    }
79}
80
81pub struct Encode<M> {
82    _marker: PhantomData<M>,
83}
84
85impl<M> Default for Encode<M> {
86    fn default() -> Self {
87        Self {
88            _marker: PhantomData,
89        }
90    }
91}
92
93impl<M: prost::Message + Sized + std::fmt::Debug> Encoder<M> for Encode<M> {
94    type Error = crate::BoxError;
95
96    fn encode(&mut self, item: M, dst: &mut BytesMut) -> Result<(), Self::Error> {
97        let mut buf = BytesMut::new();
98        item.encode(&mut buf)?;
99        let buf = buf.freeze();
100
101        // Tendermint socket protocol:
102        //   "Messages are serialized using Protobuf3 and length-prefixed
103        //    with an unsigned varint"
104        // See: https://github.com/tendermint/tendermint/blob/v0.38.x/spec/abci/abci++_client_server.md#socket
105        prost::encoding::encode_varint(buf.len() as u64, dst);
106        dst.put(buf);
107
108        Ok(())
109    }
110}