Skip to content

Commit 142404f

Browse files
decode mp3 with librosa if torchaudio >= 0.12 doesn't work as a temporary workaround (#4923)
* decode mp3 with librosa if torchaudio is > 0.12 (ideally version of ffmpeg should be checked too) * decode mp3 with torchaudio>=0.12 if it works (instead of librosa) * fix incorrect marks for mp3 tests (require torchaudio, not sndfile) * add tests for latest torchaudio + separate stage in CI for it (first try) * install ffmpeg only on ubuntu * use mock to emulate torchaudio fail, add tests for librosa (not all of them) * test torchaudio_latest only on ubuntu * try/except decoding with librosa for file-like objects * more tests for latest torchaudio, should be comlpete set now * replace logging with warnings * fix tests: catch warnings with a pytest context manager Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 8ba0522 commit 142404f

File tree

5 files changed

+237
-17
lines changed

5 files changed

+237
-17
lines changed

.github/workflows/ci.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,13 @@ jobs:
7272
- name: Test with pytest
7373
run: |
7474
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
75+
- name: Install dependencies to test torchaudio>=0.12 on Ubuntu
76+
if: ${{ matrix.os == 'ubuntu-latest' }}
77+
run: |
78+
pip uninstall -y torchaudio torch
79+
pip install "torchaudio>=0.12"
80+
sudo apt-get -y install ffmpeg
81+
- name: Test torchaudio>=0.12 on Ubuntu
82+
if: ${{ matrix.os == 'ubuntu-latest' }}
83+
run: |
84+
python -m pytest -rfExX -m torchaudio_latest -n 2 --dist loadfile -sv ./tests/features/test_audio.py

src/datasets/features/audio.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23
from dataclasses import dataclass, field
34
from io import BytesIO
45
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union
@@ -268,7 +269,7 @@ def _decode_non_mp3_file_like(self, file, format=None):
268269
if version.parse(sf.__libsndfile_version__) < version.parse("1.0.30"):
269270
raise RuntimeError(
270271
"Decoding .opus files requires 'libsndfile'>=1.0.30, "
271-
+ "it can be installed via conda: `conda install -c conda-forge libsndfile>=1.0.30`"
272+
+ 'it can be installed via conda: `conda install -c conda-forge "libsndfile>=1.0.30"`'
272273
)
273274
array, sampling_rate = sf.read(file)
274275
array = array.T
@@ -282,19 +283,44 @@ def _decode_non_mp3_file_like(self, file, format=None):
282283
def _decode_mp3(self, path_or_file):
283284
try:
284285
import torchaudio
285-
import torchaudio.transforms as T
286286
except ImportError as err:
287-
raise ImportError(
288-
"Decoding 'mp3' audio files, requires 'torchaudio<0.12.0': pip install 'torchaudio<0.12.0'"
289-
) from err
290-
if not version.parse(torchaudio.__version__) < version.parse("0.12.0"):
291-
raise RuntimeError(
292-
"Decoding 'mp3' audio files, requires 'torchaudio<0.12.0': pip install 'torchaudio<0.12.0'"
293-
)
294-
try:
295-
torchaudio.set_audio_backend("sox_io")
296-
except RuntimeError as err:
297-
raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err
287+
raise ImportError("To support decoding 'mp3' audio files, please install 'torchaudio'.") from err
288+
if version.parse(torchaudio.__version__) < version.parse("0.12.0"):
289+
try:
290+
torchaudio.set_audio_backend("sox_io")
291+
except RuntimeError as err:
292+
raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err
293+
array, sampling_rate = self._decode_mp3_torchaudio(path_or_file)
294+
else:
295+
try: # try torchaudio anyway because sometimes it works (depending on the os and os packages installed)
296+
array, sampling_rate = self._decode_mp3_torchaudio(path_or_file)
297+
except RuntimeError:
298+
try:
299+
# flake8: noqa
300+
import librosa
301+
except ImportError as err:
302+
raise ImportError(
303+
"Your version of `torchaudio` (>=0.12.0) doesn't support decoding 'mp3' files on your machine. "
304+
"To support 'mp3' decoding with `torchaudio>=0.12.0`, please install `ffmpeg>=4` system package "
305+
'or downgrade `torchaudio` to <0.12: `pip install "torchaudio<0.12"`. '
306+
"To support decoding 'mp3' audio files without `torchaudio`, please install `librosa`: "
307+
"`pip install librosa`. Note that decoding will be extremely slow in that case."
308+
) from err
309+
# try to decode with librosa for torchaudio>=0.12.0 as a workaround
310+
warnings.warn("Decoding mp3 with `librosa` instead of `torchaudio`, decoding is slow.")
311+
try:
312+
array, sampling_rate = self._decode_mp3_librosa(path_or_file)
313+
except RuntimeError as err:
314+
raise RuntimeError(
315+
"Decoding of 'mp3' failed, probably because of streaming mode "
316+
"(`librosa` cannot decode 'mp3' file-like objects, only path-like)."
317+
) from err
318+
319+
return array, sampling_rate
320+
321+
def _decode_mp3_torchaudio(self, path_or_file):
322+
import torchaudio
323+
import torchaudio.transforms as T
298324

