Skip to main content

rodio/source/
take.rs

1use std::time::Duration;
2
3use super::SeekError;
4use crate::common::{ChannelCount, SampleRate};
5use crate::math::NANOS_PER_SEC;
6use crate::{Float, Sample, Source};
7
8/// Internal function that builds a `TakeDuration` object.
9pub fn take_duration<I>(input: I, duration: Duration) -> TakeDuration<I>
10where
11    I: Source,
12{
13    TakeDuration {
14        current_span_len: input.current_span_len(),
15        duration_per_sample: TakeDuration::get_duration_per_sample(&input),
16        input,
17        remaining_duration: duration,
18        requested_duration: duration,
19        filter: None,
20    }
21}
22
23/// A filter that can be applied to a `TakeDuration`.
24#[derive(Clone, Debug)]
25enum DurationFilter {
26    FadeOut,
27}
28impl DurationFilter {
29    fn apply<I: Iterator>(&self, sample: Sample, parent: &TakeDuration<I>) -> Sample {
30        match self {
31            DurationFilter::FadeOut => {
32                let remaining = parent.remaining_duration.as_millis() as Float;
33                let total = parent.requested_duration.as_millis() as Float;
34                sample * remaining / total
35            }
36        }
37    }
38}
39
40/// A source that truncates the given source to a certain duration.
41#[derive(Clone, Debug)]
42pub struct TakeDuration<I> {
43    input: I,
44    remaining_duration: Duration,
45    requested_duration: Duration,
46    filter: Option<DurationFilter>,
47    // Remaining samples in current span.
48    current_span_len: Option<usize>,
49    // Only updated when the current span len is exhausted.
50    duration_per_sample: Duration,
51}
52
53impl<I> TakeDuration<I>
54where
55    I: Source,
56{
57    /// Returns the duration elapsed for each sample extracted.
58    #[inline]
59    fn get_duration_per_sample(input: &I) -> Duration {
60        let ns = NANOS_PER_SEC / (input.sample_rate().get() as u64 * input.channels().get() as u64);
61        // \|/ the maximum value of `ns` is one billion, so this can't fail
62        Duration::new(0, ns as u32)
63    }
64
65    /// Returns a reference to the inner source.
66    #[inline]
67    pub fn inner(&self) -> &I {
68        &self.input
69    }
70
71    /// Returns a mutable reference to the inner source.
72    #[inline]
73    pub fn inner_mut(&mut self) -> &mut I {
74        &mut self.input
75    }
76
77    /// Returns the inner source.
78    #[inline]
79    pub fn into_inner(self) -> I {
80        self.input
81    }
82
83    /// Make the truncated source end with a FadeOut. The fadeout covers the
84    /// entire length of the take source.
85    pub fn set_filter_fadeout(&mut self) {
86        self.filter = Some(DurationFilter::FadeOut);
87    }
88
89    /// Remove any filter set.
90    pub fn clear_filter(&mut self) {
91        self.filter = None;
92    }
93}
94
95impl<I> Iterator for TakeDuration<I>
96where
97    I: Source,
98{
99    type Item = <I as Iterator>::Item;
100
101    fn next(&mut self) -> Option<<I as Iterator>::Item> {
102        if let Some(span_len) = self.current_span_len.take() {
103            if span_len > 0 {
104                self.current_span_len = Some(span_len - 1);
105            } else {
106                self.current_span_len = self.input.current_span_len();
107                // Sample rate might have changed
108                self.duration_per_sample = Self::get_duration_per_sample(&self.input);
109            }
110        }
111
112        if self.remaining_duration <= self.duration_per_sample {
113            None
114        } else if let Some(sample) = self.input.next() {
115            let sample = match &self.filter {
116                Some(filter) => filter.apply(sample, self),
117                None => sample,
118            };
119
120            self.remaining_duration -= self.duration_per_sample;
121
122            Some(sample)
123        } else {
124            None
125        }
126    }
127
128    // TODO: size_hint
129}
130
131impl<I> Source for TakeDuration<I>
132where
133    I: Iterator + Source,
134{
135    #[inline]
136    fn current_span_len(&self) -> Option<usize> {
137        let remaining_nanos = self.remaining_duration.as_secs() * NANOS_PER_SEC
138            + self.remaining_duration.subsec_nanos() as u64;
139        let nanos_per_sample = self.duration_per_sample.as_secs() * NANOS_PER_SEC
140            + self.duration_per_sample.subsec_nanos() as u64;
141        let remaining_samples = (remaining_nanos / nanos_per_sample) as usize;
142
143        self.input
144            .current_span_len()
145            .filter(|value| *value < remaining_samples)
146            .or(Some(remaining_samples))
147    }
148
149    #[inline]
150    fn channels(&self) -> ChannelCount {
151        self.input.channels()
152    }
153
154    #[inline]
155    fn sample_rate(&self) -> SampleRate {
156        self.input.sample_rate()
157    }
158
159    #[inline]
160    fn total_duration(&self) -> Option<Duration> {
161        if let Some(duration) = self.input.total_duration() {
162            if duration < self.requested_duration {
163                Some(duration)
164            } else {
165                Some(self.requested_duration)
166            }
167        } else {
168            None
169        }
170    }
171
172    #[inline]
173    fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> {
174        let result = self.input.try_seek(pos);
175        if result.is_ok() {
176            // This assumes that the seek implementation of the codec always
177            // starts again at the beginning of a span. Which is the case with
178            // symphonia.
179            self.remaining_duration = self
180                .total_duration()
181                .unwrap_or(self.requested_duration)
182                .checked_sub(pos)
183                .unwrap_or(Duration::ZERO);
184            self.current_span_len = self.current_span_len();
185        }
186        result
187    }
188}