1use crate::Marker;
2use std::num::{NonZeroU32, NonZeroUsize};
3
4pub struct MessageLen {
7 wip: Option<WIP>,
9 max_position: NonZeroUsize,
11 position: usize,
13 sequences_wip: Vec<Seq>,
16 current_depth: u16,
18 max_depth: u16,
20 max_len: u32,
22}
23
24#[derive(Debug)]
26pub enum LenError {
27 Truncated(NonZeroUsize),
29 ParseError,
31}
32
33impl LenError {
34 #[must_use]
36 pub fn len(&self) -> usize {
37 match *self {
38 Self::ParseError => 0,
39 Self::Truncated(l) => l.get(),
40 }
41 }
42}
43
44impl MessageLen {
45 #[must_use]
51 pub fn new() -> Self {
52 Self::with_limits(1024, (u32::MAX as usize).min(isize::MAX as usize / 2))
53 }
54
55 #[must_use]
63 pub fn with_limits(max_depth: usize, max_len: usize) -> Self {
64 Self {
65 max_position: NonZeroUsize::new(1).unwrap(),
66 position: 0,
67 current_depth: 0,
68 max_depth: max_depth.min(u16::MAX as _) as u16,
69 max_len: max_len.min(u32::MAX as _) as u32,
70 sequences_wip: Vec::new(),
71 wip: Some(WIP::NextMarker),
72 }
73 }
74
75 pub fn len_of(complete_message: &[u8]) -> Result<usize, LenError> {
89 Self::with_limits(1024, 1 << 30).incremental_len(complete_message.as_ref())
90 }
91
92 pub fn incremental_len(&mut self, mut next_message_fragment: &[u8]) -> Result<usize, LenError> {
110 let data = &mut next_message_fragment;
111 let Some(wip) = self.wip.take() else {
112 return Ok(self.position); };
114 match wip {
115 WIP::Data(Data { bytes_left }) => self.skip_data(data, bytes_left.get()),
116 WIP::MarkerLen(wip) => self.read_marker_with_len(data, wip),
117 WIP::NextMarker => self.read_one_item(data),
118 WIP::LimitExceeded => {
119 self.wip = Some(WIP::LimitExceeded); return Err(LenError::ParseError);
121 },
122 }.ok_or(LenError::Truncated(self.max_position))?;
123
124 while let Some(seq) = self.sequences_wip.pop() {
125 self.current_depth = seq.depth;
126 debug_assert!(self.wip.is_none());
127 self.read_sequence(data, seq.items_left.get() - 1).ok_or(LenError::Truncated(self.max_position))?;
128 }
129 debug_assert!(self.wip.is_none());
130 debug_assert!(self.max_position.get() <= self.position);
131 Ok(self.position)
132 }
133
134 pub fn reset(&mut self) {
136 self.max_position = NonZeroUsize::new(1).unwrap();
137 self.position = 0;
138 self.current_depth = 0;
139 self.sequences_wip.clear();
140 self.wip = Some(WIP::NextMarker);
141 }
142
143 fn read_one_item(&mut self, data: &mut &[u8]) -> Option<()> {
144 debug_assert!(self.wip.is_none());
145 let marker = self.read_marker(data)?;
146 match marker {
147 Marker::FixPos(_) => Some(()),
148 Marker::FixMap(len) => self.read_sequence(data, u32::from(len) * 2),
149 Marker::FixArray(len) => self.read_sequence(data, u32::from(len)),
150 Marker::FixStr(len) => self.skip_data(data, len.into()),
151 Marker::Null |
152 Marker::Reserved |
153 Marker::False |
154 Marker::True => Some(()),
155 Marker::Str8 |
156 Marker::Str16 |
157 Marker::Str32 |
158 Marker::Bin8 |
159 Marker::Bin16 |
160 Marker::Bin32 |
161 Marker::Array16 |
162 Marker::Array32 |
163 Marker::Map16 |
164 Marker::Map32 => self.read_marker_with_len(data, MarkerLen { marker, buf: [0; 4], has: 0 }),
165 Marker::Ext8 |
166 Marker::Ext16 |
167 Marker::Ext32 => todo!(),
168 Marker::F32 => self.skip_data(data, 4),
169 Marker::F64 => self.skip_data(data, 8),
170 Marker::U8 => self.skip_data(data, 1),
171 Marker::U16 => self.skip_data(data, 2),
172 Marker::U32 => self.skip_data(data, 4),
173 Marker::U64 => self.skip_data(data, 8),
174 Marker::I8 => self.skip_data(data, 1),
175 Marker::I16 => self.skip_data(data, 2),
176 Marker::I32 => self.skip_data(data, 4),
177 Marker::I64 => self.skip_data(data, 8),
178 Marker::FixExt1 |
179 Marker::FixExt2 |
180 Marker::FixExt4 |
181 Marker::FixExt8 |
182 Marker::FixExt16 => todo!(),
183 Marker::FixNeg(_) => Some(()),
184 }
185 }
186
187 fn read_marker_with_len(&mut self, data: &mut &[u8], mut wip: MarkerLen) -> Option<()> {
188 let size = wip.size();
189 debug_assert!(wip.has < size && size > 0 && size <= 4);
190 let dest = &mut wip.buf[0..size as usize];
191 let wanted = dest.len().checked_sub(wip.has as _)?;
192
193 let taken = self.take_bytes(data, wanted as u32);
194 dest[wip.has as usize..][..taken.len()].copy_from_slice(taken);
195 wip.has += taken.len() as u8;
196 if wip.has < size {
197 return self.fail(WIP::MarkerLen(wip));
198 }
199 let len = match dest.len() {
200 1 => dest[0].into(),
201 2 => u16::from_be_bytes(dest.try_into().unwrap()).into(),
202 4 => u32::from_be_bytes(dest.try_into().unwrap()),
203 _ => {
204 debug_assert!(false);
205 return None;
206 },
207 };
208 if len >= self.max_len {
209 return self.fail(WIP::LimitExceeded);
210 }
211 match wip.marker {
212 Marker::Bin8 |
213 Marker::Bin16 |
214 Marker::Bin32 |
215 Marker::Str8 |
216 Marker::Str16 |
217 Marker::Str32 => self.skip_data(data, len),
218 Marker::Ext8 |
219 Marker::Ext16 |
220 Marker::Ext32 => todo!(),
221 Marker::Array16 |
222 Marker::Array32 => self.read_sequence(data, len),
223 Marker::Map16 |
224 Marker::Map32 => {
225 if let Some(len) = len.checked_mul(2).filter(|&l| l < self.max_len) {
226 self.read_sequence(data, len)
227 } else {
228 self.fail(WIP::LimitExceeded)
229 }
230 },
231 _ => {
232 debug_assert!(false);
233 None
234 },
235 }
236 }
237
238 fn read_sequence(&mut self, data: &mut &[u8], mut items_left: u32) -> Option<()> {
239 self.current_depth += 1;
240 if self.current_depth > self.max_depth {
241 return self.fail(WIP::LimitExceeded);
242 }
243 while let Some(non_zero) = NonZeroU32::new(items_left) {
244 let position_before_item = self.position;
245 self.read_one_item(data).or_else(|| {
246 self.set_max_position(position_before_item + items_left as usize);
247 self.sequences_wip.push(Seq {
249 items_left: non_zero,
250 depth: self.current_depth - 1,
251 });
252 None
253 })?;
254 items_left -= 1;
255 }
256 debug_assert!(self.current_depth > 0);
257 self.current_depth -= 1;
258 Some(())
259 }
260
261 fn skip_data(&mut self, data: &mut &[u8], wanted: u32) -> Option<()> {
262 let taken = self.take_bytes(data, wanted);
263 if let Some(bytes_left) = NonZeroU32::new(wanted - taken.len() as u32) {
264 debug_assert!(data.is_empty());
265 self.fail(WIP::Data(Data { bytes_left }))
266 } else {
267 Some(())
268 }
269 }
270
271 fn read_marker(&mut self, data: &mut &[u8]) -> Option<Marker> {
272 let Some((&b, rest)) = data.split_first() else {
273 debug_assert!(data.is_empty());
274 return self.fail(WIP::NextMarker);
275 };
276 self.position += 1;
277 *data = rest;
278 Some(Marker::from_u8(b))
279 }
280
281 fn set_max_position(&mut self, position: usize) {
282 self.max_position = NonZeroUsize::new(self.max_position.get().max(position)).unwrap();
283 }
284
285 fn take_bytes<'data>(&mut self, data: &mut &'data [u8], wanted: u32) -> &'data [u8] {
287 let (taken, rest) = data.split_at(data.len().min(wanted as usize));
288 self.position += taken.len();
289 *data = rest;
290 taken
291 }
292
293 #[inline(always)]
294 fn fail<T>(&mut self, wip: WIP) -> Option<T> {
295 debug_assert!(self.wip.is_none());
296 let pos = match self.wip.insert(wip) {
297 WIP::NextMarker => self.position + 1,
298 WIP::Data(Data { bytes_left }) => self.position + bytes_left.get() as usize,
299 WIP::MarkerLen(m) => self.position + (m.size() - m.has) as usize,
300 WIP::LimitExceeded => 0,
301 };
302 self.set_max_position(pos);
303 None
304 }
305}
306
307enum WIP {
308 NextMarker,
309 Data(Data),
310 MarkerLen(MarkerLen),
311 LimitExceeded,
312}
313
314struct Seq { items_left: NonZeroU32, depth: u16 }
315struct Data { bytes_left: NonZeroU32 }
316struct MarkerLen { marker: Marker, buf: [u8; 4], has: u8 }
317
318impl MarkerLen {
319 fn size(&self) -> u8 {
320 match self.marker {
321 Marker::Bin8 => 1,
322 Marker::Bin16 => 2,
323 Marker::Bin32 => 4,
324 Marker::Ext8 => 1,
325 Marker::Ext16 => 2,
326 Marker::Ext32 => 4,
327 Marker::Str8 => 1,
328 Marker::Str16 => 2,
329 Marker::Str32 => 4,
330 Marker::Array16 => 2,
331 Marker::Array32 => 4,
332 Marker::Map16 => 2,
333 Marker::Map32 => 4,
334 _ => unimplemented!(),
335 }
336 }
337}