Skip to content

Commit b79be71

Browse files
authored
Merge pull request #17306 from github/redsun82/bazel-lfs
Bazel: fix logging bug in `git_lfs_probe.py`
2 parents b3fa4f3 + 0738e01 commit b79be71

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

misc/bazel/internal/git_lfs_probe.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,26 @@
1616
import json
1717
import typing
1818
import urllib.request
19+
import urllib.error
1920
from urllib.parse import urlparse
2021
import re
2122
import base64
2223
from dataclasses import dataclass
2324
import argparse
2425

25-
2626
def options():
2727
p = argparse.ArgumentParser(description=__doc__)
2828
p.add_argument("--hash-only", action="store_true")
2929
p.add_argument("sources", type=pathlib.Path, nargs="+")
3030
return p.parse_args()
3131

3232

33+
TIMEOUT = 20
34+
35+
def warn(message: str) -> None:
36+
print(f"WARNING: {message}", file=sys.stderr)
37+
38+
3339
@dataclass
3440
class Endpoint:
3541
name: str
@@ -41,6 +47,10 @@ def update_headers(self, d: typing.Iterable[typing.Tuple[str, str]]):
4147
self.headers.update((k.capitalize(), v) for k, v in d)
4248

4349

50+
class NoEndpointsFound(Exception):
51+
pass
52+
53+
4454
opts = options()
4555
sources = [p.resolve() for p in opts.sources]
4656
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
@@ -105,18 +115,12 @@ def get_endpoints() -> typing.Iterable[Endpoint]:
105115
"download",
106116
]
107117
try:
108-
res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=15)
118+
res = subprocess.run(cmd, stdout=subprocess.PIPE, timeout=TIMEOUT)
109119
except subprocess.TimeoutExpired:
110-
print(
111-
f"WARNING: ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint",
112-
file=sys.stderr,
113-
)
120+
warn(f"ssh timed out when connecting to {server}, ignoring {endpoint.name} endpoint")
114121
continue
115122
if res.returncode != 0:
116-
print(
117-
f"WARNING: ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint",
118-
file=sys.stderr,
119-
)
123+
warn(f"ssh failed when connecting to {server}, ignoring {endpoint.name} endpoint")
120124
continue
121125
ssh_resp = json.loads(res.stdout)
122126
endpoint.href = ssh_resp.get("href", endpoint)
@@ -139,10 +143,7 @@ def get_endpoints() -> typing.Iterable[Endpoint]:
139143
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n",
140144
)
141145
if credentials is None:
142-
print(
143-
f"WARNING: no authorization method found, ignoring {data.name} endpoint",
144-
file=sys.stderr,
145-
)
146+
warn(f"no authorization method found, ignoring {endpoint.name} endpoint")
146147
continue
147148
credentials = dict(get_env(credentials))
148149
auth = base64.b64encode(
@@ -176,18 +177,18 @@ def get_locations(objects):
176177
data=json.dumps(data).encode("ascii"),
177178
)
178179
try:
179-
with urllib.request.urlopen(req) as resp:
180+
with urllib.request.urlopen(req, timeout=TIMEOUT) as resp:
180181
data = json.load(resp)
181-
except urllib.request.HTTPError as e:
182-
print(f"WARNING: encountered HTTPError {e}, ignoring endpoint {e.name}")
182+
except urllib.error.URLError as e:
183+
warn(f"encountered {type(e).__name__} {e}, ignoring endpoint {endpoint.name}")
183184
continue
184185
assert len(data["objects"]) == len(
185186
indexes
186187
), f"received {len(data)} objects, expected {len(indexes)}"
187188
for i, resp in zip(indexes, data["objects"]):
188189
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
189190
return ret
190-
raise Exception(f"no valid endpoint found")
191+
raise NoEndpointsFound
191192

192193

193194
def get_lfs_object(path):
@@ -204,6 +205,10 @@ def get_lfs_object(path):
204205
return {"oid": sha256, "size": size}
205206

206207

207-
objects = [get_lfs_object(src) for src in sources]
208-
for resp in get_locations(objects):
209-
print(resp)
208+
try:
209+
objects = [get_lfs_object(src) for src in sources]
210+
for resp in get_locations(objects):
211+
print(resp)
212+
except NoEndpointsFound as e:
213+
print(f"ERROR: no valid endpoints found", file=sys.stderr)
214+
sys.exit(1)

0 commit comments

Comments
 (0)