1use std::time::Duration;
2
3use super::SeekError;
4use crate::common::{ChannelCount, SampleRate};
5use crate::math::NANOS_PER_SEC;
6use crate::Source;
7
8fn remaining_samples(
9 until_playback: Duration,
10 sample_rate: SampleRate,
11 channels: ChannelCount,
12) -> usize {
13 let ns = until_playback.as_nanos();
14 let samples = ns * channels.get() as u128 * sample_rate.get() as u128 / NANOS_PER_SEC as u128;
15 samples as usize
16}
17
18pub fn delay<I>(input: I, duration: Duration) -> Delay<I>
20where
21 I: Source,
22{
23 Delay {
24 remaining_samples: remaining_samples(duration, input.sample_rate(), input.channels()),
25 requested_duration: duration,
26 input,
27 }
28}
29
30#[derive(Clone, Debug)]
32pub struct Delay<I> {
33 input: I,
34 remaining_samples: usize,
35 requested_duration: Duration,
36}
37
38impl<I> Delay<I>
39where
40 I: Source,
41{
42 #[inline]
44 pub fn inner(&self) -> &I {
45 &self.input
46 }
47
48 #[inline]
50 pub fn inner_mut(&mut self) -> &mut I {
51 &mut self.input
52 }
53
54 #[inline]
56 pub fn into_inner(self) -> I {
57 self.input
58 }
59}
60
61impl<I> Iterator for Delay<I>
62where
63 I: Source,
64{
65 type Item = <I as Iterator>::Item;
66
67 #[inline]
68 fn next(&mut self) -> Option<<I as Iterator>::Item> {
69 if self.remaining_samples >= 1 {
70 self.remaining_samples -= 1;
71 Some(0.0)
72 } else {
73 self.input.next()
74 }
75 }
76
77 #[inline]
78 fn size_hint(&self) -> (usize, Option<usize>) {
79 let (min, max) = self.input.size_hint();
80 (
81 min + self.remaining_samples,
82 max.map(|v| v + self.remaining_samples),
83 )
84 }
85}
86
87impl<I> Source for Delay<I>
88where
89 I: Iterator + Source,
90{
91 #[inline]
92 fn current_span_len(&self) -> Option<usize> {
93 self.input
94 .current_span_len()
95 .map(|val| val + self.remaining_samples)
96 }
97
98 #[inline]
99 fn channels(&self) -> ChannelCount {
100 self.input.channels()
101 }
102
103 #[inline]
104 fn sample_rate(&self) -> SampleRate {
105 self.input.sample_rate()
106 }
107
108 #[inline]
109 fn total_duration(&self) -> Option<Duration> {
110 self.input
111 .total_duration()
112 .map(|val| val + self.requested_duration)
113 }
114
115 #[inline]
129 fn try_seek(&mut self, pos: Duration) -> Result<(), SeekError> {
130 if pos < self.requested_duration {
131 self.input.try_seek(Duration::ZERO)?;
132 let until_playback = self.requested_duration - pos;
133 self.remaining_samples =
134 remaining_samples(until_playback, self.sample_rate(), self.channels());
135 }
136 let compensated_for_delay = pos.saturating_sub(self.requested_duration);
137 self.input.try_seek(compensated_for_delay)
138 }
139}