Skip to content

Commit df69212

Browse files
Merge branch 'main' into ODSC-71740/AQUA-BYOR
2 parents 4dfc037 + ddb3094 commit df69212

File tree

1 file changed

+43
-37
lines changed

1 file changed

+43
-37
lines changed

ads/aqua/common/utils.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,9 +1158,11 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11581158

11591159

11601160
def build_pydantic_error_message(ex: ValidationError):
1161-
"""Added to handle error messages from pydantic model validator.
1161+
"""
1162+
Added to handle error messages from pydantic model validator.
11621163
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1163-
message using msg field."""
1164+
message using msg field.
1165+
"""
11641166

11651167
return {
11661168
".".join(map(str, e["loc"])): e["msg"]
@@ -1185,67 +1187,71 @@ def is_pydantic_model(obj: object) -> bool:
11851187

11861188
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
11871189
def load_gpu_shapes_index(
1188-
auth: Optional[Dict] = None,
1190+
auth: Optional[Dict[str, Any]] = None,
11891191
) -> GPUShapesIndex:
11901192
"""
1191-
Loads the GPU shapes index from Object Storage or a local resource folder.
1193+
Load the GPU shapes index, preferring the OS bucket copy over the local one.
11921194
1193-
The function first attempts to load the file from an Object Storage bucket using fsspec.
1194-
If the loading fails (due to connection issues, missing file, etc.), it falls back to
1195-
loading the index from a local file.
1195+
Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
1196+
if that succeeds, those entries will override the local defaults.
11961197
11971198
Parameters
11981199
----------
1199-
auth: (Dict, optional). Defaults to None.
1200-
The default authentication is set using `ads.set_auth` API. If you need to override the
1201-
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1202-
authentication signer and kwargs required to instantiate IdentityClient object.
1200+
auth
1201+
Optional auth dict (as returned by `ads.common.auth.default_signer()`)
1202+
to pass through to `fsspec.open()`.
12031203
12041204
Returns
12051205
-------
1206-
GPUShapesIndex: The parsed GPU shapes index.
1206+
GPUShapesIndex
1207+
Merged index where any shape present remotely supersedes the local entry.
12071208
12081209
Raises
12091210
------
1210-
FileNotFoundError: If the GPU shapes index cannot be found in either Object Storage or locally.
1211-
json.JSONDecodeError: If the JSON is malformed.
1211+
json.JSONDecodeError
1212+
If any of the JSON is malformed.
12121213
"""
12131214
file_name = "gpu_shapes_index.json"
1214-
data: Dict[str, Any] = {}
12151215

1216-
# Check if the CONDA_BUCKET_NS environment variable is set.
1216+
# Try remote load
1217+
remote_data: Dict[str, Any] = {}
12171218
if CONDA_BUCKET_NS:
12181219
try:
12191220
auth = auth or authutil.default_signer()
1220-
# Construct the object storage path. Adjust bucket name and path as needed.
12211221
storage_path = (
12221222
f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/service_pack/{file_name}"
12231223
)
1224-
logger.debug("Loading GPU shapes index from Object Storage")
1225-
with fsspec.open(storage_path, mode="r", **auth) as file_obj:
1226-
data = json.load(file_obj)
1227-
logger.debug("Successfully loaded GPU shapes index.")
1228-
except Exception as ex:
12291224
logger.debug(
1230-
f"Failed to load GPU shapes index from Object Storage. Details: {ex}"
1225+
"Loading GPU shapes index from Object Storage: %s", storage_path
12311226
)
1232-
1233-
# If loading from Object Storage failed, load from the local resource folder.
1234-
if not data:
1235-
try:
1236-
local_path = os.path.join(
1237-
os.path.dirname(__file__), "../resources", file_name
1238-
)
1239-
logger.debug(f"Loading GPU shapes index from {local_path}.")
1240-
with open(local_path) as file_obj:
1241-
data = json.load(file_obj)
1242-
logger.debug("Successfully loaded GPU shapes index.")
1243-
except Exception as e:
1227+
with fsspec.open(storage_path, mode="r", **auth) as f:
1228+
remote_data = json.load(f)
12441229
logger.debug(
1245-
f"Failed to load GPU shapes index from {local_path}. Details: {e}"
1230+
"Loaded %d shapes from Object Storage",
1231+
len(remote_data.get("shapes", {})),
12461232
)
1233+
except Exception as ex:
1234+
logger.debug("Remote load failed (%s); falling back to local", ex)
1235+
1236+
# Load local copy
1237+
local_data: Dict[str, Any] = {}
1238+
local_path = os.path.join(os.path.dirname(__file__), "../resources", file_name)
1239+
try:
1240+
logger.debug("Loading GPU shapes index from local file: %s", local_path)
1241+
with open(local_path) as f:
1242+
local_data = json.load(f)
1243+
logger.debug(
1244+
"Loaded %d shapes from local file", len(local_data.get("shapes", {}))
1245+
)
1246+
except Exception as ex:
1247+
logger.debug("Local load GPU shapes index failed (%s)", ex)
1248+
1249+
# Merge: remote shapes override local
1250+
local_shapes = local_data.get("shapes", {})
1251+
remote_shapes = remote_data.get("shapes", {})
1252+
merged_shapes = {**local_shapes, **remote_shapes}
12471253

1248-
return GPUShapesIndex(**data)
1254+
return GPUShapesIndex(shapes=merged_shapes)
12491255

12501256

12511257
def get_preferred_compatible_family(selected_families: set[str]) -> str:

0 commit comments

Comments
 (0)