use std::{ marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use axum::extract::ws; use futures::{Sink, Stream}; use pin_project::pin_project; use tokio::net::TcpStream; use tokio_tungstenite as tt; use tokio_tungstenite::tungstenite as ts; pub trait IntoData { fn into_data(self) -> Option>; } impl IntoData for ws::Message { fn into_data(self) -> Option> { match self { ws::Message::Binary(x) => Some(x), _ => None, } } } impl IntoData for ts::Message { fn into_data(self) -> Option> { match self { ts::Message::Binary(x) => Some(x), _ => None, } } } #[pin_project] pub struct WebSocketTransport where Message: IntoData + From>, Transport: Stream> + Sink, { #[pin] inner: Transport, ghost: PhantomData<(Req, Resp)>, } impl From for WebSocketTransport { fn from(inner: ws::WebSocket) -> Self { Self { inner, ghost: PhantomData, } } } impl From>> for WebSocketTransport< Req, Resp, ts::Message, tt::WebSocketStream>, tt::tungstenite::Error, > { fn from(inner: tokio_tungstenite::WebSocketStream>) -> Self { Self { inner, ghost: PhantomData, } } } impl Stream for WebSocketTransport where Req: for<'de> serde::Deserialize<'de>, Message: IntoData + From> + std::fmt::Debug, Transport: Stream> + Sink, { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match futures::ready!(self.as_mut().project().inner.poll_next(cx)) { Some(Ok(msg)) => { let bin = msg.into_data(); match bin { Some(bin) => Poll::Ready(Some(Ok(bincode::deserialize_from::<&[u8], Req>( bin.as_ref(), ) .unwrap()))), None => Poll::Ready(None), } } Some(Err(err)) => Poll::Ready(Some(Err(err))), None => Poll::Ready(None), } } } impl Sink for WebSocketTransport where Resp: serde::Serialize, Message: IntoData + From>, Transport: Stream> + Sink, { type Error = Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.as_mut().project().inner.poll_ready(cx) } fn start_send(mut self: Pin<&mut Self>, item: Resp) -> Result<(), Self::Error> { let msg = Message::from(bincode::serialize(&item).unwrap()); self.as_mut().project().inner.start_send(msg) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.as_mut().project().inner.poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.as_mut().project().inner.poll_close(cx) } }