299325
array, sampling_rate = torchaudio.load(path_or_file, format="mp3")
300326
if self.sampling_rate and self.sampling_rate != sampling_rate:
@@ -306,3 +332,9 @@ def _decode_mp3(self, path_or_file):
306332
if self.mono:
307333
array = array.mean(axis=0)
308334
return array, sampling_rate
335+
336+
def _decode_mp3_librosa(self, path_or_file):
337+
import librosa
338+
339+
array, sampling_rate = librosa.load(path_or_file, mono=self.mono, sr=self.sampling_rate)
340+
return array, sampling_rate

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ def pytest_collection_modifyitems(config, items):
1515
item.add_marker(pytest.mark.unit)
1616

1717

18+
def pytest_configure(config):
19+
config.addinivalue_line("markers", "torchaudio_latest: mark test to run with torchaudio>=0.12")
20+
21+
1822
@pytest.fixture(autouse=True)
1923
def set_test_cache_config(tmp_path_factory, monkeypatch):
2024
# test_hf_cache_home = tmp_path_factory.mktemp("cache") # TODO: why a cache dir per test function does not work?

tests/features/test_audio.py

Lines changed: 168 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
import os
22
import tarfile
3+
from contextlib import nullcontext
4+
from unittest.mock import patch
35

46
import pyarrow as pa
57
import pytest
68

79
from datasets import Dataset, concatenate_datasets, load_dataset
810
from datasets.features import Audio, Features, Sequence, Value
911

10-
from ..utils import require_libsndfile_with_opus, require_sndfile, require_sox, require_torchaudio
12+
from ..utils import (
13+
require_libsndfile_with_opus,
14+
require_sndfile,
15+
require_sox,
16+
require_torchaudio,
17+
require_torchaudio_latest,
18+
)
1119

1220

1321
@pytest.fixture()
@@ -135,6 +143,26 @@ def test_audio_decode_example_mp3(shared_datadir):
135143
assert decoded_example["sampling_rate"] == 44100
136144

137145

146+
@pytest.mark.torchaudio_latest
147+
@require_torchaudio_latest
148+
@pytest.mark.parametrize("torchaudio_failed", [False, True])
149+
def test_audio_decode_example_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
150+
audio_path = str(shared_datadir / "test_audio_44100.mp3")
151+
audio = Audio()
152+
153+
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
154+
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
155+
) if torchaudio_failed else nullcontext():
156+
157+
if torchaudio_failed:
158+
load_mock.side_effect = RuntimeError()
159+
160+
decoded_example = audio.decode_example(audio.encode_example(audio_path))
161+
assert decoded_example["path"] == audio_path
162+
assert decoded_example["array"].shape == (110592,)
163+
assert decoded_example["sampling_rate"] == 44100
164+
165+
138166
@require_libsndfile_with_opus
139167
def test_audio_decode_example_opus(shared_datadir):
140168
audio_path = str(shared_datadir / "test_audio_48000.opus")
@@ -178,6 +206,34 @@ def test_audio_resampling_mp3_different_sampling_rates(shared_datadir):
178206
assert decoded_example["sampling_rate"] == 48000
179207

180208

