Skip to content

Commit 73a7601

Browse files
authored
Enhancing get_all_bundles_list under monai.bundle to support model zoo NGC hosting (#6997)
Fixes #6833 ### Description Add `model_info_url` in `get_all_bundles_list`, `get_bundle_versions`, and `get_bundle_info`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <[email protected]>
1 parent b31367f commit 73a7601

File tree

2 files changed

+66
-28
lines changed

2 files changed

+66
-28
lines changed

monai/bundle/scripts.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -541,11 +541,16 @@ def load(
541541
return model
542542

543543

544+
@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.2", replaced="1.5")
544545
def _get_all_bundles_info(
545546
repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None
546547
) -> dict[str, dict[str, dict[str, Any]]]:
547548
if has_requests:
548-
request_url = f"https://api.github.com/repos/{repo}/releases"
549+
if tag == "hosting_storage_v1":
550+
request_url = f"https://api.github.com/repos/{repo}/releases"
551+
else:
552+
request_url = f"https://raw.githubusercontent.com/{repo}/{tag}/models/model_info.json"
553+
549554
if auth_token is not None:
550555
headers = {"Authorization": f"Bearer {auth_token}"}
551556
resp = requests_get(request_url, headers=headers)
@@ -558,33 +563,39 @@ def _get_all_bundles_info(
558563
bundle_name_pattern = re.compile(r"_v\d*.")
559564
bundles_info: dict[str, dict[str, dict[str, Any]]] = {}
560565

561-
for release in releases_list:
562-
if release["tag_name"] == tag:
563-
for asset in release["assets"]:
564-
asset_name = bundle_name_pattern.split(asset["name"])[0]
565-
if asset_name not in bundles_info:
566-
bundles_info[asset_name] = {}
567-
asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
568-
bundles_info[asset_name][asset_version] = {
569-
"id": asset["id"],
570-
"name": asset["name"],
571-
"size": asset["size"],
572-
"download_count": asset["download_count"],
573-
"browser_download_url": asset["browser_download_url"],
574-
"created_at": asset["created_at"],
575-
"updated_at": asset["updated_at"],
576-
}
577-
return bundles_info
566+
if tag == "hosting_storage_v1":
567+
for release in releases_list:
568+
if release["tag_name"] == tag:
569+
for asset in release["assets"]:
570+
asset_name = bundle_name_pattern.split(asset["name"])[0]
571+
if asset_name not in bundles_info:
572+
bundles_info[asset_name] = {}
573+
asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
574+
bundles_info[asset_name][asset_version] = dict(asset)
575+
return bundles_info
576+
else:
577+
for asset in releases_list.keys():
578+
asset_name = bundle_name_pattern.split(asset)[0]
579+
if asset_name not in bundles_info:
580+
bundles_info[asset_name] = {}
581+
asset_version = asset.split(f"{asset_name}_v")[-1]
582+
bundles_info[asset_name][asset_version] = {
583+
"name": asset,
584+
"browser_download_url": releases_list[asset]["source"],
585+
}
578586
return bundles_info
579587

580588

589+
@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
581590
def get_all_bundles_list(
582591
repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None
583592
) -> list[tuple[str, str]]:
584593
"""
585594
Get all bundles names (and the latest versions) that are stored in the release of specified repository
586-
with the provided tag. The default values of arguments correspond to the release of MONAI model zoo.
587-
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
595+
with the provided tag. If tag is "dev", will get model information from
596+
https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json.
597+
The default values of arguments correspond to the release of MONAI model zoo. In order to increase the
598+
rate limits of calling Github APIs, you can input your personal access token.
588599
Please check the following link for more details about rate limiting:
589600
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
590601
@@ -610,6 +621,7 @@ def get_all_bundles_list(
610621
return bundles_list
611622

612623

624+
@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
613625
def get_bundle_versions(
614626
bundle_name: str,
615627
repo: str = "Project-MONAI/model-zoo",
@@ -618,7 +630,8 @@ def get_bundle_versions(
618630
) -> dict[str, list[str] | str]:
619631
"""
620632
Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified
621-
repository with the provided tag.
633+
repository with the provided tag. If tag is "dev", will get model information from
634+
https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json.
622635
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
623636
Please check the following link for more details about rate limiting:
624637
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
@@ -646,6 +659,7 @@ def get_bundle_versions(
646659
return {"latest_version": all_versions[-1], "all_versions": all_versions}
647660

648661

662+
@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
649663
def get_bundle_info(
650664
bundle_name: str,
651665
version: str | None = None,
@@ -656,7 +670,9 @@ def get_bundle_info(
656670
"""
657671
Get all information
658672
(include "id", "name", "size", "download_count", "browser_download_url", "created_at", "updated_at") of a bundle
659-
with the specified bundle name and version.
673+
with the specified bundle name and version which is stored in the release of specified repository with the provided tag.
674+
Since v1.5, "hosting_storage_v1" will be deprecated in favor of 'dev', which contains only "name" and "browser_download_url".
675+
information about a bundle.
660676
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
661677
Please check the following link for more details about rate limiting:
662678
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
@@ -685,7 +701,7 @@ def get_bundle_info(
685701
if version not in bundle_info:
686702
raise ValueError(f"version: {version} of bundle: {bundle_name} is not existing.")
687703

688-
return bundle_info[version]
704+
return bundle_info[version] # type: ignore[no-any-return]
689705

690706

691707
@deprecated_arg("runner_id", since="1.1", removed="1.3", new_name="run_id", msg_suffix="please use `run_id` instead.")

tests/test_bundle_get_data.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,34 @@
2525

2626
TEST_CASE_2 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None}]
2727

28-
TEST_CASE_FAKE_TOKEN = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]
28+
TEST_CASE_3 = [{"tag": "hosting_storage_v1"}]
29+
30+
TEST_CASE_4 = [{"tag": "dev"}]
31+
32+
TEST_CASE_5 = [{"bundle_name": "brats_mri_segmentation", "tag": "dev"}]
33+
34+
TEST_CASE_6 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None, "tag": "dev"}]
35+
36+
TEST_CASE_FAKE_TOKEN_1 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]
37+
38+
TEST_CASE_FAKE_TOKEN_2 = [
39+
{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken", "tag": "dev"}
40+
]
2941

3042

3143
@skip_if_windows
3244
@SkipIfNoModule("requests")
3345
class TestGetBundleData(unittest.TestCase):
46+
@parameterized.expand([TEST_CASE_3, TEST_CASE_4])
3447
@skip_if_quick
35-
def test_get_all_bundles_list(self):
48+
def test_get_all_bundles_list(self, params):
3649
with skip_if_downloading_fails():
37-
output = get_all_bundles_list()
50+
output = get_all_bundles_list(**params)
3851
self.assertTrue(isinstance(output, list))
3952
self.assertTrue(isinstance(output[0], tuple))
4053
self.assertTrue(len(output[0]) == 2)
4154

42-
@parameterized.expand([TEST_CASE_1])
55+
@parameterized.expand([TEST_CASE_1, TEST_CASE_5])
4356
@skip_if_quick
4457
def test_get_bundle_versions(self, params):
4558
with skip_if_downloading_fails():
@@ -57,7 +70,16 @@ def test_get_bundle_info(self, params):
5770
for key in ["id", "name", "size", "download_count", "browser_download_url"]:
5871
self.assertTrue(key in output)
5972

60-
@parameterized.expand([TEST_CASE_FAKE_TOKEN])
73+
@parameterized.expand([TEST_CASE_5, TEST_CASE_6])
74+
@skip_if_quick
75+
def test_get_bundle_info_monaihosting(self, params):
76+
with skip_if_downloading_fails():
77+
output = get_bundle_info(**params)
78+
self.assertTrue(isinstance(output, dict))
79+
for key in ["name", "browser_download_url"]:
80+
self.assertTrue(key in output)
81+
82+
@parameterized.expand([TEST_CASE_FAKE_TOKEN_1, TEST_CASE_FAKE_TOKEN_2])
6183
@skip_if_quick
6284
def test_fake_token(self, params):
6385
with skip_if_downloading_fails():

0 commit comments

Comments
 (0)