Skip to main content

rodio/conversions/
channels.rs

1use crate::common::ChannelCount;
2use crate::Sample;
3
4/// Iterator that converts from a certain channel count to another.
5#[derive(Clone, Debug)]
6pub struct ChannelCountConverter<I>
7where
8    I: Iterator<Item = Sample>,
9{
10    input: I,
11    from: ChannelCount,
12    to: ChannelCount,
13    sample_repeat: Option<Sample>,
14    next_output_sample_pos: u16,
15}
16
17impl<I> ChannelCountConverter<I>
18where
19    I: Iterator<Item = Sample>,
20{
21    /// Initializes the iterator.
22    ///
23    /// # Panic
24    ///
25    /// Panics if `from` or `to` are equal to 0.
26    ///
27    #[inline]
28    pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter<I> {
29        ChannelCountConverter {
30            input,
31            from,
32            to,
33            sample_repeat: None,
34            next_output_sample_pos: 0,
35        }
36    }
37
38    /// Destroys this iterator and returns the underlying iterator.
39    #[inline]
40    pub fn into_inner(self) -> I {
41        self.input
42    }
43
44    /// Get mutable access to the iterator
45    #[inline]
46    pub fn inner_mut(&mut self) -> &mut I {
47        &mut self.input
48    }
49}
50
51impl<I> Iterator for ChannelCountConverter<I>
52where
53    I: Iterator<Item = Sample>,
54{
55    type Item = I::Item;
56
57    fn next(&mut self) -> Option<Self::Item> {
58        let result = match self.next_output_sample_pos {
59            0 => {
60                // save first sample for mono -> stereo conversion
61                let value = self.input.next();
62                self.sample_repeat = value;
63                value
64            }
65            x if x < self.from.get() => self.input.next(),
66            1 => self.sample_repeat,
67            _ => Some(0.0),
68        };
69
70        if result.is_some() {
71            self.next_output_sample_pos += 1;
72        }
73
74        if self.next_output_sample_pos == self.to.get() {
75            self.next_output_sample_pos = 0;
76
77            if self.from > self.to {
78                for _ in self.to.get()..self.from.get() {
79                    self.input.next(); // discarding extra input
80                }
81            }
82        }
83
84        result
85    }
86
87    #[inline]
88    fn size_hint(&self) -> (usize, Option<usize>) {
89        let (min, max) = self.input.size_hint();
90
91        let consumed = std::cmp::min(self.from.get(), self.next_output_sample_pos) as usize;
92
93        let min = ((min + consumed) / self.from.get() as usize * self.to.get() as usize)
94            .saturating_sub(self.next_output_sample_pos as usize);
95
96        let max = max.map(|max| {
97            ((max + consumed) / self.from.get() as usize * self.to.get() as usize)
98                .saturating_sub(self.next_output_sample_pos as usize)
99        });
100
101        (min, max)
102    }
103}
104
105impl<I> ExactSizeIterator for ChannelCountConverter<I> where I: ExactSizeIterator<Item = Sample> {}
106
107#[cfg(test)]
108mod test {
109    use super::ChannelCountConverter;
110    use crate::common::ChannelCount;
111    use crate::math::nz;
112    use crate::Sample;
113
114    #[test]
115    fn remove_channels() {
116        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
117        let output =
118            ChannelCountConverter::new(input.into_iter(), nz!(3), nz!(2)).collect::<Vec<_>>();
119        assert_eq!(output, [1.0, 2.0, 4.0, 5.0]);
120
121        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
122        let output =
123            ChannelCountConverter::new(input.into_iter(), nz!(4), nz!(1)).collect::<Vec<_>>();
124        assert_eq!(output, [1.0, 5.0]);
125    }
126
127    #[test]
128    fn add_channels() {
129        let input = vec![1.0, 2.0, 3.0, 4.0];
130        let output =
131            ChannelCountConverter::new(input.into_iter(), nz!(1), nz!(2)).collect::<Vec<_>>();
132        assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]);
133
134        let input = vec![1.0, 2.0];
135        let output =
136            ChannelCountConverter::new(input.into_iter(), nz!(1), nz!(4)).collect::<Vec<_>>();
137        assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]);
138
139        let input = vec![1.0, 2.0, 3.0, 4.0];
140        let output =
141            ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(4)).collect::<Vec<_>>();
142        assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
143    }
144
145    #[test]
146    fn size_hint() {
147        fn test(input: &[Sample], from: ChannelCount, to: ChannelCount) {
148            let mut converter = ChannelCountConverter::new(input.iter().copied(), from, to);
149            let count = converter.clone().count();
150            for left_in_iter in (0..=count).rev() {
151                println!("left_in_iter = {left_in_iter}");
152                assert_eq!(converter.size_hint(), (left_in_iter, Some(left_in_iter)));
153                converter.next();
154            }
155            assert_eq!(converter.size_hint(), (0, Some(0)));
156        }
157
158        test(&[1.0, 2.0, 3.0], nz!(1), nz!(2));
159        test(&[1.0, 2.0, 3.0, 4.0], nz!(2), nz!(4));
160        test(&[1.0, 2.0, 3.0, 4.0], nz!(4), nz!(2));
161        test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], nz!(3), nz!(8));
162        test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], nz!(4), nz!(1));
163    }
164
165    #[test]
166    fn len_more() {
167        let input = vec![1.0, 2.0, 3.0, 4.0];
168        let output = ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(3));
169        assert_eq!(output.len(), 6);
170    }
171
172    #[test]
173    fn len_less() {
174        let input = vec![1.0, 2.0, 3.0, 4.0];
175        let output = ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(1));
176        assert_eq!(output.len(), 2);
177    }
178}