|
4 | 4 | from collections import abc |
5 | 5 | import dataclasses |
6 | 6 | import gzip |
7 | | -from io import BufferedIOBase, BytesIO, RawIOBase, TextIOWrapper |
| 7 | +from io import BufferedIOBase, BytesIO, RawIOBase, StringIO, TextIOWrapper |
8 | 8 | import mmap |
9 | 9 | import os |
10 | | -from typing import IO, Any, AnyStr, Dict, List, Mapping, Optional, Tuple, cast |
| 10 | +from typing import IO, Any, AnyStr, Dict, List, Mapping, Optional, Tuple, Union, cast |
11 | 11 | from urllib.parse import ( |
12 | 12 | urljoin, |
13 | 13 | urlparse as parse_url, |
@@ -707,17 +707,36 @@ def __init__( |
707 | 707 | archive_name: Optional[str] = None, |
708 | 708 | **kwargs, |
709 | 709 | ): |
710 | | - if mode in ["wb", "rb"]: |
711 | | - mode = mode.replace("b", "") |
| 710 | + mode = mode.replace("b", "") |
712 | 711 | self.archive_name = archive_name |
| 712 | + self.multiple_write_buffer: Optional[Union[StringIO, BytesIO]] = None |
| 713 | + |
713 | 714 | kwargs_zip: Dict[str, Any] = {"compression": zipfile.ZIP_DEFLATED} |
714 | 715 | kwargs_zip.update(kwargs) |
| 716 | + |
715 | 717 | super().__init__(file, mode, **kwargs_zip) # type: ignore[arg-type] |
716 | 718 |
|
717 | 719 | def write(self, data): |
| 720 | + # buffer multiple write calls, write on flush |
| 721 | + if self.multiple_write_buffer is None: |
| 722 | + self.multiple_write_buffer = ( |
| 723 | + BytesIO() if isinstance(data, bytes) else StringIO() |
| 724 | + ) |
| 725 | + self.multiple_write_buffer.write(data) |
| 726 | + |
| 727 | + def flush(self) -> None: |
| 728 | + # write to actual handle and close write buffer |
| 729 | + if self.multiple_write_buffer is None or self.multiple_write_buffer.closed: |
| 730 | + return |
| 731 | + |
718 | 732 | # ZipFile needs a non-empty string |
719 | 733 | archive_name = self.archive_name or self.filename or "zip" |
720 | | - super().writestr(archive_name, data) |
| 734 | + with self.multiple_write_buffer: |
| 735 | + super().writestr(archive_name, self.multiple_write_buffer.getvalue()) |
| 736 | + |
| 737 | + def close(self): |
| 738 | + self.flush() |
| 739 | + super().close() |
721 | 740 |
|
722 | 741 | @property |
723 | 742 | def closed(self): |
|
0 commit comments