1use crate::common::ChannelCount;
2use crate::Sample;
3
4#[derive(Clone, Debug)]
6pub struct ChannelCountConverter<I>
7where
8 I: Iterator<Item = Sample>,
9{
10 input: I,
11 from: ChannelCount,
12 to: ChannelCount,
13 sample_repeat: Option<Sample>,
14 next_output_sample_pos: u16,
15}
16
17impl<I> ChannelCountConverter<I>
18where
19 I: Iterator<Item = Sample>,
20{
21 #[inline]
28 pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter<I> {
29 ChannelCountConverter {
30 input,
31 from,
32 to,
33 sample_repeat: None,
34 next_output_sample_pos: 0,
35 }
36 }
37
38 #[inline]
40 pub fn into_inner(self) -> I {
41 self.input
42 }
43
44 #[inline]
46 pub fn inner_mut(&mut self) -> &mut I {
47 &mut self.input
48 }
49}
50
51impl<I> Iterator for ChannelCountConverter<I>
52where
53 I: Iterator<Item = Sample>,
54{
55 type Item = I::Item;
56
57 fn next(&mut self) -> Option<Self::Item> {
58 let result = match self.next_output_sample_pos {
59 0 => {
60 let value = self.input.next();
62 self.sample_repeat = value;
63 value
64 }
65 x if x < self.from.get() => self.input.next(),
66 1 => self.sample_repeat,
67 _ => Some(0.0),
68 };
69
70 if result.is_some() {
71 self.next_output_sample_pos += 1;
72 }
73
74 if self.next_output_sample_pos == self.to.get() {
75 self.next_output_sample_pos = 0;
76
77 if self.from > self.to {
78 for _ in self.to.get()..self.from.get() {
79 self.input.next(); }
81 }
82 }
83
84 result
85 }
86
87 #[inline]
88 fn size_hint(&self) -> (usize, Option<usize>) {
89 let (min, max) = self.input.size_hint();
90
91 let consumed = std::cmp::min(self.from.get(), self.next_output_sample_pos) as usize;
92
93 let min = ((min + consumed) / self.from.get() as usize * self.to.get() as usize)
94 .saturating_sub(self.next_output_sample_pos as usize);
95
96 let max = max.map(|max| {
97 ((max + consumed) / self.from.get() as usize * self.to.get() as usize)
98 .saturating_sub(self.next_output_sample_pos as usize)
99 });
100
101 (min, max)
102 }
103}
104
105impl<I> ExactSizeIterator for ChannelCountConverter<I> where I: ExactSizeIterator<Item = Sample> {}
106
107#[cfg(test)]
108mod test {
109 use super::ChannelCountConverter;
110 use crate::common::ChannelCount;
111 use crate::math::nz;
112 use crate::Sample;
113
114 #[test]
115 fn remove_channels() {
116 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
117 let output =
118 ChannelCountConverter::new(input.into_iter(), nz!(3), nz!(2)).collect::<Vec<_>>();
119 assert_eq!(output, [1.0, 2.0, 4.0, 5.0]);
120
121 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
122 let output =
123 ChannelCountConverter::new(input.into_iter(), nz!(4), nz!(1)).collect::<Vec<_>>();
124 assert_eq!(output, [1.0, 5.0]);
125 }
126
127 #[test]
128 fn add_channels() {
129 let input = vec![1.0, 2.0, 3.0, 4.0];
130 let output =
131 ChannelCountConverter::new(input.into_iter(), nz!(1), nz!(2)).collect::<Vec<_>>();
132 assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]);
133
134 let input = vec![1.0, 2.0];
135 let output =
136 ChannelCountConverter::new(input.into_iter(), nz!(1), nz!(4)).collect::<Vec<_>>();
137 assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]);
138
139 let input = vec![1.0, 2.0, 3.0, 4.0];
140 let output =
141 ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(4)).collect::<Vec<_>>();
142 assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
143 }
144
145 #[test]
146 fn size_hint() {
147 fn test(input: &[Sample], from: ChannelCount, to: ChannelCount) {
148 let mut converter = ChannelCountConverter::new(input.iter().copied(), from, to);
149 let count = converter.clone().count();
150 for left_in_iter in (0..=count).rev() {
151 println!("left_in_iter = {left_in_iter}");
152 assert_eq!(converter.size_hint(), (left_in_iter, Some(left_in_iter)));
153 converter.next();
154 }
155 assert_eq!(converter.size_hint(), (0, Some(0)));
156 }
157
158 test(&[1.0, 2.0, 3.0], nz!(1), nz!(2));
159 test(&[1.0, 2.0, 3.0, 4.0], nz!(2), nz!(4));
160 test(&[1.0, 2.0, 3.0, 4.0], nz!(4), nz!(2));
161 test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], nz!(3), nz!(8));
162 test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], nz!(4), nz!(1));
163 }
164
165 #[test]
166 fn len_more() {
167 let input = vec![1.0, 2.0, 3.0, 4.0];
168 let output = ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(3));
169 assert_eq!(output.len(), 6);
170 }
171
172 #[test]
173 fn len_less() {
174 let input = vec![1.0, 2.0, 3.0, 4.0];
175 let output = ChannelCountConverter::new(input.into_iter(), nz!(2), nz!(1));
176 assert_eq!(output.len(), 2);
177 }
178}