1use crate::common::ChannelCount;
2use crate::Sample;
3
4#[derive(Clone, Debug)]
6pub struct ChannelCountConverter<I>
7where
8 I: Iterator<Item = Sample>,
9{
10 input: I,
11 from: ChannelCount,
12 to: ChannelCount,
13 sample_repeat: Option<Sample>,
14 next_output_sample_pos: ChannelCount,
15}
16
17impl<I> ChannelCountConverter<I>
18where
19 I: Iterator<Item = Sample>,
20{
21 #[inline]
28 pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter<I> {
29 assert!(from >= 1);
30 assert!(to >= 1);
31
32 ChannelCountConverter {
33 input,
34 from,
35 to,
36 sample_repeat: None,
37 next_output_sample_pos: 0,
38 }
39 }
40
41 #[inline]
43 pub fn into_inner(self) -> I {
44 self.input
45 }
46
47 #[inline]
49 pub fn inner_mut(&mut self) -> &mut I {
50 &mut self.input
51 }
52}
53
54impl<I> Iterator for ChannelCountConverter<I>
55where
56 I: Iterator<Item = Sample>,
57{
58 type Item = I::Item;
59
60 fn next(&mut self) -> Option<Self::Item> {
61 let result = match self.next_output_sample_pos {
62 0 => {
63 let value = self.input.next();
65 self.sample_repeat = value;
66 value
67 }
68 x if x < self.from => self.input.next(),
69 1 => self.sample_repeat,
70 _ => Some(0.0),
71 };
72
73 if result.is_some() {
74 self.next_output_sample_pos += 1;
75 }
76
77 if self.next_output_sample_pos == self.to {
78 self.next_output_sample_pos = 0;
79
80 if self.from > self.to {
81 for _ in self.to..self.from {
82 self.input.next(); }
84 }
85 }
86
87 result
88 }
89
90 #[inline]
91 fn size_hint(&self) -> (usize, Option<usize>) {
92 let (min, max) = self.input.size_hint();
93
94 let consumed = std::cmp::min(self.from, self.next_output_sample_pos) as usize;
95
96 let min = ((min + consumed) / self.from as usize * self.to as usize)
97 .saturating_sub(self.next_output_sample_pos as usize);
98
99 let max = max.map(|max| {
100 ((max + consumed) / self.from as usize * self.to as usize)
101 .saturating_sub(self.next_output_sample_pos as usize)
102 });
103
104 (min, max)
105 }
106}
107
108impl<I> ExactSizeIterator for ChannelCountConverter<I> where I: ExactSizeIterator<Item = Sample> {}
109
110#[cfg(test)]
111mod test {
112 use super::ChannelCountConverter;
113 use crate::common::ChannelCount;
114 use crate::Sample;
115
116 #[test]
117 fn remove_channels() {
118 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
119 let output = ChannelCountConverter::new(input.into_iter(), 3, 2).collect::<Vec<_>>();
120 assert_eq!(output, [1.0, 2.0, 4.0, 5.0]);
121
122 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
123 let output = ChannelCountConverter::new(input.into_iter(), 4, 1).collect::<Vec<_>>();
124 assert_eq!(output, [1.0, 5.0]);
125 }
126
127 #[test]
128 fn add_channels() {
129 let input = vec![1.0, 2.0, 3.0, 4.0];
130 let output = ChannelCountConverter::new(input.into_iter(), 1, 2).collect::<Vec<_>>();
131 assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]);
132
133 let input = vec![1.0, 2.0];
134 let output = ChannelCountConverter::new(input.into_iter(), 1, 4).collect::<Vec<_>>();
135 assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]);
136
137 let input = vec![1.0, 2.0, 3.0, 4.0];
138 let output = ChannelCountConverter::new(input.into_iter(), 2, 4).collect::<Vec<_>>();
139 assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
140 }
141
142 #[test]
143 fn size_hint() {
144 fn test(input: &[Sample], from: ChannelCount, to: ChannelCount) {
145 let mut converter = ChannelCountConverter::new(input.iter().copied(), from, to);
146 let count = converter.clone().count();
147 for left_in_iter in (0..=count).rev() {
148 println!("left_in_iter = {left_in_iter}");
149 assert_eq!(converter.size_hint(), (left_in_iter, Some(left_in_iter)));
150 converter.next();
151 }
152 assert_eq!(converter.size_hint(), (0, Some(0)));
153 }
154
155 test(&[1.0, 2.0, 3.0], 1, 2);
156 test(&[1.0, 2.0, 3.0, 4.0], 2, 4);
157 test(&[1.0, 2.0, 3.0, 4.0], 4, 2);
158 test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 8);
159 test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], 4, 1);
160 }
161
162 #[test]
163 fn len_more() {
164 let input = vec![1.0, 2.0, 3.0, 4.0];
165 let output = ChannelCountConverter::new(input.into_iter(), 2, 3);
166 assert_eq!(output.len(), 6);
167 }
168
169 #[test]
170 fn len_less() {
171 let input = vec![1.0, 2.0, 3.0, 4.0];
172 let output = ChannelCountConverter::new(input.into_iter(), 2, 1);
173 assert_eq!(output.len(), 2);
174 }
175}