209+
@pytest.mark.torchaudio_latest
210+
@require_torchaudio_latest
211+
@pytest.mark.parametrize("torchaudio_failed", [False, True])
212+
def test_audio_resampling_mp3_different_sampling_rates_torchaudio_latest(shared_datadir, torchaudio_failed):
213+
audio_path = str(shared_datadir / "test_audio_44100.mp3")
214+
audio_path2 = str(shared_datadir / "test_audio_16000.mp3")
215+
audio = Audio(sampling_rate=48000)
216+
217+
# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
218+
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
219+
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
220+
) if torchaudio_failed else nullcontext():
221+
if torchaudio_failed:
222+
load_mock.side_effect = RuntimeError()
223+
224+
decoded_example = audio.decode_example(audio.encode_example(audio_path))
225+
assert decoded_example.keys() == {"path", "array", "sampling_rate"}
226+
assert decoded_example["path"] == audio_path
227+
assert decoded_example["array"].shape == (120373,)
228+
assert decoded_example["sampling_rate"] == 48000
229+
230+
decoded_example = audio.decode_example(audio.encode_example(audio_path2))
231+
assert decoded_example.keys() == {"path", "array", "sampling_rate"}
232+
assert decoded_example["path"] == audio_path2
233+
assert decoded_example["array"].shape == (122688,)
234+
assert decoded_example["sampling_rate"] == 48000
235+
236+
181237
@require_sndfile
182238
def test_dataset_with_audio_feature(shared_datadir):
183239
audio_path = str(shared_datadir / "test_audio_44100.wav")
@@ -266,6 +322,38 @@ def test_dataset_with_audio_feature_tar_mp3(tar_mp3_path):
266322
assert column[0]["sampling_rate"] == 44100
267323

268324

325+
@pytest.mark.torchaudio_latest
326+
@require_torchaudio_latest
327+
def test_dataset_with_audio_feature_tar_mp3_torchaudio_latest(tar_mp3_path):
328+
# no test for librosa here because it doesn't support file-like objects, only paths
329+
audio_filename = "test_audio_44100.mp3"
330+
data = {"audio": []}
331+
for file_path, file_obj in iter_archive(tar_mp3_path):
332+
data["audio"].append({"path": file_path, "bytes": file_obj.read()})
333+
break
334+
features = Features({"audio": Audio()})
335+
dset = Dataset.from_dict(data, features=features)
336+
item = dset[0]
337+
assert item.keys() == {"audio"}
338+
assert item["audio"].keys() == {"path", "array", "sampling_rate"}
339+
assert item["audio"]["path"] == audio_filename
340+
assert item["audio"]["array"].shape == (110592,)
341+
assert item["audio"]["sampling_rate"] == 44100
342+
batch = dset[:1]
343+
assert batch.keys() == {"audio"}
344+
assert len(batch["audio"]) == 1
345+
assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"}
346+
assert batch["audio"][0]["path"] == audio_filename
347+
assert batch["audio"][0]["array"].shape == (110592,)
348+
assert batch["audio"][0]["sampling_rate"] == 44100
349+
column = dset["audio"]
350+
assert len(column) == 1
351+
assert column[0].keys() == {"path", "array", "sampling_rate"}
352+
assert column[0]["path"] == audio_filename
353+
assert column[0]["array"].shape == (110592,)
354+
assert column[0]["sampling_rate"] == 44100
355+
356+
269357
@require_sndfile
270358
def test_dataset_with_audio_feature_with_none():
271359
data = {"audio": [None]}
@@ -328,7 +416,7 @@ def test_resampling_at_loading_dataset_with_audio_feature(shared_datadir):
328416

329417

330418
@require_sox
331-
@require_sndfile
419+
@require_torchaudio
332420
def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir):
333421
audio_path = str(shared_datadir / "test_audio_44100.mp3")
334422
data = {"audio": [audio_path]}
@@ -355,6 +443,43 @@ def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir):
355443
assert column[0]["sampling_rate"] == 16000
356444

357445

