Skip to main content

rodio/decoder/
symphonia.rs

1use core::time::Duration;
2use std::sync::Arc;
3use symphonia::{
4    core::{
5        audio::{AudioBufferRef, SampleBuffer, SignalSpec},
6        codecs::{Decoder, DecoderOptions, CODEC_TYPE_NULL},
7        errors::Error,
8        formats::{FormatOptions, FormatReader, SeekMode, SeekTo, SeekedTo},
9        io::MediaSourceStream,
10        meta::MetadataOptions,
11        probe::Hint,
12        units,
13    },
14    default::get_probe,
15};
16
17use super::{DecoderError, Settings};
18use crate::{
19    common::{assert_error_traits, ChannelCount, Sample, SampleRate},
20    source, Source,
21};
22
23pub(crate) struct SymphoniaDecoder {
24    decoder: Box<dyn Decoder>,
25    current_span_offset: usize,
26    format: Box<dyn FormatReader>,
27    total_duration: Option<Duration>,
28    buffer: SampleBuffer<Sample>,
29    spec: SignalSpec,
30    seek_mode: SeekMode,
31}
32
33impl SymphoniaDecoder {
34    pub(crate) fn new(mss: MediaSourceStream, settings: &Settings) -> Result<Self, DecoderError> {
35        match SymphoniaDecoder::init(mss, settings) {
36            Err(e) => match e {
37                Error::IoError(e) => Err(DecoderError::IoError(e.to_string())),
38                Error::DecodeError(e) => Err(DecoderError::DecodeError(e)),
39                Error::SeekError(_) => {
40                    unreachable!("Seek errors should not occur during initialization")
41                }
42                Error::Unsupported(_) => Err(DecoderError::UnrecognizedFormat),
43                Error::LimitError(e) => Err(DecoderError::LimitError(e)),
44                Error::ResetRequired => Err(DecoderError::ResetRequired),
45            },
46            Ok(Some(decoder)) => Ok(decoder),
47            Ok(None) => Err(DecoderError::NoStreams),
48        }
49    }
50
51    #[inline]
52    pub(crate) fn into_inner(self) -> MediaSourceStream {
53        self.format.into_inner()
54    }
55
56    fn init(
57        mss: MediaSourceStream,
58        settings: &Settings,
59    ) -> symphonia::core::errors::Result<Option<SymphoniaDecoder>> {
60        let mut hint = Hint::new();
61        if let Some(ext) = settings.hint.as_ref() {
62            hint.with_extension(ext);
63        }
64        if let Some(typ) = settings.mime_type.as_ref() {
65            hint.mime_type(typ);
66        }
67        let format_opts: FormatOptions = FormatOptions {
68            enable_gapless: settings.gapless,
69            ..Default::default()
70        };
71        let metadata_opts: MetadataOptions = Default::default();
72        let seek_mode = if settings.coarse_seek {
73            SeekMode::Coarse
74        } else {
75            SeekMode::Accurate
76        };
77        let mut probed = get_probe().format(&hint, mss, &format_opts, &metadata_opts)?;
78
79        let stream = match probed.format.default_track() {
80            Some(stream) => stream,
81            None => return Ok(None),
82        };
83
84        // Select the first supported track
85        let track_id = probed
86            .format
87            .tracks()
88            .iter()
89            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
90            .ok_or(symphonia::core::errors::Error::Unsupported(
91                "No track with supported codec",
92            ))?
93            .id;
94
95        let track = match probed
96            .format
97            .tracks()
98            .iter()
99            .find(|track| track.id == track_id)
100        {
101            Some(track) => track,
102            None => return Ok(None),
103        };
104
105        let mut decoder = symphonia::default::get_codecs()
106            .make(&track.codec_params, &DecoderOptions::default())?;
107        let total_duration: Option<Duration> = track
108            .codec_params
109            .time_base
110            .zip(stream.codec_params.n_frames)
111            .map(|(base, spans)| base.calc_time(spans).into())
112            .filter(|d: &Duration| !d.is_zero());
113
114        let decoded = loop {
115            let current_span = match probed.format.next_packet() {
116                Ok(packet) => packet,
117                Err(Error::IoError(_)) => break decoder.last_decoded(),
118                Err(e) => return Err(e),
119            };
120
121            // If the packet does not belong to the selected track, skip over it
122            if current_span.track_id() != track_id {
123                continue;
124            }
125
126            match decoder.decode(&current_span) {
127                Ok(decoded) => break decoded,
128                Err(e) => match e {
129                    Error::DecodeError(_) => {
130                        // Decode errors are intentionally ignored with no retry limit.
131                        // This behavior ensures that the decoder skips over problematic packets
132                        // and continues processing the rest of the stream.
133                        continue;
134                    }
135                    _ => return Err(e),
136                },
137            }
138        };
139        let spec = decoded.spec().to_owned();
140        let buffer = SymphoniaDecoder::get_buffer(decoded, &spec);
141        Ok(Some(SymphoniaDecoder {
142            decoder,
143            current_span_offset: 0,
144            format: probed.format,
145            total_duration,
146            buffer,
147            spec,
148            seek_mode,
149        }))
150    }
151
152    #[inline]
153    fn get_buffer(decoded: AudioBufferRef, spec: &SignalSpec) -> SampleBuffer<Sample> {
154        let duration = units::Duration::from(decoded.capacity() as u64);
155        let mut buffer = SampleBuffer::<Sample>::new(duration, *spec);
156        buffer.copy_interleaved_ref(decoded);
157        buffer
158    }
159}
160
161impl Source for SymphoniaDecoder {
162    #[inline]
163    fn current_span_len(&self) -> Option<usize> {
164        Some(self.buffer.len())
165    }
166
167    #[inline]
168    fn channels(&self) -> ChannelCount {
169        ChannelCount::new(
170            self.spec
171                .channels
172                .count()
173                .try_into()
174                .expect("rodio only support up to u16::MAX channels (65_535)"),
175        )
176        .expect("audio should always have at least one channel")
177    }
178
179    #[inline]
180    fn sample_rate(&self) -> SampleRate {
181        SampleRate::new(self.spec.rate).expect("audio should always have a non zero SampleRate")
182    }
183
184    #[inline]
185    fn total_duration(&self) -> Option<Duration> {
186        self.total_duration
187    }
188
189    fn try_seek(&mut self, pos: Duration) -> Result<(), source::SeekError> {
190        if matches!(self.seek_mode, SeekMode::Accurate)
191            && self.decoder.codec_params().time_base.is_none()
192        {
193            return Err(source::SeekError::SymphoniaDecoder(
194                SeekError::AccurateSeekNotSupported,
195            ));
196        }
197
198        // Seeking should be "saturating", meaning: target positions beyond the end of the stream
199        // are clamped to the end.
200        let mut target = pos;
201        if let Some(total_duration) = self.total_duration {
202            if target > total_duration {
203                target = total_duration;
204            }
205        }
206
207        // Remember the current channel, so we can restore it after seeking.
208        let active_channel = self.current_span_offset % self.channels().get() as usize;
209
210        let seek_res = match self.format.seek(
211            self.seek_mode,
212            SeekTo::Time {
213                time: target.into(),
214                track_id: None,
215            },
216        ) {
217            Err(Error::SeekError(symphonia::core::errors::SeekErrorKind::ForwardOnly)) => {
218                return Err(source::SeekError::SymphoniaDecoder(
219                    SeekError::RandomAccessNotSupported,
220                ));
221            }
222            other => other.map_err(Arc::new).map_err(SeekError::Demuxer),
223        }?;
224
225        // Seeking is a demuxer operation without the decoder knowing about it,
226        // so we need to reset the decoder to make sure it's in sync and prevent
227        // audio glitches.
228        self.decoder.reset();
229
230        // Force the iterator to decode the next packet.
231        self.current_span_offset = usize::MAX;
232
233        // Symphonia does not seek to the exact position, it seeks to the closest keyframe.
234        // If accurate seeking is required, fast-forward to the exact position.
235        if matches!(self.seek_mode, SeekMode::Accurate) {
236            self.refine_position(seek_res)?;
237        }
238
239        // After seeking, we are at the beginning of an inter-sample frame, i.e. the first
240        // channel. We need to advance the iterator to the right channel.
241        for _ in 0..active_channel {
242            self.next();
243        }
244
245        Ok(())
246    }
247}
248
249/// Error returned when the try_seek implementation of the symphonia decoder fails.
250#[derive(Debug, thiserror::Error, Clone)]
251pub enum SeekError {
252    /// Accurate seeking is not supported
253    ///
254    /// This error occurs when the decoder cannot extract time base information from the source.
255    /// You may catch this error to try a coarse seek instead.
256    #[error("Accurate seeking is not supported on this file/byte stream that lacks time base information")]
257    AccurateSeekNotSupported,
258    /// The decoder does not support random access seeking
259    ///
260    /// This error occurs when the source is not seekable or does not have a known byte length.
261    #[error("The decoder needs to know the length of the file/byte stream to be able to seek backwards. You can set that by using the `DecoderBuilder` or creating a decoder using `Decoder::try_from(some_file)`.")]
262    RandomAccessNotSupported,
263    /// Demuxer failed to seek
264    #[error("Demuxer failed to seek")]
265    Demuxer(#[source] Arc<symphonia::core::errors::Error>),
266}
267assert_error_traits!(SeekError);
268
269impl SymphoniaDecoder {
270    /// Note span offset must be set after
271    fn refine_position(&mut self, seek_res: SeekedTo) -> Result<(), source::SeekError> {
272        // Calculate the number of samples to skip.
273        let mut samples_to_skip = (Duration::from(
274            self.decoder
275                .codec_params()
276                .time_base
277                .expect("time base availability guaranteed by caller")
278                .calc_time(seek_res.required_ts.saturating_sub(seek_res.actual_ts)),
279        )
280        .as_secs_f32()
281            * self.sample_rate().get() as f32
282            * self.channels().get() as f32)
283            .ceil() as usize;
284
285        // Re-align the seek position to the first channel.
286        samples_to_skip -= samples_to_skip % self.channels().get() as usize;
287
288        // Skip ahead to the precise position.
289        for _ in 0..samples_to_skip {
290            self.next();
291        }
292
293        Ok(())
294    }
295}
296
297impl Iterator for SymphoniaDecoder {
298    type Item = Sample;
299
300    fn next(&mut self) -> Option<Self::Item> {
301        if self.current_span_offset >= self.buffer.len() {
302            let decoded = loop {
303                let packet = self.format.next_packet().ok()?;
304                let decoded = match self.decoder.decode(&packet) {
305                    Ok(decoded) => decoded,
306                    Err(Error::DecodeError(_)) => {
307                        // Skip over packets that cannot be decoded. This ensures the iterator
308                        // continues processing subsequent packets instead of terminating due to
309                        // non-critical decode errors.
310                        continue;
311                    }
312                    Err(_) => return None,
313                };
314
315                // Loop until we get a packet with audio frames. This is necessary because some
316                // formats can have packets with only metadata, particularly when rewinding, in
317                // which case the iterator would otherwise end with `None`.
318                // Note: checking `decoded.frames()` is more reliable than `packet.dur()`, which
319                // can resturn non-zero durations for packets without audio frames.
320                if decoded.frames() > 0 {
321                    break decoded;
322                }
323            };
324
325            decoded.spec().clone_into(&mut self.spec);
326            self.buffer = SymphoniaDecoder::get_buffer(decoded, &self.spec);
327            self.current_span_offset = 0;
328        }
329
330        let sample = *self.buffer.samples().get(self.current_span_offset)?;
331        self.current_span_offset += 1;
332
333        Some(sample)
334    }
335}