Skip to main content

rodio/source/
linear_ramp.rs

1use std::time::Duration;
2
3use super::SeekError;
4use crate::common::{ChannelCount, SampleRate};
5use crate::math::{duration_to_float, NANOS_PER_SEC};
6use crate::{Float, Source};
7
8/// Internal function that builds a `LinearRamp` object.
9pub fn linear_gain_ramp<I>(
10    input: I,
11    duration: Duration,
12    start_gain: Float,
13    end_gain: Float,
14    clamp_end: bool,
15) -> LinearGainRamp<I>
16where
17    I: Source,
18{
19    assert!(!duration.is_zero(), "duration must be greater than zero");
20
21    LinearGainRamp {
22        input,
23        elapsed: Duration::ZERO,
24        total: duration,
25        start_gain,
26        end_gain,
27        clamp_end,
28        sample_idx: 0u64,
29    }
30}
31
32/// Filter that adds a linear gain ramp to the source over a given time range.
33#[derive(Clone, Debug)]
34pub struct LinearGainRamp<I> {
35    input: I,
36    elapsed: Duration,
37    total: Duration,
38    start_gain: Float,
39    end_gain: Float,
40    clamp_end: bool,
41    sample_idx: u64,
42}
43
44impl<I> LinearGainRamp<I>
45where
46    I: Source,
47{
48    /// Returns a reference to the inner source.
49    #[inline]
50    pub fn inner(&self) -> &I {
51        &self.input
52    }
53
54    /// Returns a mutable reference to the inner source.
55    #[inline]
56    pub fn inner_mut(&mut self) -> &mut I {
57        &mut self.input
58    }
59
60    /// Returns the inner source.
61    #[inline]
62    pub fn into_inner(self) -> I {
63        self.input
64    }
65}
66
67impl<I> Iterator for LinearGainRamp<I>
68where
69    I: Source,
70{
71    type Item = I::Item;
72
73    #[inline]
74    fn next(&mut self) -> Option<I::Item> {
75        let factor: Float;
76
77        if self.elapsed >= self.total {
78            if self.clamp_end {
79                factor = self.end_gain;
80            } else {
81                factor = 1.0;
82            }
83        } else {
84            self.sample_idx += 1;
85
86            // Calculate progress (0.0 to 1.0) using appropriate precision for Float type
87            let p = duration_to_float(self.elapsed) / duration_to_float(self.total);
88
89            factor = self.start_gain * (1.0 - p) + self.end_gain * p;
90        }
91
92        if self.sample_idx.is_multiple_of(self.channels().get() as u64) {
93            let sample_duration =
94                Duration::from_nanos(NANOS_PER_SEC / self.input.sample_rate().get() as u64);
95            self.elapsed += sample_duration;
96        }
97
98        self.input.next().map(|value| value * factor)
99    }
100
101    #[inline]
102    fn size_hint(&self) -> (usize, Option<usize>) {
103        self.input.size_hint()
104    }
105}
106
107impl<I> ExactSizeIterator for LinearGainRamp<I> where I: Source + ExactSizeIterator {}
108
109impl<I> Source for LinearGainRamp<I>
110where
111    I: Source,
112{
113    #[inline]
114    fn current_span_len(&self) -> Option<usize> {
115        self.input.current_span_len()
116    }
117
118    #[inline]
119    fn channels(&self) -> ChannelCount {
120        self.input.channels()
121    }
122
123    #[inline]
124    fn sample_rate(&self) -> SampleRate {
125        self.input.sample_rate()
126    }
127
128    #[inline]
129    fn total_duration(&self) -> Option<Duration> {
130        self.input.total_duration()
131    }
132
133    #[inline]
134    fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> {
135        self.elapsed = pos;
136        self.input.try_seek(pos)
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use approx::assert_abs_diff_eq;
143
144    use super::*;
145    use crate::buffer::SamplesBuffer;
146    use crate::math::nz;
147    use crate::Sample;
148
149    /// Create a SamplesBuffer of identical samples with value `value`.
150    /// Returned buffer is one channel and has a sample rate of 1 hz.
151    fn const_source(length: u8, value: Sample) -> SamplesBuffer {
152        let data: Vec<Sample> = (1..=length).map(|_| value).collect();
153        SamplesBuffer::new(nz!(1), nz!(1), data)
154    }
155
156    /// Create a SamplesBuffer of repeating sample values from `values`.
157    fn cycle_source(length: u8, values: Vec<Sample>) -> SamplesBuffer {
158        let data: Vec<Sample> = (1..=length)
159            .enumerate()
160            .map(|(i, _)| values[i % values.len()])
161            .collect();
162
163        SamplesBuffer::new(nz!(1), nz!(1), data)
164    }
165
166    #[test]
167    fn test_linear_ramp() {
168        let source1 = const_source(10, 1.0);
169        let mut faded = linear_gain_ramp(source1, Duration::from_secs(4), 0.0, 1.0, true);
170
171        assert_eq!(faded.next(), Some(0.0));
172        assert_eq!(faded.next(), Some(0.25));
173        assert_eq!(faded.next(), Some(0.5));
174        assert_eq!(faded.next(), Some(0.75));
175        assert_eq!(faded.next(), Some(1.0));
176        assert_eq!(faded.next(), Some(1.0));
177        assert_eq!(faded.next(), Some(1.0));
178        assert_eq!(faded.next(), Some(1.0));
179        assert_eq!(faded.next(), Some(1.0));
180        assert_eq!(faded.next(), Some(1.0));
181        assert_eq!(faded.next(), None);
182    }
183
184    #[test]
185    fn test_linear_ramp_clamped() {
186        let source1 = const_source(10, 1.0);
187        let mut faded = linear_gain_ramp(source1, Duration::from_secs(4), 0.0, 0.5, true);
188
189        assert_eq!(faded.next(), Some(0.0)); // fading in...
190        assert_eq!(faded.next(), Some(0.125));
191        assert_eq!(faded.next(), Some(0.25));
192        assert_eq!(faded.next(), Some(0.375));
193        assert_eq!(faded.next(), Some(0.5)); // fade is done
194        assert_eq!(faded.next(), Some(0.5));
195        assert_eq!(faded.next(), Some(0.5));
196        assert_eq!(faded.next(), Some(0.5));
197        assert_eq!(faded.next(), Some(0.5));
198        assert_eq!(faded.next(), Some(0.5));
199        assert_eq!(faded.next(), None);
200    }
201
202    #[test]
203    fn test_linear_ramp_seek() {
204        let source1 = cycle_source(20, vec![0.0, 0.4, 0.8]);
205        let mut faded = linear_gain_ramp(source1, Duration::from_secs(10), 0.0, 1.0, true);
206
207        assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0
208        assert_abs_diff_eq!(faded.next().unwrap(), 0.04); // source value 0.4, ramp gain 0.1
209        assert_abs_diff_eq!(faded.next().unwrap(), 0.16); // source value 0.8, ramp gain 0.2
210
211        if let Ok(_result) = faded.try_seek(Duration::from_secs(5)) {
212            assert_abs_diff_eq!(faded.next().unwrap(), 0.40); // source value 0.8, ramp gain 0.5
213            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0, ramp gain 0.6
214            assert_abs_diff_eq!(faded.next().unwrap(), 0.28); // source value 0.4. ramp gain 0.7
215        } else {
216            panic!("try_seek() failed!");
217        }
218
219        if let Ok(_result) = faded.try_seek(Duration::from_secs(0)) {
220            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0, ramp gain 0.0
221            assert_abs_diff_eq!(faded.next().unwrap(), 0.04); // source value 0.4, ramp gain 0.1
222            assert_abs_diff_eq!(faded.next().unwrap(), 0.16); // source value 0.8. ramp gain 0.2
223        } else {
224            panic!("try_seek() failed!");
225        }
226
227        if let Ok(_result) = faded.try_seek(Duration::from_secs(10)) {
228            assert_abs_diff_eq!(faded.next().unwrap(), 0.4); // source value 0.4, ramp gain 1.0
229            assert_abs_diff_eq!(faded.next().unwrap(), 0.8); // source value 0.8, ramp gain 1.0
230            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0. ramp gain 1.0
231        } else {
232            panic!("try_seek() failed!");
233        }
234    }
235}