tower_abci/v034/
codec.rs

1use std::marker::PhantomData;
2
3use tokio_util::codec::{Decoder, Encoder};
4
5use bytes::{Buf, BufMut, BytesMut};
6
7// encode_varint and decode_varint will be removed once
8// https://github.com/tendermint/tendermint/issues/5783 lands in Tendermint.
9pub fn encode_varint<B: BufMut>(val: u64, mut buf: &mut B) {
10    prost::encoding::encode_varint(val << 1, &mut buf);
11}
12
13pub fn decode_varint<B: Buf>(mut buf: &mut B) -> Result<u64, prost::DecodeError> {
14    let len = prost::encoding::decode_varint(&mut buf)?;
15    Ok(len >> 1)
16}
17
18pub struct Decode<M> {
19    state: DecodeState,
20    _marker: PhantomData<M>,
21}
22
23impl<M> Default for Decode<M> {
24    fn default() -> Self {
25        Self {
26            state: DecodeState::Head,
27            _marker: PhantomData,
28        }
29    }
30}
31
32#[derive(Debug)]
33enum DecodeState {
34    Head,
35    Body { len: usize },
36}
37
38impl<M: prost::Message + Default> Decoder for Decode<M> {
39    type Item = M;
40    type Error = crate::BoxError;
41
42    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
43        match self.state {
44            DecodeState::Head => {
45                tracing::trace!(?src, "decoding head");
46
47                // we don't use decode_varint directly, because it advances the
48                // buffer regardless of success, but Decoder assumes that when
49                // the buffer advances we've consumed the data. this is sort of
50                // a sad hack, but it works.
51                // fix this
52                let mut tmp = src.clone().freeze();
53                let len = match decode_varint(&mut tmp) {
54                    Ok(_) => {
55                        // advance the real buffer
56                        decode_varint(src).unwrap() as usize
57                    }
58                    Err(_) => {
59                        tracing::trace!(?self.state, src.len = src.len(), "waiting for header data");
60                        return Ok(None);
61                    }
62                };
63                self.state = DecodeState::Body { len };
64                tracing::trace!(?self.state, "ready for body");
65
66                // Recurse to attempt body decoding.
67                self.decode(src)
68            }
69            DecodeState::Body { len } => {
70                if src.len() < len {
71                    tracing::trace!(?self.state, src.len = src.len(), "waiting for body");
72                    return Ok(None);
73                }
74
75                let body = src.split_to(len);
76                tracing::trace!(?body, "decoding body");
77                let message = M::decode(body)?;
78
79                // Now reset the decoder state for the next message.
80                self.state = DecodeState::Head;
81
82                Ok(Some(message))
83            }
84        }
85    }
86}
87
88pub struct Encode<M> {
89    _marker: PhantomData<M>,
90}
91
92impl<M> Default for Encode<M> {
93    fn default() -> Self {
94        Self {
95            _marker: PhantomData,
96        }
97    }
98}
99
100impl<M: prost::Message + Sized + std::fmt::Debug> Encoder<M> for Encode<M> {
101    type Error = crate::BoxError;
102
103    fn encode(&mut self, item: M, dst: &mut BytesMut) -> Result<(), Self::Error> {
104        // rewrite this to avoid extra copy?
105        let mut buf = BytesMut::new();
106        item.encode(&mut buf)?;
107        let buf = buf.freeze();
108        encode_varint(buf.len() as u64, dst);
109        dst.put(buf);
110
111        Ok(())
112    }
113}