use std::marker::PhantomData;
use tokio_util::codec::{Decoder, Encoder};
use bytes::{Buf, BufMut, BytesMut};
pub fn encode_varint<B: BufMut>(val: u64, mut buf: &mut B) {
prost::encoding::encode_varint(val << 1, &mut buf);
}
pub fn decode_varint<B: Buf>(mut buf: &mut B) -> Result<u64, prost::DecodeError> {
let len = prost::encoding::decode_varint(&mut buf)?;
Ok(len >> 1)
}
pub struct Decode<M> {
state: DecodeState,
_marker: PhantomData<M>,
}
impl<M> Default for Decode<M> {
fn default() -> Self {
Self {
state: DecodeState::Head,
_marker: PhantomData,
}
}
}
#[derive(Debug)]
enum DecodeState {
Head,
Body { len: usize },
}
impl<M: prost::Message + Default> Decoder for Decode<M> {
type Item = M;
type Error = crate::BoxError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.state {
DecodeState::Head => {
tracing::trace!(?src, "decoding head");
let mut tmp = src.clone().freeze();
let len = match decode_varint(&mut tmp) {
Ok(_) => {
decode_varint(src).unwrap() as usize
}
Err(_) => {
tracing::trace!(?self.state, src.len = src.len(), "waiting for header data");
return Ok(None);
}
};
self.state = DecodeState::Body { len };
tracing::trace!(?self.state, "ready for body");
self.decode(src)
}
DecodeState::Body { len } => {
if src.len() < len {
tracing::trace!(?self.state, src.len = src.len(), "waiting for body");
return Ok(None);
}
let body = src.split_to(len);
tracing::trace!(?body, "decoding body");
let message = M::decode(body)?;
self.state = DecodeState::Head;
Ok(Some(message))
}
}
}
}
pub struct Encode<M> {
_marker: PhantomData<M>,
}
impl<M> Default for Encode<M> {
fn default() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<M: prost::Message + Sized + std::fmt::Debug> Encoder<M> for Encode<M> {
type Error = crate::BoxError;
fn encode(&mut self, item: M, dst: &mut BytesMut) -> Result<(), Self::Error> {
let mut buf = BytesMut::new();
item.encode(&mut buf)?;
let buf = buf.freeze();
encode_varint(buf.len() as u64, dst);
dst.put(buf);
Ok(())
}
}