1
1
import os
2
2
import tarfile
3
+ from contextlib import nullcontext
4
+ from unittest .mock import patch
3
5
4
6
import pyarrow as pa
5
7
import pytest
6
8
7
9
from datasets import Dataset , concatenate_datasets , load_dataset
8
10
from datasets .features import Audio , Features , Sequence , Value
9
11
10
- from ..utils import require_libsndfile_with_opus , require_sndfile , require_sox , require_torchaudio
12
+ from ..utils import (
13
+ require_libsndfile_with_opus ,
14
+ require_sndfile ,
15
+ require_sox ,
16
+ require_torchaudio ,
17
+ require_torchaudio_latest ,
18
+ )
11
19
12
20
13
21
@pytest .fixture ()
@@ -135,6 +143,26 @@ def test_audio_decode_example_mp3(shared_datadir):
135
143
assert decoded_example ["sampling_rate" ] == 44100
136
144
137
145
146
+ @pytest .mark .torchaudio_latest
147
+ @require_torchaudio_latest
148
+ @pytest .mark .parametrize ("torchaudio_failed" , [False , True ])
149
+ def test_audio_decode_example_mp3_torchaudio_latest (shared_datadir , torchaudio_failed ):
150
+ audio_path = str (shared_datadir / "test_audio_44100.mp3" )
151
+ audio = Audio ()
152
+
153
+ with patch ("torchaudio.load" ) if torchaudio_failed else nullcontext () as load_mock , pytest .warns (
154
+ UserWarning , match = r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
155
+ ) if torchaudio_failed else nullcontext ():
156
+
157
+ if torchaudio_failed :
158
+ load_mock .side_effect = RuntimeError ()
159
+
160
+ decoded_example = audio .decode_example (audio .encode_example (audio_path ))
161
+ assert decoded_example ["path" ] == audio_path
162
+ assert decoded_example ["array" ].shape == (110592 ,)
163
+ assert decoded_example ["sampling_rate" ] == 44100
164
+
165
+
138
166
@require_libsndfile_with_opus
139
167
def test_audio_decode_example_opus (shared_datadir ):
140
168
audio_path = str (shared_datadir / "test_audio_48000.opus" )
@@ -178,6 +206,34 @@ def test_audio_resampling_mp3_different_sampling_rates(shared_datadir):
178
206
assert decoded_example ["sampling_rate" ] == 48000
179
207
180
208
209
+ @pytest .mark .torchaudio_latest
210
+ @require_torchaudio_latest
211
+ @pytest .mark .parametrize ("torchaudio_failed" , [False , True ])
212
+ def test_audio_resampling_mp3_different_sampling_rates_torchaudio_latest (shared_datadir , torchaudio_failed ):
213
+ audio_path = str (shared_datadir / "test_audio_44100.mp3" )
214
+ audio_path2 = str (shared_datadir / "test_audio_16000.mp3" )
215
+ audio = Audio (sampling_rate = 48000 )
216
+
217
+ # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
218
+ with patch ("torchaudio.load" ) if torchaudio_failed else nullcontext () as load_mock , pytest .warns (
219
+ UserWarning , match = r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
220
+ ) if torchaudio_failed else nullcontext ():
221
+ if torchaudio_failed :
222
+ load_mock .side_effect = RuntimeError ()
223
+
224
+ decoded_example = audio .decode_example (audio .encode_example (audio_path ))
225
+ assert decoded_example .keys () == {"path" , "array" , "sampling_rate" }
226
+ assert decoded_example ["path" ] == audio_path
227
+ assert decoded_example ["array" ].shape == (120373 ,)
228
+ assert decoded_example ["sampling_rate" ] == 48000
229
+
230
+ decoded_example = audio .decode_example (audio .encode_example (audio_path2 ))
231
+ assert decoded_example .keys () == {"path" , "array" , "sampling_rate" }
232
+ assert decoded_example ["path" ] == audio_path2
233
+ assert decoded_example ["array" ].shape == (122688 ,)
234
+ assert decoded_example ["sampling_rate" ] == 48000
235
+
236
+
181
237
@require_sndfile
182
238
def test_dataset_with_audio_feature (shared_datadir ):
183
239
audio_path = str (shared_datadir / "test_audio_44100.wav" )
@@ -266,6 +322,38 @@ def test_dataset_with_audio_feature_tar_mp3(tar_mp3_path):
266
322
assert column [0 ]["sampling_rate" ] == 44100
267
323
268
324
325
+ @pytest .mark .torchaudio_latest
326
+ @require_torchaudio_latest
327
+ def test_dataset_with_audio_feature_tar_mp3_torchaudio_latest (tar_mp3_path ):
328
+ # no test for librosa here because it doesn't support file-like objects, only paths
329
+ audio_filename = "test_audio_44100.mp3"
330
+ data = {"audio" : []}
331
+ for file_path , file_obj in iter_archive (tar_mp3_path ):
332
+ data ["audio" ].append ({"path" : file_path , "bytes" : file_obj .read ()})
333
+ break
334
+ features = Features ({"audio" : Audio ()})
335
+ dset = Dataset .from_dict (data , features = features )
336
+ item = dset [0 ]
337
+ assert item .keys () == {"audio" }
338
+ assert item ["audio" ].keys () == {"path" , "array" , "sampling_rate" }
339
+ assert item ["audio" ]["path" ] == audio_filename
340
+ assert item ["audio" ]["array" ].shape == (110592 ,)
341
+ assert item ["audio" ]["sampling_rate" ] == 44100
342
+ batch = dset [:1 ]
343
+ assert batch .keys () == {"audio" }
344
+ assert len (batch ["audio" ]) == 1
345
+ assert batch ["audio" ][0 ].keys () == {"path" , "array" , "sampling_rate" }
346
+ assert batch ["audio" ][0 ]["path" ] == audio_filename
347
+ assert batch ["audio" ][0 ]["array" ].shape == (110592 ,)
348
+ assert batch ["audio" ][0 ]["sampling_rate" ] == 44100
349
+ column = dset ["audio" ]
350
+ assert len (column ) == 1
351
+ assert column [0 ].keys () == {"path" , "array" , "sampling_rate" }
352
+ assert column [0 ]["path" ] == audio_filename
353
+ assert column [0 ]["array" ].shape == (110592 ,)
354
+ assert column [0 ]["sampling_rate" ] == 44100
355
+
356
+
269
357
@require_sndfile
270
358
def test_dataset_with_audio_feature_with_none ():
271
359
data = {"audio" : [None ]}
@@ -328,7 +416,7 @@ def test_resampling_at_loading_dataset_with_audio_feature(shared_datadir):
328
416
329
417
330
418
@require_sox
331
- @require_sndfile
419
+ @require_torchaudio
332
420
def test_resampling_at_loading_dataset_with_audio_feature_mp3 (shared_datadir ):
333
421
audio_path = str (shared_datadir / "test_audio_44100.mp3" )
334
422
data = {"audio" : [audio_path ]}
@@ -355,6 +443,43 @@ def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir):
355
443
assert column [0 ]["sampling_rate" ] == 16000
356
444
357
445
446
+ @pytest .mark .torchaudio_latest
447
+ @require_torchaudio_latest
448
+ @pytest .mark .parametrize ("torchaudio_failed" , [False , True ])
449
+ def test_resampling_at_loading_dataset_with_audio_feature_mp3_torchaudio_latest (shared_datadir , torchaudio_failed ):
450
+ audio_path = str (shared_datadir / "test_audio_44100.mp3" )
451
+ data = {"audio" : [audio_path ]}
452
+ features = Features ({"audio" : Audio (sampling_rate = 16000 )})
453
+ dset = Dataset .from_dict (data , features = features )
454
+
455
+ # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
456
+ with patch ("torchaudio.load" ) if torchaudio_failed else nullcontext () as load_mock , pytest .warns (
457
+ UserWarning , match = r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
458
+ ) if torchaudio_failed else nullcontext ():
459
+ if torchaudio_failed :
460
+ load_mock .side_effect = RuntimeError ()
461
+
462
+ item = dset [0 ]
463
+ assert item .keys () == {"audio" }
464
+ assert item ["audio" ].keys () == {"path" , "array" , "sampling_rate" }
465
+ assert item ["audio" ]["path" ] == audio_path
466
+ assert item ["audio" ]["array" ].shape == (40125 ,)
467
+ assert item ["audio" ]["sampling_rate" ] == 16000
468
+ batch = dset [:1 ]
469
+ assert batch .keys () == {"audio" }
470
+ assert len (batch ["audio" ]) == 1
471
+ assert batch ["audio" ][0 ].keys () == {"path" , "array" , "sampling_rate" }
472
+ assert batch ["audio" ][0 ]["path" ] == audio_path
473
+ assert batch ["audio" ][0 ]["array" ].shape == (40125 ,)
474
+ assert batch ["audio" ][0 ]["sampling_rate" ] == 16000
475
+ column = dset ["audio" ]
476
+ assert len (column ) == 1
477
+ assert column [0 ].keys () == {"path" , "array" , "sampling_rate" }
478
+ assert column [0 ]["path" ] == audio_path
479
+ assert column [0 ]["array" ].shape == (40125 ,)
480
+ assert column [0 ]["sampling_rate" ] == 16000
481
+
482
+
358
483
@require_sndfile
359
484
def test_resampling_after_loading_dataset_with_audio_feature (shared_datadir ):
360
485
audio_path = str (shared_datadir / "test_audio_44100.wav" )
@@ -386,7 +511,7 @@ def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):
386
511
387
512
388
513
@require_sox
389
- @require_sndfile
514
+ @require_torchaudio
390
515
def test_resampling_after_loading_dataset_with_audio_feature_mp3 (shared_datadir ):
391
516
audio_path = str (shared_datadir / "test_audio_44100.mp3" )
392
517
data = {"audio" : [audio_path ]}
@@ -416,6 +541,46 @@ def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir)
416
541
assert column [0 ]["sampling_rate" ] == 16000
417
542
418
543
544
+ @pytest .mark .torchaudio_latest
545
+ @require_torchaudio_latest
546
+ @pytest .mark .parametrize ("torchaudio_failed" , [False , True ])
547
+ def test_resampling_after_loading_dataset_with_audio_feature_mp3_torchaudio_latest (shared_datadir , torchaudio_failed ):
548
+ audio_path = str (shared_datadir / "test_audio_44100.mp3" )
549
+ data = {"audio" : [audio_path ]}
550
+ features = Features ({"audio" : Audio ()})
551
+ dset = Dataset .from_dict (data , features = features )
552
+
553
+ # if torchaudio>=0.12 failed, mp3 must be decoded anyway (with librosa)
554
+ with patch ("torchaudio.load" ) if torchaudio_failed else nullcontext () as load_mock , pytest .warns (
555
+ UserWarning , match = r"Decoding mp3 with `librosa` instead of `torchaudio`.+?"
556
+ ) if torchaudio_failed else nullcontext ():
557
+ if torchaudio_failed :
558
+ load_mock .side_effect = RuntimeError ()
559
+
560
+ item = dset [0 ]
561
+ assert item ["audio" ]["sampling_rate" ] == 44100
562
+ dset = dset .cast_column ("audio" , Audio (sampling_rate = 16000 ))
563
+ item = dset [0 ]
564
+ assert item .keys () == {"audio" }
565
+ assert item ["audio" ].keys () == {"path" , "array" , "sampling_rate" }
566
+ assert item ["audio" ]["path" ] == audio_path
567
+ assert item ["audio" ]["array" ].shape == (40125 ,)
568
+ assert item ["audio" ]["sampling_rate" ] == 16000
569
+ batch = dset [:1 ]
570
+ assert batch .keys () == {"audio" }
571
+ assert len (batch ["audio" ]) == 1
572
+ assert batch ["audio" ][0 ].keys () == {"path" , "array" , "sampling_rate" }
573
+ assert batch ["audio" ][0 ]["path" ] == audio_path
574
+ assert batch ["audio" ][0 ]["array" ].shape == (40125 ,)
575
+ assert batch ["audio" ][0 ]["sampling_rate" ] == 16000
576
+ column = dset ["audio" ]
577
+ assert len (column ) == 1
578
+ assert column [0 ].keys () == {"path" , "array" , "sampling_rate" }
579
+ assert column [0 ]["path" ] == audio_path
580
+ assert column [0 ]["array" ].shape == (40125 ,)
581
+ assert column [0 ]["sampling_rate" ] == 16000
582
+
583
+
419
584
@pytest .mark .parametrize (
420
585
"build_data" ,
421
586
[
0 commit comments