Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True)
self.eval()

self.channel_wise = channel_wise
Expand Down Expand Up @@ -297,7 +297,7 @@ class RadImageNetPerceptualSimilarity(nn.Module):

def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True)
self.eval()

for param in self.parameters():
Expand Down
43 changes: 27 additions & 16 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,32 @@

nib, _ = optional_import("nibabel")
http_error, has_req = optional_import("requests", name="HTTPError")
file_url_error, has_gdown = optional_import("gdown.exceptions", name="FileURLRetrievalError")


quick_test_var = "QUICKTEST"
_tf32_enabled = None
_test_data_config: dict = {}

MODULE_PATH = Path(__file__).resolve().parents[1]

DOWNLOAD_EXCEPTS: tuple[type, ...] = (ContentTooShortError, HTTPError, ConnectionError)
if has_req:
DOWNLOAD_EXCEPTS += (http_error,)
if has_gdown:
DOWNLOAD_EXCEPTS += (file_url_error,)

DOWNLOAD_FAIL_MSGS = (
"unexpected EOF", # incomplete download
"network issue",
"gdown dependency", # gdown not installed
"md5 check",
"limit", # HTTP Error 503: Egress is over the account limit
"authenticate",
"timed out", # urlopen error [Errno 110] Connection timed out
"HTTPError", # HTTPError: 429 Client Error: Too Many Requests for huggingface hub
)


def testing_data_config(*keys):
"""get _test_data_config[keys0][keys1]...[keysN]"""
Expand Down Expand Up @@ -142,29 +161,21 @@ def assert_allclose(

@contextmanager
def skip_if_downloading_fails():
"""
Skips a test if downloading something raises an exception recognised to indicate a download has failed.
"""

try:
yield
except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_req else () as e: # noqa: B030
raise unittest.SkipTest(f"error while downloading: {e}") from e
except DOWNLOAD_EXCEPTS as e:
raise unittest.SkipTest(f"Error while downloading: {e}") from e
except ssl.SSLError as ssl_e:
if "decryption failed" in str(ssl_e):
raise unittest.SkipTest(f"SSL error while downloading: {ssl_e}") from ssl_e
except (RuntimeError, OSError) as rt_e:
err_str = str(rt_e)
if any(
k in err_str
for k in (
"unexpected EOF", # incomplete download
"network issue",
"gdown dependency", # gdown not installed
"md5 check",
"limit", # HTTP Error 503: Egress is over the account limit
"authenticate",
"timed out", # urlopen error [Errno 110] Connection timed out
"HTTPError", # HTTPError: 429 Client Error: Too Many Requests for huggingface hub
)
):
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download
if any(k in err_str for k in DOWNLOAD_FAIL_MSGS):
raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download

raise rt_e

Expand Down
Loading