rodio/source/
take.rs

1use std::time::Duration;
2
3use super::SeekError;
4use crate::common::{ChannelCount, SampleRate};
5use crate::{Sample, Source};
6
7/// Internal function that builds a `TakeDuration` object.
8pub fn take_duration<I>(input: I, duration: Duration) -> TakeDuration<I>
9where
10    I: Source,
11{
12    TakeDuration {
13        current_span_len: input.current_span_len(),
14        duration_per_sample: TakeDuration::get_duration_per_sample(&input),
15        input,
16        remaining_duration: duration,
17        requested_duration: duration,
18        filter: None,
19    }
20}
21
22/// A filter that can be applied to a `TakeDuration`.
23#[derive(Clone, Debug)]
24enum DurationFilter {
25    FadeOut,
26}
27impl DurationFilter {
28    fn apply<I: Iterator>(&self, sample: Sample, parent: &TakeDuration<I>) -> Sample {
29        match self {
30            DurationFilter::FadeOut => {
31                let remaining = parent.remaining_duration.as_millis() as f32;
32                let total = parent.requested_duration.as_millis() as f32;
33                sample * remaining / total
34            }
35        }
36    }
37}
38
39const NANOS_PER_SEC: u64 = 1_000_000_000;
40
41/// A source that truncates the given source to a certain duration.
42#[derive(Clone, Debug)]
43pub struct TakeDuration<I> {
44    input: I,
45    remaining_duration: Duration,
46    requested_duration: Duration,
47    filter: Option<DurationFilter>,
48    // Remaining samples in current span.
49    current_span_len: Option<usize>,
50    // Only updated when the current span len is exhausted.
51    duration_per_sample: Duration,
52}
53
54impl<I> TakeDuration<I>
55where
56    I: Source,
57{
58    /// Returns the duration elapsed for each sample extracted.
59    #[inline]
60    fn get_duration_per_sample(input: &I) -> Duration {
61        let ns = NANOS_PER_SEC / (input.sample_rate() as u64 * input.channels() as u64);
62        // \|/ the maximum value of `ns` is one billion, so this can't fail
63        Duration::new(0, ns as u32)
64    }
65
66    /// Returns a reference to the inner source.
67    #[inline]
68    pub fn inner(&self) -> &I {
69        &self.input
70    }
71
72    /// Returns a mutable reference to the inner source.
73    #[inline]
74    pub fn inner_mut(&mut self) -> &mut I {
75        &mut self.input
76    }
77
78    /// Returns the inner source.
79    #[inline]
80    pub fn into_inner(self) -> I {
81        self.input
82    }
83
84    /// Make the truncated source end with a FadeOut. The fadeout covers the
85    /// entire length of the take source.
86    pub fn set_filter_fadeout(&mut self) {
87        self.filter = Some(DurationFilter::FadeOut);
88    }
89
90    /// Remove any filter set.
91    pub fn clear_filter(&mut self) {
92        self.filter = None;
93    }
94}
95
96impl<I> Iterator for TakeDuration<I>
97where
98    I: Source,
99{
100    type Item = <I as Iterator>::Item;
101
102    fn next(&mut self) -> Option<<I as Iterator>::Item> {
103        if let Some(span_len) = self.current_span_len.take() {
104            if span_len > 0 {
105                self.current_span_len = Some(span_len - 1);
106            } else {
107                self.current_span_len = self.input.current_span_len();
108                // Sample rate might have changed
109                self.duration_per_sample = Self::get_duration_per_sample(&self.input);
110            }
111        }
112
113        if self.remaining_duration <= self.duration_per_sample {
114            None
115        } else if let Some(sample) = self.input.next() {
116            let sample = match &self.filter {
117                Some(filter) => filter.apply(sample, self),
118                None => sample,
119            };
120
121            self.remaining_duration -= self.duration_per_sample;
122
123            Some(sample)
124        } else {
125            None
126        }
127    }
128
129    // TODO: size_hint
130}
131
132impl<I> Source for TakeDuration<I>
133where
134    I: Iterator + Source,
135{
136    #[inline]
137    fn current_span_len(&self) -> Option<usize> {
138        let remaining_nanos = self.remaining_duration.as_secs() * NANOS_PER_SEC
139            + self.remaining_duration.subsec_nanos() as u64;
140        let nanos_per_sample = self.duration_per_sample.as_secs() * NANOS_PER_SEC
141            + self.duration_per_sample.subsec_nanos() as u64;
142        let remaining_samples = (remaining_nanos / nanos_per_sample) as usize;
143
144        self.input
145            .current_span_len()
146            .filter(|value| *value < remaining_samples)
147            .or(Some(remaining_samples))
148    }
149
150    #[inline]
151    fn channels(&self) -> ChannelCount {
152        self.input.channels()
153    }
154
155    #[inline]
156    fn sample_rate(&self) -> SampleRate {
157        self.input.sample_rate()
158    }
159
160    #[inline]
161    fn total_duration(&self) -> Option<Duration> {
162        if let Some(duration) = self.input.total_duration() {
163            if duration < self.requested_duration {
164                Some(duration)
165            } else {
166                Some(self.requested_duration)
167            }
168        } else {
169            None
170        }
171    }
172
173    #[inline]
174    fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> {
175        self.input.try_seek(pos)
176    }
177}