rodio/source/
linear_ramp.rs

1use std::time::Duration;
2
3use super::SeekError;
4use crate::common::{ChannelCount, SampleRate};
5use crate::Source;
6
7/// Internal function that builds a `LinearRamp` object.
8pub fn linear_gain_ramp<I>(
9    input: I,
10    duration: Duration,
11    start_gain: f32,
12    end_gain: f32,
13    clamp_end: bool,
14) -> LinearGainRamp<I>
15where
16    I: Source,
17{
18    let duration_nanos = duration.as_nanos() as f32;
19    assert!(duration_nanos > 0.0f32);
20
21    LinearGainRamp {
22        input,
23        elapsed_ns: 0.0f32,
24        total_ns: duration_nanos,
25        start_gain,
26        end_gain,
27        clamp_end,
28        sample_idx: 0u64,
29    }
30}
31
32/// Filter that adds a linear gain ramp to the source over a given time range.
33#[derive(Clone, Debug)]
34pub struct LinearGainRamp<I> {
35    input: I,
36    elapsed_ns: f32,
37    total_ns: f32,
38    start_gain: f32,
39    end_gain: f32,
40    clamp_end: bool,
41    sample_idx: u64,
42}
43
44impl<I> LinearGainRamp<I>
45where
46    I: Source,
47{
48    /// Returns a reference to the innner source.
49    #[inline]
50    pub fn inner(&self) -> &I {
51        &self.input
52    }
53
54    /// Returns a mutable reference to the inner source.
55    #[inline]
56    pub fn inner_mut(&mut self) -> &mut I {
57        &mut self.input
58    }
59
60    /// Returns the inner source.
61    #[inline]
62    pub fn into_inner(self) -> I {
63        self.input
64    }
65}
66
67impl<I> Iterator for LinearGainRamp<I>
68where
69    I: Source,
70{
71    type Item = I::Item;
72
73    #[inline]
74    fn next(&mut self) -> Option<I::Item> {
75        let factor: f32;
76        let remaining_ns = self.total_ns - self.elapsed_ns;
77
78        if remaining_ns < 0.0 {
79            if self.clamp_end {
80                factor = self.end_gain;
81            } else {
82                factor = 1.0f32;
83            }
84        } else {
85            self.sample_idx += 1;
86
87            let p = self.elapsed_ns / self.total_ns;
88            factor = self.start_gain * (1.0f32 - p) + self.end_gain * p;
89        }
90
91        if self.sample_idx % (self.channels() as u64) == 0 {
92            self.elapsed_ns += 1000000000.0 / (self.input.sample_rate() as f32);
93        }
94
95        self.input.next().map(|value| value * factor)
96    }
97
98    #[inline]
99    fn size_hint(&self) -> (usize, Option<usize>) {
100        self.input.size_hint()
101    }
102}
103
104impl<I> ExactSizeIterator for LinearGainRamp<I> where I: Source + ExactSizeIterator {}
105
106impl<I> Source for LinearGainRamp<I>
107where
108    I: Source,
109{
110    #[inline]
111    fn current_span_len(&self) -> Option<usize> {
112        self.input.current_span_len()
113    }
114
115    #[inline]
116    fn channels(&self) -> ChannelCount {
117        self.input.channels()
118    }
119
120    #[inline]
121    fn sample_rate(&self) -> SampleRate {
122        self.input.sample_rate()
123    }
124
125    #[inline]
126    fn total_duration(&self) -> Option<Duration> {
127        self.input.total_duration()
128    }
129
130    #[inline]
131    fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> {
132        self.elapsed_ns = pos.as_nanos() as f32;
133        self.input.try_seek(pos)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use approx::assert_abs_diff_eq;
140
141    use super::*;
142    use crate::buffer::SamplesBuffer;
143    use crate::Sample;
144
145    /// Create a SamplesBuffer of identical samples with value `value`.
146    /// Returned buffer is one channel and has a sample rate of 1 hz.
147    fn const_source(length: u8, value: Sample) -> SamplesBuffer {
148        let data: Vec<f32> = (1..=length).map(|_| value).collect();
149        SamplesBuffer::new(1, 1, data)
150    }
151
152    /// Create a SamplesBuffer of repeating sample values from `values`.
153    fn cycle_source(length: u8, values: Vec<Sample>) -> SamplesBuffer {
154        let data: Vec<Sample> = (1..=length)
155            .enumerate()
156            .map(|(i, _)| values[i % values.len()])
157            .collect();
158
159        SamplesBuffer::new(1, 1, data)
160    }
161
162    #[test]
163    fn test_linear_ramp() {
164        let source1 = const_source(10, 1.0f32);
165        let mut faded = linear_gain_ramp(source1, Duration::from_secs(4), 0.0, 1.0, true);
166
167        assert_eq!(faded.next(), Some(0.0));
168        assert_eq!(faded.next(), Some(0.25));
169        assert_eq!(faded.next(), Some(0.5));
170        assert_eq!(faded.next(), Some(0.75));
171        assert_eq!(faded.next(), Some(1.0));
172        assert_eq!(faded.next(), Some(1.0));
173        assert_eq!(faded.next(), Some(1.0));
174        assert_eq!(faded.next(), Some(1.0));
175        assert_eq!(faded.next(), Some(1.0));
176        assert_eq!(faded.next(), Some(1.0));
177        assert_eq!(faded.next(), None);
178    }
179
180    #[test]
181    fn test_linear_ramp_clamped() {
182        let source1 = const_source(10, 1.0f32);
183        let mut faded = linear_gain_ramp(source1, Duration::from_secs(4), 0.0, 0.5, true);
184
185        assert_eq!(faded.next(), Some(0.0)); // fading in...
186        assert_eq!(faded.next(), Some(0.125));
187        assert_eq!(faded.next(), Some(0.25));
188        assert_eq!(faded.next(), Some(0.375));
189        assert_eq!(faded.next(), Some(0.5)); // fade is done
190        assert_eq!(faded.next(), Some(0.5));
191        assert_eq!(faded.next(), Some(0.5));
192        assert_eq!(faded.next(), Some(0.5));
193        assert_eq!(faded.next(), Some(0.5));
194        assert_eq!(faded.next(), Some(0.5));
195        assert_eq!(faded.next(), None);
196    }
197
198    #[test]
199    fn test_linear_ramp_seek() {
200        let source1 = cycle_source(20, vec![0.0f32, 0.4f32, 0.8f32]);
201        let mut faded = linear_gain_ramp(source1, Duration::from_secs(10), 0.0, 1.0, true);
202
203        assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0
204        assert_abs_diff_eq!(faded.next().unwrap(), 0.04); // source value 0.4, ramp gain 0.1
205        assert_abs_diff_eq!(faded.next().unwrap(), 0.16); // source value 0.8, ramp gain 0.2
206
207        if let Ok(_result) = faded.try_seek(Duration::from_secs(5)) {
208            assert_abs_diff_eq!(faded.next().unwrap(), 0.40); // source value 0.8, ramp gain 0.5
209            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0, ramp gain 0.6
210            assert_abs_diff_eq!(faded.next().unwrap(), 0.28); // source value 0.4. ramp gain 0.7
211        } else {
212            panic!("try_seek() failed!");
213        }
214
215        if let Ok(_result) = faded.try_seek(Duration::from_secs(0)) {
216            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0, ramp gain 0.0
217            assert_abs_diff_eq!(faded.next().unwrap(), 0.04); // source value 0.4, ramp gain 0.1
218            assert_abs_diff_eq!(faded.next().unwrap(), 0.16); // source value 0.8. ramp gain 0.2
219        } else {
220            panic!("try_seek() failed!");
221        }
222
223        if let Ok(_result) = faded.try_seek(Duration::from_secs(10)) {
224            assert_abs_diff_eq!(faded.next().unwrap(), 0.4); // source value 0.4, ramp gain 1.0
225            assert_abs_diff_eq!(faded.next().unwrap(), 0.8); // source value 0.8, ramp gain 1.0
226            assert_abs_diff_eq!(faded.next().unwrap(), 0.0); // source value 0. ramp gain 1.0
227        } else {
228            panic!("try_seek() failed!");
229        }
230    }
231}