1use std::{
2 fmt, io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use bitflags::bitflags;
8use bytes::{Buf, BytesMut};
9use futures_core::{ready, Stream};
10use futures_sink::Sink;
11use pin_project_lite::pin_project;
12
13use crate::{AsyncRead, AsyncWrite, Decoder, Encoder};
14
15const LW: usize = 1024;
17const HW: usize = 8 * 1024;
19
20bitflags! {
21 #[derive(Debug, Clone, Copy)]
22 struct Flags: u8 {
23 const EOF = 0b0001;
24 const READABLE = 0b0010;
25 }
26}
27
28pin_project! {
29 pub struct Framed<T, U> {
37 #[pin]
38 io: T,
39 codec: U,
40 flags: Flags,
41 read_buf: BytesMut,
42 write_buf: BytesMut,
43 }
44}
45
46impl<T, U> Framed<T, U>
47where
48 T: AsyncRead + AsyncWrite,
49 U: Decoder,
50{
51 pub fn new(io: T, codec: U) -> Framed<T, U> {
55 Framed {
56 io,
57 codec,
58 flags: Flags::empty(),
59 read_buf: BytesMut::with_capacity(HW),
60 write_buf: BytesMut::with_capacity(HW),
61 }
62 }
63}
64
65impl<T, U> Framed<T, U> {
66 pub fn codec_ref(&self) -> &U {
68 &self.codec
69 }
70
71 pub fn codec_mut(&mut self) -> &mut U {
73 &mut self.codec
74 }
75
76 pub fn io_ref(&self) -> &T {
81 &self.io
82 }
83
84 pub fn io_mut(&mut self) -> &mut T {
89 &mut self.io
90 }
91
92 pub fn io_pin(self: Pin<&mut Self>) -> Pin<&mut T> {
94 self.project().io
95 }
96
97 pub fn is_read_buf_empty(&self) -> bool {
99 self.read_buf.is_empty()
100 }
101
102 pub fn is_write_buf_empty(&self) -> bool {
104 self.write_buf.is_empty()
105 }
106
107 pub fn is_write_buf_full(&self) -> bool {
109 self.write_buf.len() >= HW
110 }
111
112 pub fn is_write_ready(&self) -> bool {
116 self.write_buf.len() < HW
117 }
118
119 pub fn replace_codec<U2>(self, codec: U2) -> Framed<T, U2> {
121 Framed {
122 codec,
123 io: self.io,
124 flags: self.flags,
125 read_buf: self.read_buf,
126 write_buf: self.write_buf,
127 }
128 }
129
130 pub fn into_map_io<F, T2>(self, f: F) -> Framed<T2, U>
132 where
133 F: Fn(T) -> T2,
134 {
135 Framed {
136 io: f(self.io),
137 codec: self.codec,
138 flags: self.flags,
139 read_buf: self.read_buf,
140 write_buf: self.write_buf,
141 }
142 }
143
144 pub fn into_map_codec<F, U2>(self, f: F) -> Framed<T, U2>
146 where
147 F: Fn(U) -> U2,
148 {
149 Framed {
150 io: self.io,
151 codec: f(self.codec),
152 flags: self.flags,
153 read_buf: self.read_buf,
154 write_buf: self.write_buf,
155 }
156 }
157}
158
159impl<T, U> Framed<T, U> {
160 pub fn write<I>(mut self: Pin<&mut Self>, item: I) -> Result<(), <U as Encoder<I>>::Error>
162 where
163 T: AsyncWrite,
164 U: Encoder<I>,
165 {
166 let this = self.as_mut().project();
167 let remaining = this.write_buf.capacity() - this.write_buf.len();
168 if remaining < LW {
169 this.write_buf.reserve(HW - remaining);
170 }
171
172 this.codec.encode(item, this.write_buf)?;
173 Ok(())
174 }
175
176 pub fn next_item(
178 mut self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 ) -> Poll<Option<Result<<U as Decoder>::Item, U::Error>>>
181 where
182 T: AsyncRead,
183 U: Decoder,
184 {
185 loop {
186 let this = self.as_mut().project();
187 if this.flags.contains(Flags::READABLE) {
193 if this.flags.contains(Flags::EOF) {
194 match this.codec.decode_eof(this.read_buf) {
195 Ok(Some(frame)) => return Poll::Ready(Some(Ok(frame))),
196 Ok(None) => return Poll::Ready(None),
197 Err(err) => return Poll::Ready(Some(Err(err))),
198 }
199 }
200
201 tracing::trace!("attempting to decode a frame");
202
203 match this.codec.decode(this.read_buf) {
204 Ok(Some(frame)) => {
205 tracing::trace!("frame decoded from buffer");
206 return Poll::Ready(Some(Ok(frame)));
207 }
208 Err(err) => return Poll::Ready(Some(Err(err))),
209 _ => (), }
211
212 this.flags.remove(Flags::READABLE);
213 }
214
215 debug_assert!(!this.flags.contains(Flags::EOF));
216
217 let remaining = this.read_buf.capacity() - this.read_buf.len();
219 if remaining < LW {
220 this.read_buf.reserve(HW - remaining)
221 }
222
223 let cnt = match tokio_util::io::poll_read_buf(this.io, cx, this.read_buf) {
224 Poll::Pending => return Poll::Pending,
225 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
226 Poll::Ready(Ok(cnt)) => cnt,
227 };
228
229 if cnt == 0 {
230 this.flags.insert(Flags::EOF);
231 }
232 this.flags.insert(Flags::READABLE);
233 }
234 }
235
236 pub fn flush<I>(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>>
238 where
239 T: AsyncWrite,
240 U: Encoder<I>,
241 {
242 let mut this = self.as_mut().project();
243 tracing::trace!("flushing framed transport");
244
245 while !this.write_buf.is_empty() {
246 tracing::trace!("writing; remaining={}", this.write_buf.len());
247
248 let n = ready!(this.io.as_mut().poll_write(cx, this.write_buf))?;
249
250 if n == 0 {
251 return Poll::Ready(Err(io::Error::new(
252 io::ErrorKind::WriteZero,
253 "failed to write frame to transport",
254 )
255 .into()));
256 }
257
258 this.write_buf.advance(n);
260 }
261
262 ready!(this.io.poll_flush(cx))?;
264
265 tracing::trace!("framed transport flushed");
266 Poll::Ready(Ok(()))
267 }
268
269 pub fn close<I>(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>>
271 where
272 T: AsyncWrite,
273 U: Encoder<I>,
274 {
275 let mut this = self.as_mut().project();
276 ready!(this.io.as_mut().poll_flush(cx))?;
277 ready!(this.io.as_mut().poll_shutdown(cx))?;
278 Poll::Ready(Ok(()))
279 }
280}
281
282impl<T, U> Stream for Framed<T, U>
283where
284 T: AsyncRead,
285 U: Decoder,
286{
287 type Item = Result<U::Item, U::Error>;
288
289 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
290 self.next_item(cx)
291 }
292}
293
294impl<T, U, I> Sink<I> for Framed<T, U>
295where
296 T: AsyncWrite,
297 U: Encoder<I>,
298 U::Error: From<io::Error>,
299{
300 type Error = U::Error;
301
302 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303 if self.is_write_ready() {
304 Poll::Ready(Ok(()))
305 } else {
306 self.flush(cx)
307 }
308 }
309
310 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
311 self.write(item)
312 }
313
314 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
315 self.flush(cx)
316 }
317
318 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
319 self.close(cx)
320 }
321}
322
323impl<T, U> fmt::Debug for Framed<T, U>
324where
325 T: fmt::Debug,
326 U: fmt::Debug,
327{
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 f.debug_struct("Framed")
330 .field("io", &self.io)
331 .field("codec", &self.codec)
332 .finish()
333 }
334}
335
336impl<T, U> Framed<T, U> {
337 pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> {
344 Framed {
345 io: parts.io,
346 codec: parts.codec,
347 flags: parts.flags,
348 write_buf: parts.write_buf,
349 read_buf: parts.read_buf,
350 }
351 }
352
353 pub fn into_parts(self) -> FramedParts<T, U> {
359 FramedParts {
360 io: self.io,
361 codec: self.codec,
362 flags: self.flags,
363 read_buf: self.read_buf,
364 write_buf: self.write_buf,
365 }
366 }
367}
368
369#[derive(Debug)]
374pub struct FramedParts<T, U> {
375 pub io: T,
377
378 pub codec: U,
380
381 pub read_buf: BytesMut,
383
384 pub write_buf: BytesMut,
386
387 flags: Flags,
388}
389
390impl<T, U> FramedParts<T, U> {
391 pub fn new(io: T, codec: U) -> FramedParts<T, U> {
393 FramedParts {
394 io,
395 codec,
396 flags: Flags::empty(),
397 read_buf: BytesMut::new(),
398 write_buf: BytesMut::new(),
399 }
400 }
401
402 pub fn with_read_buf(io: T, codec: U, read_buf: BytesMut) -> FramedParts<T, U> {
404 FramedParts {
405 io,
406 codec,
407 read_buf,
408 flags: Flags::empty(),
409 write_buf: BytesMut::new(),
410 }
411 }
412}