446+
@pytest.mark.torchaudio_latest
447+
@require_torchaudio_latest
448+
@pytest.mark.parametrize("torchaudio_failed", [False, True])
449+
def test_resampling_at_loading_dataset_with_audio_feature_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
450+
audio_path = str(shared_datadir / "test_audio_44100.mp3")
451+
data = {"audio": [audio_path]}
452+
features = Features({"audio": Audio(sampling_rate=16000)})
453+
dset = Dataset.from_dict(data, features=features)
454+
455+
# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
456+
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
457+
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
458+
) if torchaudio_failed else nullcontext():
459+
if torchaudio_failed:
460+
load_mock.side_effect = RuntimeError()
461+
462+
item = dset[0]
463+
assert item.keys() == {"audio"}
464+
assert item["audio"].keys() == {"path", "array", "sampling_rate"}
465+
assert item["audio"]["path"] == audio_path
466+
assert item["audio"]["array"].shape == (40125,)
467+
assert item["audio"]["sampling_rate"] == 16000
468+
batch = dset[:1]
469+
assert batch.keys() == {"audio"}
470+
assert len(batch["audio"]) == 1
471+
assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"}
472+
assert batch["audio"][0]["path"] == audio_path
473+
assert batch["audio"][0]["array"].shape == (40125,)
474+
assert batch["audio"][0]["sampling_rate"] == 16000
475+
column = dset["audio"]
476+
assert len(column) == 1
477+
assert column[0].keys() == {"path", "array", "sampling_rate"}
478+
assert column[0]["path"] == audio_path
479+
assert column[0]["array"].shape == (40125,)
480+
assert column[0]["sampling_rate"] == 16000
481+
482+
358483
@require_sndfile
359484
def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):
360485
audio_path = str(shared_datadir / "test_audio_44100.wav")
@@ -386,7 +511,7 @@ def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):
386511

387512

388513
@require_sox
389-
@require_sndfile
514+
@require_torchaudio
390515
def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir):
391516
audio_path = str(shared_datadir / "test_audio_44100.mp3")
392517
data = {"audio": [audio_path]}
@@ -416,6 +541,46 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir)
416541
assert column[0]["sampling_rate"] == 16000
417542

418543

544+
@pytest.mark.torchaudio_latest
545+
@require_torchaudio_latest
546+
@pytest.mark.parametrize("torchaudio_failed", [False, True])
547+
def test_resampling_after_loading_dataset_with_audio_feature_mp3_torchaudio_latest(shared_datadir, torchaudio_failed):
548+
audio_path = str(shared_datadir / "test_audio_44100.mp3")
549+
data = {"audio": [audio_path]}
550+
features = Features({"audio": Audio()})
551+
dset = Dataset.from_dict(data, features=features)
552+
553+
# if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
554+
with patch("torchaudio.load") if torchaudio_failed else nullcontext() as load_mock, pytest.warns(
555+
UserWarning, match=r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
556+
) if torchaudio_failed else nullcontext():
557+
if torchaudio_failed:
558+
load_mock.side_effect = RuntimeError()
559+
560+
item = dset[0]
561+
assert item["audio"]["sampling_rate"] == 44100
562+
dset = dset.cast_column("audio", Audio(sampling_rate=16000))
563+
item = dset[0]
564+
assert item.keys() == {"audio"}
565+
assert item["audio"].keys() == {"path", "array", "sampling_rate"}
566+
assert item["audio"]["path"] == audio_path
567+
assert item["audio"]["array"].shape == (40125,)
568+
assert item["audio"]["sampling_rate"] == 16000
569+
batch = dset[:1]
570+
assert batch.keys() == {"audio"}
571+
assert len(batch["audio"]) == 1
572+
assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"}
573+
assert batch["audio"][0]["path"] == audio_path
574+
assert batch["audio"][0]["array"].shape == (40125,)
575+
assert batch["audio"][0]["sampling_rate"] == 16000
576+
column = dset["audio"]
577+
assert len(column) == 1
578+
assert column[0].keys() == {"path", "array", "sampling_rate"}
579+
assert column[0]["path"] == audio_path
580+
assert column[0]["array"].shape == (40125,)
581+
assert column[0]["sampling_rate"] == 16000
582+
583+
419584
@pytest.mark.parametrize(
420585
"build_data",
421586
[

tests/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,16 @@ def parse_flag_from_env(key, default=False):
6464
find_library("sox") is None,
6565
reason="test requires sox OS dependency; only available on non-Windows: 'sudo apt-get install sox'",
6666
)
67-
require_torchaudio = pytest.mark.skipif(find_spec("torchaudio") is None, reason="test requires torchaudio")
67+
require_torchaudio = pytest.mark.skipif(
68+
find_spec("torchaudio") is None
69+
or version.parse(import_module("torchaudio").__version__) >= version.parse("0.12.0"),
70+
reason="test requires torchaudio<0.12",
71+
)
72+
require_torchaudio_latest = pytest.mark.skipif(
73+
find_spec("torchaudio") is None
74+
or version.parse(import_module("torchaudio").__version__) < version.parse("0.12.0"),
75+
reason="test requires torchaudio>=0.12",
76+
)
6877

6978

7079
def require_beam(test_case):

0 commit comments

Comments
 (0)