Skip to main content

rmp/decode/
est.rs

1use crate::Marker;
2use std::num::{NonZeroU32, NonZeroUsize};
3
4/// Incremental MessagePack parser that can parse incomplete messages,
5/// and report their estimated total length.
6pub struct MessageLen {
7    /// The last operation interrupted
8    wip: Option<WIP>,
9    /// Max size estimate
10    max_position: NonZeroUsize,
11    /// Bytes read so far
12    position: usize,
13    /// Stack of open arrays and maps
14    /// It is not a complete stack. Used only when resumption is needed.
15    sequences_wip: Vec<Seq>,
16    /// Nesting of arrays and maps
17    current_depth: u16,
18    /// Configured limit
19    max_depth: u16,
20    /// Configured limit
21    max_len: u32,
22}
23
24/// [`MessageLen`] result
25#[derive(Debug)]
26pub enum LenError {
27    /// The message is truncated, and needs at least this many bytes to parse
28    Truncated(NonZeroUsize),
29    /// The message is invalid or exceeded size limits
30    ParseError,
31}
32
33impl LenError {
34    /// Get expected min length or 0 on error
35    #[must_use]
36    pub fn len(&self) -> usize {
37        match *self {
38            Self::ParseError => 0,
39            Self::Truncated(l) => l.get(),
40        }
41    }
42}
43
44impl MessageLen {
45    /// New parser with default limits
46    ///
47    /// If you have all MessagePack data in memory already, you can use [`MessageLen::len_of`].
48    /// If you're reading data in a streaming fashion, you can feed chunks of data
49    /// to [`MessageLen::incremental_len`].
50    #[must_use]
51    pub fn new() -> Self {
52        Self::with_limits(1024, (u32::MAX as usize).min(isize::MAX as usize / 2))
53    }
54
55    /// * `max_depth` limits nesting of arrays and maps
56    ///
57    /// * `max_len` is maximum size of any string, byte string, map, or array.
58    ///   For maps and arrays this is the number of items, not bytes.
59    ///
60    /// Messages can be both deep and wide, being `max_depth` * `max_len` in size.
61    /// You should also limit the maximum byte size of the message (outside of this parser).
62    #[must_use]
63    pub fn with_limits(max_depth: usize, max_len: usize) -> Self {
64        Self {
65            max_position: NonZeroUsize::new(1).unwrap(),
66            position: 0,
67            current_depth: 0,
68            max_depth: max_depth.min(u16::MAX as _) as u16,
69            max_len: max_len.min(u32::MAX as _) as u32,
70            sequences_wip: Vec::new(),
71            wip: Some(WIP::NextMarker),
72        }
73    }
74
75    /// Parse the entire message to find if it's complete, and what is its serialized size in bytes.
76    ///
77    /// If it returns `Ok(len)`, then the first `len` bytes of the given slice
78    /// parse as a single MessagePack object.
79    /// The length may be shorter than the slice given (extra data is gracefully ignored).
80    ///
81    /// `Err(LenError::Truncated(len))` means that the the object is incomplete, the slice is truncated,
82    /// and it would need *at least* this many bytes to parse.
83    /// The `len` is always the lower bound, and never exceeds actual message length.
84    ///
85    /// `Err(LenError::ParseError)` — the end of the message is unknown.
86    ///
87    /// Don't call this function in a loop. Use [`MessageLen::incremental_len`] instead.
88    pub fn len_of(complete_message: &[u8]) -> Result<usize, LenError> {
89        Self::with_limits(1024, 1 << 30).incremental_len(complete_message.as_ref())
90    }
91
92    /// Parse more bytes, and re-evaluate required message length.
93    ///
94    /// This function is stateful and keeps "appending" the data to its evaluation.
95    ///
96    /// * `Ok(len)` — size of the whole MessagePack object, in bytes, starting at the beginning
97    ///   of all data given to this function, including previous calls (not just this slice).
98    ///   The object's data may end before the end of this slice. In such case the extra bytes
99    ///   are gracefully ignored, and have not been parsed.
100    ///
101    /// * `Err(LenError::Truncated(len))` — all bytes of this slice have been consumed,
102    ///   and that was still not enough. The object needs at least `len` bytes in total
103    ///   (counting from the start of all data given to this function, not just this slice).
104    ///   The `len` is always the lower bound, and never exceeds actual message length,
105    ///   so it's safe to read the additional bytes without overshooting the end of the message.
106    ///
107    /// * `Err(LenError::ParseError)` — the end of the message cannot be determined, and this
108    ///   is a non-recoverable error. Any further calls to this function may return nonsense.
109    pub fn incremental_len(&mut self, mut next_message_fragment: &[u8]) -> Result<usize, LenError> {
110        let data = &mut next_message_fragment;
111        let Some(wip) = self.wip.take() else {
112            return Ok(self.position); // must have succeded already
113        };
114        match wip {
115            WIP::Data(Data { bytes_left }) => self.skip_data(data, bytes_left.get()),
116            WIP::MarkerLen(wip) => self.read_marker_with_len(data, wip),
117            WIP::NextMarker => self.read_one_item(data),
118            WIP::LimitExceeded => {
119                self.wip = Some(WIP::LimitExceeded); // put it back!
120                return Err(LenError::ParseError);
121            },
122        }.ok_or(LenError::Truncated(self.max_position))?;
123
124        while let Some(seq) = self.sequences_wip.pop() {
125            self.current_depth = seq.depth;
126            debug_assert!(self.wip.is_none());
127            self.read_sequence(data, seq.items_left.get() - 1).ok_or(LenError::Truncated(self.max_position))?;
128        }
129        debug_assert!(self.wip.is_none());
130        debug_assert!(self.max_position.get() <= self.position);
131        Ok(self.position)
132    }
133
134    /// Forget all the state. The next call to `incremental_len` will assume it's the start of a new message.
135    pub fn reset(&mut self) {
136        self.max_position = NonZeroUsize::new(1).unwrap();
137        self.position = 0;
138        self.current_depth = 0;
139        self.sequences_wip.clear();
140        self.wip = Some(WIP::NextMarker);
141    }
142
143    fn read_one_item(&mut self, data: &mut &[u8]) -> Option<()> {
144        debug_assert!(self.wip.is_none());
145        let marker = self.read_marker(data)?;
146        match marker {
147            Marker::FixPos(_) => Some(()),
148            Marker::FixMap(len) => self.read_sequence(data, u32::from(len) * 2),
149            Marker::FixArray(len) => self.read_sequence(data, u32::from(len)),
150            Marker::FixStr(len) => self.skip_data(data, len.into()),
151            Marker::Null |
152            Marker::Reserved |
153            Marker::False |
154            Marker::True => Some(()),
155            Marker::Str8 |
156            Marker::Str16 |
157            Marker::Str32 |
158            Marker::Bin8 |
159            Marker::Bin16 |
160            Marker::Bin32 |
161            Marker::Array16 |
162            Marker::Array32 |
163            Marker::Map16 |
164            Marker::Map32 => self.read_marker_with_len(data, MarkerLen { marker, buf: [0; 4], has: 0 }),
165            Marker::Ext8 |
166            Marker::Ext16 |
167            Marker::Ext32 => todo!(),
168            Marker::F32 => self.skip_data(data, 4),
169            Marker::F64 => self.skip_data(data, 8),
170            Marker::U8 => self.skip_data(data, 1),
171            Marker::U16 => self.skip_data(data, 2),
172            Marker::U32 => self.skip_data(data, 4),
173            Marker::U64 => self.skip_data(data, 8),
174            Marker::I8 => self.skip_data(data, 1),
175            Marker::I16 => self.skip_data(data, 2),
176            Marker::I32 => self.skip_data(data, 4),
177            Marker::I64 => self.skip_data(data, 8),
178            Marker::FixExt1 |
179            Marker::FixExt2 |
180            Marker::FixExt4 |
181            Marker::FixExt8 |
182            Marker::FixExt16 => todo!(),
183            Marker::FixNeg(_) => Some(()),
184        }
185    }
186
187    fn read_marker_with_len(&mut self, data: &mut &[u8], mut wip: MarkerLen) -> Option<()> {
188        let size = wip.size();
189        debug_assert!(wip.has < size && size > 0 && size <= 4);
190        let dest = &mut wip.buf[0..size as usize];
191        let wanted = dest.len().checked_sub(wip.has as _)?;
192
193        let taken = self.take_bytes(data, wanted as u32);
194        dest[wip.has as usize..][..taken.len()].copy_from_slice(taken);
195        wip.has += taken.len() as u8;
196        if wip.has < size {
197            return self.fail(WIP::MarkerLen(wip));
198        }
199        let len = match dest.len() {
200            1 => dest[0].into(),
201            2 => u16::from_be_bytes(dest.try_into().unwrap()).into(),
202            4 => u32::from_be_bytes(dest.try_into().unwrap()),
203            _ => {
204                debug_assert!(false);
205                return None;
206            },
207        };
208        if len >= self.max_len {
209            return self.fail(WIP::LimitExceeded);
210        }
211        match wip.marker {
212            Marker::Bin8 |
213            Marker::Bin16 |
214            Marker::Bin32 |
215            Marker::Str8 |
216            Marker::Str16 |
217            Marker::Str32 => self.skip_data(data, len),
218            Marker::Ext8 |
219            Marker::Ext16 |
220            Marker::Ext32 => todo!(),
221            Marker::Array16 |
222            Marker::Array32 => self.read_sequence(data, len),
223            Marker::Map16 |
224            Marker::Map32 => {
225                if let Some(len) = len.checked_mul(2).filter(|&l| l < self.max_len) {
226                    self.read_sequence(data, len)
227                } else {
228                    self.fail(WIP::LimitExceeded)
229                }
230            },
231            _ => {
232                debug_assert!(false);
233                None
234            },
235        }
236    }
237
238    fn read_sequence(&mut self, data: &mut &[u8], mut items_left: u32) -> Option<()> {
239        self.current_depth += 1;
240        if self.current_depth > self.max_depth {
241            return self.fail(WIP::LimitExceeded);
242        }
243        while let Some(non_zero) = NonZeroU32::new(items_left) {
244            let position_before_item = self.position;
245            self.read_one_item(data).or_else(|| {
246                self.set_max_position(position_before_item + items_left as usize);
247                // -1, because it will increase depth again when resumed
248                self.sequences_wip.push(Seq {
249                    items_left: non_zero,
250                    depth: self.current_depth - 1,
251                });
252                None
253            })?;
254            items_left -= 1;
255        }
256        debug_assert!(self.current_depth > 0);
257        self.current_depth -= 1;
258        Some(())
259    }
260
261    fn skip_data(&mut self, data: &mut &[u8], wanted: u32) -> Option<()> {
262        let taken = self.take_bytes(data, wanted);
263        if let Some(bytes_left) = NonZeroU32::new(wanted - taken.len() as u32) {
264            debug_assert!(data.is_empty());
265            self.fail(WIP::Data(Data { bytes_left }))
266        } else {
267            Some(())
268        }
269    }
270
271    fn read_marker(&mut self, data: &mut &[u8]) -> Option<Marker> {
272        let Some((&b, rest)) = data.split_first() else {
273            debug_assert!(data.is_empty());
274            return self.fail(WIP::NextMarker);
275        };
276        self.position += 1;
277        *data = rest;
278        Some(Marker::from_u8(b))
279    }
280
281    fn set_max_position(&mut self, position: usize) {
282        self.max_position = NonZeroUsize::new(self.max_position.get().max(position)).unwrap();
283    }
284
285    /// May return less than requested
286    fn take_bytes<'data>(&mut self, data: &mut &'data [u8], wanted: u32) -> &'data [u8] {
287        let (taken, rest) = data.split_at(data.len().min(wanted as usize));
288        self.position += taken.len();
289        *data = rest;
290        taken
291    }
292
293    #[inline(always)]
294    fn fail<T>(&mut self, wip: WIP) -> Option<T> {
295        debug_assert!(self.wip.is_none());
296        let pos = match self.wip.insert(wip) {
297            WIP::NextMarker => self.position + 1,
298            WIP::Data(Data { bytes_left }) => self.position + bytes_left.get() as usize,
299            WIP::MarkerLen(m) => self.position + (m.size() - m.has) as usize,
300            WIP::LimitExceeded => 0,
301        };
302        self.set_max_position(pos);
303        None
304    }
305}
306
307enum WIP {
308    NextMarker,
309    Data(Data),
310    MarkerLen(MarkerLen),
311    LimitExceeded,
312}
313
314struct Seq { items_left: NonZeroU32, depth: u16 }
315struct Data { bytes_left: NonZeroU32 }
316struct MarkerLen { marker: Marker, buf: [u8; 4], has: u8 }
317
318impl MarkerLen {
319    fn size(&self) -> u8 {
320        match self.marker {
321            Marker::Bin8 => 1,
322            Marker::Bin16 => 2,
323            Marker::Bin32 => 4,
324            Marker::Ext8 => 1,
325            Marker::Ext16 => 2,
326            Marker::Ext32 => 4,
327            Marker::Str8 => 1,
328            Marker::Str16 => 2,
329            Marker::Str32 => 4,
330            Marker::Array16 => 2,
331            Marker::Array32 => 4,
332            Marker::Map16 => 2,
333            Marker::Map32 => 4,
334            _ => unimplemented!(),
335        }
336    }
337}