actix_http/
extensions.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    fmt,
5    hash::{BuildHasherDefault, Hasher},
6};
7
8/// A hasher for `TypeId`s that takes advantage of its known characteristics.
9///
10/// Author of `anymap` crate has done research on the topic:
11/// https://github.com/chris-morgan/anymap/blob/2e9a5704/src/lib.rs#L599
12#[derive(Debug, Default)]
13struct NoOpHasher(u64);
14
15impl Hasher for NoOpHasher {
16    fn write(&mut self, _bytes: &[u8]) {
17        unimplemented!("This NoOpHasher can only handle u64s")
18    }
19
20    fn write_u64(&mut self, i: u64) {
21        self.0 = i;
22    }
23
24    fn finish(&self) -> u64 {
25        self.0
26    }
27}
28
29/// A type map for request extensions.
30///
31/// All entries into this map must be owned types (or static references).
32#[derive(Default)]
33pub struct Extensions {
34    // use no-op hasher with a std HashMap with for faster lookups on the small `TypeId` keys
35    map: HashMap<TypeId, Box<dyn Any>, BuildHasherDefault<NoOpHasher>>,
36}
37
38impl Extensions {
39    /// Creates an empty `Extensions`.
40    #[inline]
41    pub fn new() -> Extensions {
42        Extensions {
43            map: HashMap::default(),
44        }
45    }
46
47    /// Insert an item into the map.
48    ///
49    /// If an item of this type was already stored, it will be replaced and returned.
50    ///
51    /// ```
52    /// # use actix_http::Extensions;
53    /// let mut map = Extensions::new();
54    /// assert_eq!(map.insert(""), None);
55    /// assert_eq!(map.insert(1u32), None);
56    /// assert_eq!(map.insert(2u32), Some(1u32));
57    /// assert_eq!(*map.get::<u32>().unwrap(), 2u32);
58    /// ```
59    pub fn insert<T: 'static>(&mut self, val: T) -> Option<T> {
60        self.map
61            .insert(TypeId::of::<T>(), Box::new(val))
62            .and_then(downcast_owned)
63    }
64
65    /// Check if map contains an item of a given type.
66    ///
67    /// ```
68    /// # use actix_http::Extensions;
69    /// let mut map = Extensions::new();
70    /// assert!(!map.contains::<u32>());
71    ///
72    /// assert_eq!(map.insert(1u32), None);
73    /// assert!(map.contains::<u32>());
74    /// ```
75    pub fn contains<T: 'static>(&self) -> bool {
76        self.map.contains_key(&TypeId::of::<T>())
77    }
78
79    /// Get a reference to an item of a given type.
80    ///
81    /// ```
82    /// # use actix_http::Extensions;
83    /// let mut map = Extensions::new();
84    /// map.insert(1u32);
85    /// assert_eq!(map.get::<u32>(), Some(&1u32));
86    /// ```
87    pub fn get<T: 'static>(&self) -> Option<&T> {
88        self.map
89            .get(&TypeId::of::<T>())
90            .and_then(|boxed| boxed.downcast_ref())
91    }
92
93    /// Get a mutable reference to an item of a given type.
94    ///
95    /// ```
96    /// # use actix_http::Extensions;
97    /// let mut map = Extensions::new();
98    /// map.insert(1u32);
99    /// assert_eq!(map.get_mut::<u32>(), Some(&mut 1u32));
100    /// ```
101    pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
102        self.map
103            .get_mut(&TypeId::of::<T>())
104            .and_then(|boxed| boxed.downcast_mut())
105    }
106
107    /// Inserts the given `value` into the extensions if it is not present, then returns a reference
108    /// to the value in the extensions.
109    ///
110    /// ```
111    /// # use actix_http::Extensions;
112    /// let mut map = Extensions::new();
113    /// assert_eq!(map.get::<Vec<u32>>(), None);
114    ///
115    /// map.get_or_insert(Vec::<u32>::new()).push(1);
116    /// assert_eq!(map.get::<Vec<u32>>(), Some(&vec![1]));
117    ///
118    /// map.get_or_insert(Vec::<u32>::new()).push(2);
119    /// assert_eq!(map.get::<Vec<u32>>(), Some(&vec![1,2]));
120    /// ```
121    pub fn get_or_insert<T: 'static>(&mut self, value: T) -> &mut T {
122        self.get_or_insert_with(|| value)
123    }
124
125    /// Inserts a value computed from `f` into the extensions if the given `value` is not present,
126    /// then returns a reference to the value in the extensions.
127    ///
128    /// ```
129    /// # use actix_http::Extensions;
130    /// let mut map = Extensions::new();
131    /// assert_eq!(map.get::<Vec<u32>>(), None);
132    ///
133    /// map.get_or_insert_with(Vec::<u32>::new).push(1);
134    /// assert_eq!(map.get::<Vec<u32>>(), Some(&vec![1]));
135    ///
136    /// map.get_or_insert_with(Vec::<u32>::new).push(2);
137    /// assert_eq!(map.get::<Vec<u32>>(), Some(&vec![1,2]));
138    /// ```
139    pub fn get_or_insert_with<T: 'static, F: FnOnce() -> T>(&mut self, default: F) -> &mut T {
140        self.map
141            .entry(TypeId::of::<T>())
142            .or_insert_with(|| Box::new(default()))
143            .downcast_mut()
144            .expect("extensions map should now contain a T value")
145    }
146
147    /// Remove an item from the map of a given type.
148    ///
149    /// If an item of this type was already stored, it will be returned.
150    ///
151    /// ```
152    /// # use actix_http::Extensions;
153    /// let mut map = Extensions::new();
154    ///
155    /// map.insert(1u32);
156    /// assert_eq!(map.get::<u32>(), Some(&1u32));
157    ///
158    /// assert_eq!(map.remove::<u32>(), Some(1u32));
159    /// assert!(!map.contains::<u32>());
160    /// ```
161    pub fn remove<T: 'static>(&mut self) -> Option<T> {
162        self.map.remove(&TypeId::of::<T>()).and_then(downcast_owned)
163    }
164
165    /// Clear the `Extensions` of all inserted extensions.
166    ///
167    /// ```
168    /// # use actix_http::Extensions;
169    /// let mut map = Extensions::new();
170    ///
171    /// map.insert(1u32);
172    /// assert!(map.contains::<u32>());
173    ///
174    /// map.clear();
175    /// assert!(!map.contains::<u32>());
176    /// ```
177    #[inline]
178    pub fn clear(&mut self) {
179        self.map.clear();
180    }
181
182    /// Extends self with the items from another `Extensions`.
183    pub fn extend(&mut self, other: Extensions) {
184        self.map.extend(other.map);
185    }
186}
187
188impl fmt::Debug for Extensions {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        f.debug_struct("Extensions").finish()
191    }
192}
193
194fn downcast_owned<T: 'static>(boxed: Box<dyn Any>) -> Option<T> {
195    boxed.downcast().ok().map(|boxed| *boxed)
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_remove() {
204        let mut map = Extensions::new();
205
206        map.insert::<i8>(123);
207        assert!(map.get::<i8>().is_some());
208
209        map.remove::<i8>();
210        assert!(map.get::<i8>().is_none());
211    }
212
213    #[test]
214    fn test_clear() {
215        let mut map = Extensions::new();
216
217        map.insert::<i8>(8);
218        map.insert::<i16>(16);
219        map.insert::<i32>(32);
220
221        assert!(map.contains::<i8>());
222        assert!(map.contains::<i16>());
223        assert!(map.contains::<i32>());
224
225        map.clear();
226
227        assert!(!map.contains::<i8>());
228        assert!(!map.contains::<i16>());
229        assert!(!map.contains::<i32>());
230
231        map.insert::<i8>(10);
232        assert_eq!(*map.get::<i8>().unwrap(), 10);
233    }
234
235    #[test]
236    fn test_integers() {
237        static A: u32 = 8;
238
239        let mut map = Extensions::new();
240
241        map.insert::<i8>(8);
242        map.insert::<i16>(16);
243        map.insert::<i32>(32);
244        map.insert::<i64>(64);
245        map.insert::<i128>(128);
246        map.insert::<u8>(8);
247        map.insert::<u16>(16);
248        map.insert::<u32>(32);
249        map.insert::<u64>(64);
250        map.insert::<u128>(128);
251        map.insert::<&'static u32>(&A);
252        assert!(map.get::<i8>().is_some());
253        assert!(map.get::<i16>().is_some());
254        assert!(map.get::<i32>().is_some());
255        assert!(map.get::<i64>().is_some());
256        assert!(map.get::<i128>().is_some());
257        assert!(map.get::<u8>().is_some());
258        assert!(map.get::<u16>().is_some());
259        assert!(map.get::<u32>().is_some());
260        assert!(map.get::<u64>().is_some());
261        assert!(map.get::<u128>().is_some());
262        assert!(map.get::<&'static u32>().is_some());
263    }
264
265    #[test]
266    fn test_composition() {
267        struct Magi<T>(pub T);
268
269        struct Madoka {
270            pub god: bool,
271        }
272
273        struct Homura {
274            pub attempts: usize,
275        }
276
277        struct Mami {
278            pub guns: usize,
279        }
280
281        let mut map = Extensions::new();
282
283        map.insert(Magi(Madoka { god: false }));
284        map.insert(Magi(Homura { attempts: 0 }));
285        map.insert(Magi(Mami { guns: 999 }));
286
287        assert!(!map.get::<Magi<Madoka>>().unwrap().0.god);
288        assert_eq!(0, map.get::<Magi<Homura>>().unwrap().0.attempts);
289        assert_eq!(999, map.get::<Magi<Mami>>().unwrap().0.guns);
290    }
291
292    #[test]
293    fn test_extensions() {
294        #[derive(Debug, PartialEq)]
295        struct MyType(i32);
296
297        let mut extensions = Extensions::new();
298
299        extensions.insert(5i32);
300        extensions.insert(MyType(10));
301
302        assert_eq!(extensions.get(), Some(&5i32));
303        assert_eq!(extensions.get_mut(), Some(&mut 5i32));
304
305        assert_eq!(extensions.remove::<i32>(), Some(5i32));
306        assert!(extensions.get::<i32>().is_none());
307
308        assert_eq!(extensions.get::<bool>(), None);
309        assert_eq!(extensions.get(), Some(&MyType(10)));
310    }
311
312    #[test]
313    fn test_extend() {
314        #[derive(Debug, PartialEq)]
315        struct MyType(i32);
316
317        let mut extensions = Extensions::new();
318
319        extensions.insert(5i32);
320        extensions.insert(MyType(10));
321
322        let mut other = Extensions::new();
323
324        other.insert(15i32);
325        other.insert(20u8);
326
327        extensions.extend(other);
328
329        assert_eq!(extensions.get(), Some(&15i32));
330        assert_eq!(extensions.get_mut(), Some(&mut 15i32));
331
332        assert_eq!(extensions.remove::<i32>(), Some(15i32));
333        assert!(extensions.get::<i32>().is_none());
334
335        assert_eq!(extensions.get::<bool>(), None);
336        assert_eq!(extensions.get(), Some(&MyType(10)));
337
338        assert_eq!(extensions.get(), Some(&20u8));
339        assert_eq!(extensions.get_mut(), Some(&mut 20u8));
340    }
341}