1use std::time::Duration;
2
3use super::SeekError;
4use crate::common::{ChannelCount, SampleRate};
5use crate::math::NANOS_PER_SEC;
6use crate::{Float, Sample, Source};
7
8pub fn take_duration<I>(input: I, duration: Duration) -> TakeDuration<I>
10where
11 I: Source,
12{
13 TakeDuration {
14 current_span_len: input.current_span_len(),
15 duration_per_sample: TakeDuration::get_duration_per_sample(&input),
16 input,
17 remaining_duration: duration,
18 requested_duration: duration,
19 filter: None,
20 }
21}
22
23#[derive(Clone, Debug)]
25enum DurationFilter {
26 FadeOut,
27}
28impl DurationFilter {
29 fn apply<I: Iterator>(&self, sample: Sample, parent: &TakeDuration<I>) -> Sample {
30 match self {
31 DurationFilter::FadeOut => {
32 let remaining = parent.remaining_duration.as_millis() as Float;
33 let total = parent.requested_duration.as_millis() as Float;
34 sample * remaining / total
35 }
36 }
37 }
38}
39
40#[derive(Clone, Debug)]
42pub struct TakeDuration<I> {
43 input: I,
44 remaining_duration: Duration,
45 requested_duration: Duration,
46 filter: Option<DurationFilter>,
47 current_span_len: Option<usize>,
49 duration_per_sample: Duration,
51}
52
53impl<I> TakeDuration<I>
54where
55 I: Source,
56{
57 #[inline]
59 fn get_duration_per_sample(input: &I) -> Duration {
60 let ns = NANOS_PER_SEC / (input.sample_rate().get() as u64 * input.channels().get() as u64);
61 Duration::new(0, ns as u32)
63 }
64
65 #[inline]
67 pub fn inner(&self) -> &I {
68 &self.input
69 }
70
71 #[inline]
73 pub fn inner_mut(&mut self) -> &mut I {
74 &mut self.input
75 }
76
77 #[inline]
79 pub fn into_inner(self) -> I {
80 self.input
81 }
82
83 pub fn set_filter_fadeout(&mut self) {
86 self.filter = Some(DurationFilter::FadeOut);
87 }
88
89 pub fn clear_filter(&mut self) {
91 self.filter = None;
92 }
93}
94
95impl<I> Iterator for TakeDuration<I>
96where
97 I: Source,
98{
99 type Item = <I as Iterator>::Item;
100
101 fn next(&mut self) -> Option<<I as Iterator>::Item> {
102 if let Some(span_len) = self.current_span_len.take() {
103 if span_len > 0 {
104 self.current_span_len = Some(span_len - 1);
105 } else {
106 self.current_span_len = self.input.current_span_len();
107 self.duration_per_sample = Self::get_duration_per_sample(&self.input);
109 }
110 }
111
112 if self.remaining_duration <= self.duration_per_sample {
113 None
114 } else if let Some(sample) = self.input.next() {
115 let sample = match &self.filter {
116 Some(filter) => filter.apply(sample, self),
117 None => sample,
118 };
119
120 self.remaining_duration -= self.duration_per_sample;
121
122 Some(sample)
123 } else {
124 None
125 }
126 }
127
128 }
130
131impl<I> Source for TakeDuration<I>
132where
133 I: Iterator + Source,
134{
135 #[inline]
136 fn current_span_len(&self) -> Option<usize> {
137 let remaining_nanos = self.remaining_duration.as_secs() * NANOS_PER_SEC
138 + self.remaining_duration.subsec_nanos() as u64;
139 let nanos_per_sample = self.duration_per_sample.as_secs() * NANOS_PER_SEC
140 + self.duration_per_sample.subsec_nanos() as u64;
141 let remaining_samples = (remaining_nanos / nanos_per_sample) as usize;
142
143 self.input
144 .current_span_len()
145 .filter(|value| *value < remaining_samples)
146 .or(Some(remaining_samples))
147 }
148
149 #[inline]
150 fn channels(&self) -> ChannelCount {
151 self.input.channels()
152 }
153
154 #[inline]
155 fn sample_rate(&self) -> SampleRate {
156 self.input.sample_rate()
157 }
158
159 #[inline]
160 fn total_duration(&self) -> Option<Duration> {
161 if let Some(duration) = self.input.total_duration() {
162 if duration < self.requested_duration {
163 Some(duration)
164 } else {
165 Some(self.requested_duration)
166 }
167 } else {
168 None
169 }
170 }
171
172 #[inline]
173 fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> {
174 let result = self.input.try_seek(pos);
175 if result.is_ok() {
176 self.remaining_duration = self
180 .total_duration()
181 .unwrap_or(self.requested_duration)
182 .checked_sub(pos)
183 .unwrap_or(Duration::ZERO);
184 self.current_span_len = self.current_span_len();
185 }
186 result
187 }
188}