16
16
import json
17
17
import typing
18
18
import urllib .request
19
+ import urllib .error
19
20
from urllib .parse import urlparse
20
21
import re
21
22
import base64
22
23
from dataclasses import dataclass
23
24
import argparse
24
25
25
-
26
26
def options ():
27
27
p = argparse .ArgumentParser (description = __doc__ )
28
28
p .add_argument ("--hash-only" , action = "store_true" )
29
29
p .add_argument ("sources" , type = pathlib .Path , nargs = "+" )
30
30
return p .parse_args ()
31
31
32
32
33
+ TIMEOUT = 20
34
+
35
+ def warn (message : str ) -> None :
36
+ print (f"WARNING: { message } " , file = sys .stderr )
37
+
38
+
33
39
@dataclass
34
40
class Endpoint :
35
41
name : str
@@ -41,6 +47,10 @@ def update_headers(self, d: typing.Iterable[typing.Tuple[str, str]]):
41
47
self .headers .update ((k .capitalize (), v ) for k , v in d )
42
48
43
49
50
+ class NoEndpointsFound (Exception ):
51
+ pass
52
+
53
+
44
54
opts = options ()
45
55
sources = [p .resolve () for p in opts .sources ]
46
56
source_dir = pathlib .Path (os .path .commonpath (src .parent for src in sources ))
@@ -105,18 +115,12 @@ def get_endpoints() -> typing.Iterable[Endpoint]:
105
115
"download" ,
106
116
]
107
117
try :
108
- res = subprocess .run (cmd , stdout = subprocess .PIPE , timeout = 15 )
118
+ res = subprocess .run (cmd , stdout = subprocess .PIPE , timeout = TIMEOUT )
109
119
except subprocess .TimeoutExpired :
110
- print (
111
- f"WARNING: ssh timed out when connecting to { server } , ignoring { endpoint .name } endpoint" ,
112
- file = sys .stderr ,
113
- )
120
+ warn (f"ssh timed out when connecting to { server } , ignoring { endpoint .name } endpoint" )
114
121
continue
115
122
if res .returncode != 0 :
116
- print (
117
- f"WARNING: ssh failed when connecting to { server } , ignoring { endpoint .name } endpoint" ,
118
- file = sys .stderr ,
119
- )
123
+ warn (f"ssh failed when connecting to { server } , ignoring { endpoint .name } endpoint" )
120
124
continue
121
125
ssh_resp = json .loads (res .stdout )
122
126
endpoint .href = ssh_resp .get ("href" , endpoint )
@@ -139,10 +143,7 @@ def get_endpoints() -> typing.Iterable[Endpoint]:
139
143
input = f"protocol={ url .scheme } \n host={ url .netloc } \n path={ url .path [1 :]} \n " ,
140
144
)
141
145
if credentials is None :
142
- print (
143
- f"WARNING: no authorization method found, ignoring { data .name } endpoint" ,
144
- file = sys .stderr ,
145
- )
146
+ warn (f"no authorization method found, ignoring { endpoint .name } endpoint" )
146
147
continue
147
148
credentials = dict (get_env (credentials ))
148
149
auth = base64 .b64encode (
@@ -176,18 +177,18 @@ def get_locations(objects):
176
177
data = json .dumps (data ).encode ("ascii" ),
177
178
)
178
179
try :
179
- with urllib .request .urlopen (req ) as resp :
180
+ with urllib .request .urlopen (req , timeout = TIMEOUT ) as resp :
180
181
data = json .load (resp )
181
- except urllib .request . HTTPError as e :
182
- print (f"WARNING: encountered HTTPError { e } , ignoring endpoint { e .name } " )
182
+ except urllib .error . URLError as e :
183
+ warn (f"encountered { type ( e ). __name__ } { e } , ignoring endpoint { endpoint .name } " )
183
184
continue
184
185
assert len (data ["objects" ]) == len (
185
186
indexes
186
187
), f"received { len (data )} objects, expected { len (indexes )} "
187
188
for i , resp in zip (indexes , data ["objects" ]):
188
189
ret [i ] = f'{ resp ["oid" ]} { resp ["actions" ]["download" ]["href" ]} '
189
190
return ret
190
- raise Exception ( f"no valid endpoint found" )
191
+ raise NoEndpointsFound
191
192
192
193
193
194
def get_lfs_object (path ):
@@ -204,6 +205,10 @@ def get_lfs_object(path):
204
205
return {"oid" : sha256 , "size" : size }
205
206
206
207
207
- objects = [get_lfs_object (src ) for src in sources ]
208
- for resp in get_locations (objects ):
209
- print (resp )
208
+ try :
209
+ objects = [get_lfs_object (src ) for src in sources ]
210
+ for resp in get_locations (objects ):
211
+ print (resp )
212
+ except NoEndpointsFound as e :
213
+ print (f"ERROR: no valid endpoints found" , file = sys .stderr )
214
+ sys .exit (1 )
0 commit comments