1
1
import requests
2
2
from tqdm import tqdm
3
3
import zipfile
4
+ import tarfile
4
5
import os
5
6
import shutil
6
7
@@ -26,7 +27,7 @@ def _create_if_not_exists(path: str, remove=True) -> None:
26
27
27
28
os .makedirs (path , exist_ok = True )
28
29
29
- def download (url : str , filename : str , unzip = False , unzip_path : str = None , force_download = False , force_unzip = False , clean = False ) -> str :
30
+ def download (url : str , filename : str , unzip = True , unzip_path : str = None , force_download = False , force_unzip = False , clean = False ) -> str :
30
31
"""
31
32
Download a file from a OneDrive url.
32
33
@@ -84,7 +85,7 @@ def download(url: str, filename: str, unzip=False, unzip_path: str = None, force
84
85
85
86
# unzip file if necessary
86
87
if unzip :
87
- if filename .endswith (".zip" ):
88
+ if filename .endswith (".zip" ) or filename . endswith ( ".tar.gz" ) :
88
89
unzip_path = unzip_path if unzip_path is not None else os .path .split (filename )[0 ]
89
90
clean_unzip_path = force_unzip and os .path .realpath (unzip_path ) not in os .path .realpath (filename )
90
91
ret_path = unzip_path
@@ -94,10 +95,16 @@ def download(url: str, filename: str, unzip=False, unzip_path: str = None, force
94
95
if force_unzip :
95
96
print ("Warning: overwriting existing files!" )
96
97
97
- with zipfile .ZipFile (filename , 'r' ) as zip_ref :
98
- for file in tqdm (iterable = zip_ref .namelist (), total = len (zip_ref .namelist ()), desc = "Extracting files" ):
99
- if not os .path .exists (os .path .join (unzip_path , file )) or force_unzip :
100
- zip_ref .extract (member = file , path = unzip_path )
98
+ if filename .endswith (".zip" ):
99
+ with zipfile .ZipFile (filename , 'r' ) as zip_ref :
100
+ for file in tqdm (iterable = zip_ref .namelist (), total = len (zip_ref .namelist ()), desc = "Extracting files" ):
101
+ if not os .path .exists (os .path .join (unzip_path , file )) or force_unzip :
102
+ zip_ref .extract (member = file , path = unzip_path )
103
+ elif filename .endswith (".tar.gz" ):
104
+ with tarfile .open (filename , 'r:gz' ) as tar_ref :
105
+ for file in tqdm (iterable = tar_ref .getnames (), total = len (tar_ref .getnames ()), desc = "Extracting files" ):
106
+ if not os .path .exists (os .path .join (unzip_path , file )) or force_unzip :
107
+ tar_ref .extract (member = file , path = unzip_path )
101
108
102
109
if clean :
103
110
os .remove (filename )
@@ -109,7 +116,7 @@ def download(url: str, filename: str, unzip=False, unzip_path: str = None, force
109
116
110
117
111
118
if __name__ == "__main__" :
112
- ln = " https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EQ-DxzGOF7lBt90A601kvVEBR_ca9PtUdN_asesZ-F80bw?download=1"
119
+ ln = ' https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/EdD9BE-36ohCpMy0_EKyYb0BnnSnrP7g8TOqTaeYsy-FCA?e=OJVriO'
113
120
print ('Downloading dataset' )
114
- ret = download (ln , filename = "./tmp/" , unzip = True )
121
+ ret = download (ln , filename = "./tmp/" , clean = True )
115
122
print (ret )
0 commit comments