1use std::marker::PhantomData;
2
3use tokio_util::codec::{Decoder, Encoder};
4
5use bytes::{Buf, BufMut, BytesMut};
6
7pub fn encode_varint<B: BufMut>(val: u64, mut buf: &mut B) {
10 prost::encoding::encode_varint(val << 1, &mut buf);
11}
12
13pub fn decode_varint<B: Buf>(mut buf: &mut B) -> Result<u64, prost::DecodeError> {
14 let len = prost::encoding::decode_varint(&mut buf)?;
15 Ok(len >> 1)
16}
17
18pub struct Decode<M> {
19 state: DecodeState,
20 _marker: PhantomData<M>,
21}
22
23impl<M> Default for Decode<M> {
24 fn default() -> Self {
25 Self {
26 state: DecodeState::Head,
27 _marker: PhantomData,
28 }
29 }
30}
31
32#[derive(Debug)]
33enum DecodeState {
34 Head,
35 Body { len: usize },
36}
37
38impl<M: prost::Message + Default> Decoder for Decode<M> {
39 type Item = M;
40 type Error = crate::BoxError;
41
42 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
43 match self.state {
44 DecodeState::Head => {
45 tracing::trace!(?src, "decoding head");
46
47 let mut tmp = src.clone().freeze();
53 let len = match decode_varint(&mut tmp) {
54 Ok(_) => {
55 decode_varint(src).unwrap() as usize
57 }
58 Err(_) => {
59 tracing::trace!(?self.state, src.len = src.len(), "waiting for header data");
60 return Ok(None);
61 }
62 };
63 self.state = DecodeState::Body { len };
64 tracing::trace!(?self.state, "ready for body");
65
66 self.decode(src)
68 }
69 DecodeState::Body { len } => {
70 if src.len() < len {
71 tracing::trace!(?self.state, src.len = src.len(), "waiting for body");
72 return Ok(None);
73 }
74
75 let body = src.split_to(len);
76 tracing::trace!(?body, "decoding body");
77 let message = M::decode(body)?;
78
79 self.state = DecodeState::Head;
81
82 Ok(Some(message))
83 }
84 }
85 }
86}
87
88pub struct Encode<M> {
89 _marker: PhantomData<M>,
90}
91
92impl<M> Default for Encode<M> {
93 fn default() -> Self {
94 Self {
95 _marker: PhantomData,
96 }
97 }
98}
99
100impl<M: prost::Message + Sized + std::fmt::Debug> Encoder<M> for Encode<M> {
101 type Error = crate::BoxError;
102
103 fn encode(&mut self, item: M, dst: &mut BytesMut) -> Result<(), Self::Error> {
104 let mut buf = BytesMut::new();
106 item.encode(&mut buf)?;
107 let buf = buf.freeze();
108 encode_varint(buf.len() as u64, dst);
109 dst.put(buf);
110
111 Ok(())
112 }
113}