Skip to content

Commit 1597e4f

Browse files
committed
wave: sampwidth for IEEE Float must be 4 or 8
This is also similar to what libsndfile does
1 parent 689ae2d commit 1597e4f

File tree

3 files changed

+70
-7
lines changed

3 files changed

+70
-7
lines changed

Doc/library/wave.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ Wave_write Objects
218218

219219
Set the sample width to *n* bytes.
220220

221+
For :data:`WAVE_FORMAT_IEEE_FLOAT`, only 4-byte (32-bit) and
222+
8-byte (64-bit) sample widths are supported.
223+
221224

222225
.. method:: getsampwidth()
223226

@@ -273,6 +276,9 @@ Wave_write Objects
273276
Supported values are :data:`WAVE_FORMAT_PCM` and
274277
:data:`WAVE_FORMAT_IEEE_FLOAT`.
275278

279+
When setting :data:`WAVE_FORMAT_IEEE_FLOAT`, the sample width must be
280+
4 or 8 bytes.
281+
276282

277283
.. method:: getformat()
278284

@@ -288,6 +294,8 @@ Wave_write Objects
288294
For backwards compatibility, a 6-item tuple without *format* is also
289295
accepted and defaults to :data:`WAVE_FORMAT_PCM`.
290296

297+
For ``format=WAVE_FORMAT_IEEE_FLOAT``, *sampwidth* must be 4 or 8.
298+
291299

292300
.. method:: getparams()
293301

Lib/test/test_wave.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,20 +189,31 @@ def test_setparams_7_tuple_uses_format(self):
189189
self.addCleanup(unlink, filename)
190190

191191
with wave.open(filename, 'wb') as w:
192-
w.setparams((1, 2, 22050, 0, 'NONE', 'not compressed',
192+
w.setparams((1, 4, 22050, 0, 'NONE', 'not compressed',
193193
wave.WAVE_FORMAT_IEEE_FLOAT))
194194
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
195195

196+
def test_setparams_7_tuple_ieee_64bit_sampwidth(self):
197+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
198+
filename = fp.name
199+
self.addCleanup(unlink, filename)
200+
201+
with wave.open(filename, 'wb') as w:
202+
w.setparams((1, 8, 22050, 0, 'NONE', 'not compressed',
203+
wave.WAVE_FORMAT_IEEE_FLOAT))
204+
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
205+
self.assertEqual(w.getsampwidth(), 8)
206+
196207
def test_getparams_backward_compatible_shape(self):
197208
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
198209
filename = fp.name
199210
self.addCleanup(unlink, filename)
200211

201212
with wave.open(filename, 'wb') as w:
202-
w.setparams((1, 2, 22050, 0, 'NONE', 'not compressed',
213+
w.setparams((1, 4, 22050, 0, 'NONE', 'not compressed',
203214
wave.WAVE_FORMAT_IEEE_FLOAT))
204215
params = w.getparams()
205-
self.assertEqual(params, (1, 2, 22050, 0, 'NONE', 'not compressed'))
216+
self.assertEqual(params, (1, 4, 22050, 0, 'NONE', 'not compressed'))
206217

207218
def test_getformat_setformat(self):
208219
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
@@ -211,12 +222,51 @@ def test_getformat_setformat(self):
211222

212223
with wave.open(filename, 'wb') as w:
213224
w.setnchannels(1)
214-
w.setsampwidth(2)
225+
w.setsampwidth(4)
215226
w.setframerate(22050)
216227
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_PCM)
217228
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
218229
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
219230

231+
def test_setformat_ieee_requires_32_or_64_bit_sampwidth(self):
232+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
233+
filename = fp.name
234+
self.addCleanup(unlink, filename)
235+
236+
with wave.open(filename, 'wb') as w:
237+
w.setnchannels(1)
238+
w.setsampwidth(2)
239+
w.setframerate(22050)
240+
with self.assertRaisesRegex(wave.Error,
241+
'unsupported sample width for IEEE float format'):
242+
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
243+
244+
def test_setsampwidth_ieee_requires_32_or_64_bit(self):
245+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
246+
filename = fp.name
247+
self.addCleanup(unlink, filename)
248+
249+
with wave.open(filename, 'wb') as w:
250+
w.setnchannels(1)
251+
w.setframerate(22050)
252+
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
253+
with self.assertRaisesRegex(wave.Error,
254+
'unsupported sample width for IEEE float format'):
255+
w.setsampwidth(2)
256+
w.setsampwidth(4)
257+
258+
def test_setsampwidth_ieee_accepts_64_bit(self):
259+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
260+
filename = fp.name
261+
self.addCleanup(unlink, filename)
262+
263+
with wave.open(filename, 'wb') as w:
264+
w.setnchannels(1)
265+
w.setframerate(22050)
266+
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
267+
w.setsampwidth(8)
268+
self.assertEqual(w.getsampwidth(), 8)
269+
220270
def test_read_getformat(self):
221271
b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
222272
b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 1, 11025, 11025, 1, 8)
@@ -297,10 +347,10 @@ def test_ieee_float_has_fact_chunk(self):
297347

298348
with wave.open(filename, 'wb') as w:
299349
w.setnchannels(1)
300-
w.setsampwidth(2)
350+
w.setsampwidth(4)
301351
w.setframerate(22050)
302352
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
303-
w.writeframes(b'\x00\x00' * nframes)
353+
w.writeframes(b'\x00\x00\x00\x00' * nframes)
304354

305355
with open(filename, 'rb') as f:
306356
f.read(12)

Lib/wave.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,10 @@ def getnchannels(self):
506506
def setsampwidth(self, sampwidth):
507507
if self._datawritten:
508508
raise Error('cannot change parameters after starting to write')
509-
if sampwidth < 1 or sampwidth > 4:
509+
if self._format == WAVE_FORMAT_IEEE_FLOAT:
510+
if sampwidth not in (4, 8):
511+
raise Error('unsupported sample width for IEEE float format')
512+
elif sampwidth < 1 or sampwidth > 4:
510513
raise Error('bad sample width')
511514
self._sampwidth = sampwidth
512515

@@ -548,6 +551,8 @@ def setformat(self, format):
548551
raise Error('cannot change parameters after starting to write')
549552
if format not in (WAVE_FORMAT_IEEE_FLOAT, WAVE_FORMAT_PCM):
550553
raise Error('unsupported wave format')
554+
if format == WAVE_FORMAT_IEEE_FLOAT and self._sampwidth and self._sampwidth not in (4, 8):
555+
raise Error('unsupported sample width for IEEE float format')
551556
self._format = format
552557

553558
def getformat(self):

0 commit comments

Comments
 (0)