rodio/conversions/
channels.rs

1use crate::common::ChannelCount;
2use crate::Sample;
3
4/// Iterator that converts from a certain channel count to another.
5#[derive(Clone, Debug)]
6pub struct ChannelCountConverter<I>
7where
8    I: Iterator<Item = Sample>,
9{
10    input: I,
11    from: ChannelCount,
12    to: ChannelCount,
13    sample_repeat: Option<Sample>,
14    next_output_sample_pos: ChannelCount,
15}
16
17impl<I> ChannelCountConverter<I>
18where
19    I: Iterator<Item = Sample>,
20{
21    /// Initializes the iterator.
22    ///
23    /// # Panic
24    ///
25    /// Panics if `from` or `to` are equal to 0.
26    ///
27    #[inline]
28    pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter<I> {
29        assert!(from >= 1);
30        assert!(to >= 1);
31
32        ChannelCountConverter {
33            input,
34            from,
35            to,
36            sample_repeat: None,
37            next_output_sample_pos: 0,
38        }
39    }
40
41    /// Destroys this iterator and returns the underlying iterator.
42    #[inline]
43    pub fn into_inner(self) -> I {
44        self.input
45    }
46
47    /// Get mutable access to the iterator
48    #[inline]
49    pub fn inner_mut(&mut self) -> &mut I {
50        &mut self.input
51    }
52}
53
54impl<I> Iterator for ChannelCountConverter<I>
55where
56    I: Iterator<Item = Sample>,
57{
58    type Item = I::Item;
59
60    fn next(&mut self) -> Option<Self::Item> {
61        let result = match self.next_output_sample_pos {
62            0 => {
63                // save first sample for mono -> stereo conversion
64                let value = self.input.next();
65                self.sample_repeat = value;
66                value
67            }
68            x if x < self.from => self.input.next(),
69            1 => self.sample_repeat,
70            _ => Some(0.0),
71        };
72
73        if result.is_some() {
74            self.next_output_sample_pos += 1;
75        }
76
77        if self.next_output_sample_pos == self.to {
78            self.next_output_sample_pos = 0;
79
80            if self.from > self.to {
81                for _ in self.to..self.from {
82                    self.input.next(); // discarding extra input
83                }
84            }
85        }
86
87        result
88    }
89
90    #[inline]
91    fn size_hint(&self) -> (usize, Option<usize>) {
92        let (min, max) = self.input.size_hint();
93
94        let consumed = std::cmp::min(self.from, self.next_output_sample_pos) as usize;
95
96        let min = ((min + consumed) / self.from as usize * self.to as usize)
97            .saturating_sub(self.next_output_sample_pos as usize);
98
99        let max = max.map(|max| {
100            ((max + consumed) / self.from as usize * self.to as usize)
101                .saturating_sub(self.next_output_sample_pos as usize)
102        });
103
104        (min, max)
105    }
106}
107
108impl<I> ExactSizeIterator for ChannelCountConverter<I> where I: ExactSizeIterator<Item = Sample> {}
109
110#[cfg(test)]
111mod test {
112    use super::ChannelCountConverter;
113    use crate::common::ChannelCount;
114    use crate::Sample;
115
116    #[test]
117    fn remove_channels() {
118        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
119        let output = ChannelCountConverter::new(input.into_iter(), 3, 2).collect::<Vec<_>>();
120        assert_eq!(output, [1.0, 2.0, 4.0, 5.0]);
121
122        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
123        let output = ChannelCountConverter::new(input.into_iter(), 4, 1).collect::<Vec<_>>();
124        assert_eq!(output, [1.0, 5.0]);
125    }
126
127    #[test]
128    fn add_channels() {
129        let input = vec![1.0, 2.0, 3.0, 4.0];
130        let output = ChannelCountConverter::new(input.into_iter(), 1, 2).collect::<Vec<_>>();
131        assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]);
132
133        let input = vec![1.0, 2.0];
134        let output = ChannelCountConverter::new(input.into_iter(), 1, 4).collect::<Vec<_>>();
135        assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]);
136
137        let input = vec![1.0, 2.0, 3.0, 4.0];
138        let output = ChannelCountConverter::new(input.into_iter(), 2, 4).collect::<Vec<_>>();
139        assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
140    }
141
142    #[test]
143    fn size_hint() {
144        fn test(input: &[Sample], from: ChannelCount, to: ChannelCount) {
145            let mut converter = ChannelCountConverter::new(input.iter().copied(), from, to);
146            let count = converter.clone().count();
147            for left_in_iter in (0..=count).rev() {
148                println!("left_in_iter = {left_in_iter}");
149                assert_eq!(converter.size_hint(), (left_in_iter, Some(left_in_iter)));
150                converter.next();
151            }
152            assert_eq!(converter.size_hint(), (0, Some(0)));
153        }
154
155        test(&[1.0, 2.0, 3.0], 1, 2);
156        test(&[1.0, 2.0, 3.0, 4.0], 2, 4);
157        test(&[1.0, 2.0, 3.0, 4.0], 4, 2);
158        test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 8);
159        test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], 4, 1);
160    }
161
162    #[test]
163    fn len_more() {
164        let input = vec![1.0, 2.0, 3.0, 4.0];
165        let output = ChannelCountConverter::new(input.into_iter(), 2, 3);
166        assert_eq!(output.len(), 6);
167    }
168
169    #[test]
170    fn len_less() {
171        let input = vec![1.0, 2.0, 3.0, 4.0];
172        let output = ChannelCountConverter::new(input.into_iter(), 2, 1);
173        assert_eq!(output.len(), 2);
174    }
175}