rodio/decoder/
symphonia.rs

1use core::fmt;
2use core::time::Duration;
3use symphonia::{
4    core::{
5        audio::{AudioBufferRef, SampleBuffer, SignalSpec},
6        codecs::{Decoder, DecoderOptions, CODEC_TYPE_NULL},
7        errors::Error,
8        formats::{FormatOptions, FormatReader, SeekMode, SeekTo, SeekedTo},
9        io::MediaSourceStream,
10        meta::MetadataOptions,
11        probe::Hint,
12        units,
13    },
14    default::get_probe,
15};
16
17use super::{DecoderError, Settings};
18use crate::{
19    common::{ChannelCount, Sample, SampleRate},
20    source, Source,
21};
22
23pub(crate) struct SymphoniaDecoder {
24    decoder: Box<dyn Decoder>,
25    current_span_offset: usize,
26    format: Box<dyn FormatReader>,
27    total_duration: Option<Duration>,
28    buffer: SampleBuffer<Sample>,
29    spec: SignalSpec,
30    seek_mode: SeekMode,
31}
32
33impl SymphoniaDecoder {
34    pub(crate) fn new(mss: MediaSourceStream, settings: &Settings) -> Result<Self, DecoderError> {
35        match SymphoniaDecoder::init(mss, settings) {
36            Err(e) => match e {
37                Error::IoError(e) => Err(DecoderError::IoError(e.to_string())),
38                Error::DecodeError(e) => Err(DecoderError::DecodeError(e)),
39                Error::SeekError(_) => {
40                    unreachable!("Seek errors should not occur during initialization")
41                }
42                Error::Unsupported(_) => Err(DecoderError::UnrecognizedFormat),
43                Error::LimitError(e) => Err(DecoderError::LimitError(e)),
44                Error::ResetRequired => Err(DecoderError::ResetRequired),
45            },
46            Ok(Some(decoder)) => Ok(decoder),
47            Ok(None) => Err(DecoderError::NoStreams),
48        }
49    }
50
51    #[inline]
52    pub(crate) fn into_inner(self) -> MediaSourceStream {
53        self.format.into_inner()
54    }
55
56    fn init(
57        mss: MediaSourceStream,
58        settings: &Settings,
59    ) -> symphonia::core::errors::Result<Option<SymphoniaDecoder>> {
60        let mut hint = Hint::new();
61        if let Some(ext) = settings.hint.as_ref() {
62            hint.with_extension(ext);
63        }
64        if let Some(typ) = settings.mime_type.as_ref() {
65            hint.mime_type(typ);
66        }
67        let format_opts: FormatOptions = FormatOptions {
68            enable_gapless: settings.gapless,
69            ..Default::default()
70        };
71        let metadata_opts: MetadataOptions = Default::default();
72        let seek_mode = if settings.coarse_seek {
73            SeekMode::Coarse
74        } else {
75            SeekMode::Accurate
76        };
77        let mut probed = get_probe().format(&hint, mss, &format_opts, &metadata_opts)?;
78
79        let stream = match probed.format.default_track() {
80            Some(stream) => stream,
81            None => return Ok(None),
82        };
83
84        // Select the first supported track
85        let track_id = probed
86            .format
87            .tracks()
88            .iter()
89            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
90            .ok_or(symphonia::core::errors::Error::Unsupported(
91                "No track with supported codec",
92            ))?
93            .id;
94
95        let track = match probed
96            .format
97            .tracks()
98            .iter()
99            .find(|track| track.id == track_id)
100        {
101            Some(track) => track,
102            None => return Ok(None),
103        };
104
105        let mut decoder = symphonia::default::get_codecs()
106            .make(&track.codec_params, &DecoderOptions::default())?;
107        let total_duration = stream
108            .codec_params
109            .time_base
110            .zip(stream.codec_params.n_frames)
111            .map(|(base, spans)| base.calc_time(spans).into());
112
113        let decoded = loop {
114            let current_span = match probed.format.next_packet() {
115                Ok(packet) => packet,
116                Err(Error::IoError(_)) => break decoder.last_decoded(),
117                Err(e) => return Err(e),
118            };
119
120            // If the packet does not belong to the selected track, skip over it
121            if current_span.track_id() != track_id {
122                continue;
123            }
124
125            match decoder.decode(&current_span) {
126                Ok(decoded) => break decoded,
127                Err(e) => match e {
128                    Error::DecodeError(_) => {
129                        // Decode errors are intentionally ignored with no retry limit.
130                        // This behavior ensures that the decoder skips over problematic packets
131                        // and continues processing the rest of the stream.
132                        continue;
133                    }
134                    _ => return Err(e),
135                },
136            }
137        };
138        let spec = decoded.spec().to_owned();
139        let buffer = SymphoniaDecoder::get_buffer(decoded, &spec);
140        Ok(Some(SymphoniaDecoder {
141            decoder,
142            current_span_offset: 0,
143            format: probed.format,
144            total_duration,
145            buffer,
146            spec,
147            seek_mode,
148        }))
149    }
150
151    #[inline]
152    fn get_buffer(decoded: AudioBufferRef, spec: &SignalSpec) -> SampleBuffer<Sample> {
153        let duration = units::Duration::from(decoded.capacity() as u64);
154        let mut buffer = SampleBuffer::<Sample>::new(duration, *spec);
155        buffer.copy_interleaved_ref(decoded);
156        buffer
157    }
158}
159
160impl Source for SymphoniaDecoder {
161    #[inline]
162    fn current_span_len(&self) -> Option<usize> {
163        Some(self.buffer.len())
164    }
165
166    #[inline]
167    fn channels(&self) -> ChannelCount {
168        self.spec.channels.count() as ChannelCount
169    }
170
171    #[inline]
172    fn sample_rate(&self) -> SampleRate {
173        self.spec.rate
174    }
175
176    #[inline]
177    fn total_duration(&self) -> Option<Duration> {
178        self.total_duration
179    }
180
181    fn try_seek(&mut self, pos: Duration) -> Result<(), source::SeekError> {
182        if matches!(self.seek_mode, SeekMode::Accurate)
183            && self.decoder.codec_params().time_base.is_none()
184        {
185            return Err(source::SeekError::SymphoniaDecoder(
186                SeekError::AccurateSeekNotSupported,
187            ));
188        }
189
190        // Seeking should be "saturating", meaning: target positions beyond the end of the stream
191        // are clamped to the end.
192        let mut target = pos;
193        if let Some(total_duration) = self.total_duration {
194            if target > total_duration {
195                target = total_duration;
196            }
197        }
198
199        // Remember the current channel, so we can restore it after seeking.
200        let active_channel = self.current_span_offset % self.channels() as usize;
201
202        let seek_res = match self.format.seek(
203            self.seek_mode,
204            SeekTo::Time {
205                time: target.into(),
206                track_id: None,
207            },
208        ) {
209            Err(Error::SeekError(symphonia::core::errors::SeekErrorKind::ForwardOnly)) => {
210                return Err(source::SeekError::SymphoniaDecoder(
211                    SeekError::RandomAccessNotSupported,
212                ));
213            }
214            other => other.map_err(SeekError::Demuxer),
215        }?;
216
217        // Seeking is a demuxer operation without the decoder knowing about it,
218        // so we need to reset the decoder to make sure it's in sync and prevent
219        // audio glitches.
220        self.decoder.reset();
221
222        // Force the iterator to decode the next packet.
223        self.current_span_offset = usize::MAX;
224
225        // Symphonia does not seek to the exact position, it seeks to the closest keyframe.
226        // If accurate seeking is required, fast-forward to the exact position.
227        if matches!(self.seek_mode, SeekMode::Accurate) {
228            self.refine_position(seek_res)?;
229        }
230
231        // After seeking, we are at the beginning of an inter-sample frame, i.e. the first
232        // channel. We need to advance the iterator to the right channel.
233        for _ in 0..active_channel {
234            self.next();
235        }
236
237        Ok(())
238    }
239}
240
241/// Error returned when the try_seek implementation of the symphonia decoder fails.
242#[derive(Debug)]
243pub enum SeekError {
244    /// Accurate seeking is not supported
245    ///
246    /// This error occurs when the decoder cannot extract time base information from the source.
247    /// You may catch this error to try a coarse seek instead.
248    AccurateSeekNotSupported,
249    /// The decoder does not support random access seeking
250    ///
251    /// This error occurs when the source is not seekable or does not have a known byte length.
252    RandomAccessNotSupported,
253    /// Demuxer failed to seek
254    Demuxer(symphonia::core::errors::Error),
255}
256
257impl fmt::Display for SeekError {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        match self {
260            SeekError::AccurateSeekNotSupported => {
261                write!(
262                    f,
263                    "Accurate seeking is not supported on this file/byte stream that lacks time base information"
264                )
265            }
266            SeekError::RandomAccessNotSupported => {
267                write!(f, "The decoder needs to know the length of the file/byte stream to be able to seek backwards. You can set that by using the `DecoderBuilder` or creating a decoder using `Decoder::try_from(some_file)`.")
268            }
269            SeekError::Demuxer(err) => {
270                write!(f, "Demuxer failed to seek: {err:?}")
271            }
272        }
273    }
274}
275
276impl std::error::Error for SeekError {
277    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
278        match self {
279            SeekError::AccurateSeekNotSupported => None,
280            SeekError::RandomAccessNotSupported => None,
281            SeekError::Demuxer(err) => Some(err),
282        }
283    }
284}
285
286impl SymphoniaDecoder {
287    /// Note span offset must be set after
288    fn refine_position(&mut self, seek_res: SeekedTo) -> Result<(), source::SeekError> {
289        // Calculate the number of samples to skip.
290        let mut samples_to_skip = (Duration::from(
291            self.decoder
292                .codec_params()
293                .time_base
294                .expect("time base availability guaranteed by caller")
295                .calc_time(seek_res.required_ts.saturating_sub(seek_res.actual_ts)),
296        )
297        .as_secs_f32()
298            * self.sample_rate() as f32
299            * self.channels() as f32)
300            .ceil() as usize;
301
302        // Re-align the seek position to the first channel.
303        samples_to_skip -= samples_to_skip % self.channels() as usize;
304
305        // Skip ahead to the precise position.
306        for _ in 0..samples_to_skip {
307            self.next();
308        }
309
310        Ok(())
311    }
312}
313
314impl Iterator for SymphoniaDecoder {
315    type Item = Sample;
316
317    fn next(&mut self) -> Option<Self::Item> {
318        if self.current_span_offset >= self.buffer.len() {
319            let decoded = loop {
320                let packet = self.format.next_packet().ok()?;
321                let decoded = match self.decoder.decode(&packet) {
322                    Ok(decoded) => decoded,
323                    Err(Error::DecodeError(_)) => {
324                        // Skip over packets that cannot be decoded. This ensures the iterator
325                        // continues processing subsequent packets instead of terminating due to
326                        // non-critical decode errors.
327                        continue;
328                    }
329                    Err(_) => return None,
330                };
331
332                // Loop until we get a packet with audio frames. This is necessary because some
333                // formats can have packets with only metadata, particularly when rewinding, in
334                // which case the iterator would otherwise end with `None`.
335                // Note: checking `decoded.frames()` is more reliable than `packet.dur()`, which
336                // can resturn non-zero durations for packets without audio frames.
337                if decoded.frames() > 0 {
338                    break decoded;
339                }
340            };
341
342            decoded.spec().clone_into(&mut self.spec);
343            self.buffer = SymphoniaDecoder::get_buffer(decoded, &self.spec);
344            self.current_span_offset = 0;
345        }
346
347        let sample = *self.buffer.samples().get(self.current_span_offset)?;
348        self.current_span_offset += 1;
349
350        Some(sample)
351    }
352}