1use std::marker::PhantomData;
2
3use tokio_util::codec::{Decoder, Encoder};
4
5use bytes::{BufMut, BytesMut};
6
7pub struct Decode<M> {
8 state: DecodeState,
9 _marker: PhantomData<M>,
10}
11
12impl<M> Default for Decode<M> {
13 fn default() -> Self {
14 Self {
15 state: DecodeState::Head,
16 _marker: PhantomData,
17 }
18 }
19}
20
21#[derive(Debug)]
22enum DecodeState {
23 Head,
24 Body { len: usize },
25}
26
27impl<M: prost::Message + Default> Decoder for Decode<M> {
28 type Item = M;
29 type Error = crate::BoxError;
30
31 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
32 match self.state {
33 DecodeState::Head => {
34 tracing::trace!(?src, "decoding head");
35 let mut tmp = src.clone().freeze();
46 let len = match prost::encoding::decode_varint(&mut tmp) {
47 Ok(_) => {
48 prost::encoding::decode_varint(src).unwrap() as usize
50 }
51 Err(_) => {
52 tracing::trace!(?self.state, src.len = src.len(), "waiting for header data");
53 return Ok(None);
54 }
55 };
56 self.state = DecodeState::Body { len };
57 tracing::trace!(?self.state, "ready for body");
58
59 self.decode(src)
61 }
62 DecodeState::Body { len } => {
63 if src.len() < len {
64 tracing::trace!(?self.state, src.len = src.len(), "waiting for body");
65 return Ok(None);
66 }
67
68 let body = src.split_to(len);
69 tracing::trace!(?body, "decoding body");
70 let message = M::decode(body)?;
71
72 self.state = DecodeState::Head;
74
75 Ok(Some(message))
76 }
77 }
78 }
79}
80
81pub struct Encode<M> {
82 _marker: PhantomData<M>,
83}
84
85impl<M> Default for Encode<M> {
86 fn default() -> Self {
87 Self {
88 _marker: PhantomData,
89 }
90 }
91}
92
93impl<M: prost::Message + Sized + std::fmt::Debug> Encoder<M> for Encode<M> {
94 type Error = crate::BoxError;
95
96 fn encode(&mut self, item: M, dst: &mut BytesMut) -> Result<(), Self::Error> {
97 let mut buf = BytesMut::new();
98 item.encode(&mut buf)?;
99 let buf = buf.freeze();
100
101 prost::encoding::encode_varint(buf.len() as u64, dst);
106 dst.put(buf);
107
108 Ok(())
109 }
110}