@@ -1158,9 +1158,11 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
1158
1158
1159
1159
1160
1160
def build_pydantic_error_message (ex : ValidationError ):
1161
- """Added to handle error messages from pydantic model validator.
1161
+ """
1162
+ Added to handle error messages from pydantic model validator.
1162
1163
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1163
- message using msg field."""
1164
+ message using msg field.
1165
+ """
1164
1166
1165
1167
return {
1166
1168
"." .join (map (str , e ["loc" ])): e ["msg" ]
@@ -1185,67 +1187,71 @@ def is_pydantic_model(obj: object) -> bool:
1185
1187
1186
1188
@cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 5 ), timer = datetime .now ))
1187
1189
def load_gpu_shapes_index (
1188
- auth : Optional [Dict ] = None ,
1190
+ auth : Optional [Dict [ str , Any ] ] = None ,
1189
1191
) -> GPUShapesIndex :
1190
1192
"""
1191
- Loads the GPU shapes index from Object Storage or a local resource folder .
1193
+ Load the GPU shapes index, preferring the OS bucket copy over the local one .
1192
1194
1193
- The function first attempts to load the file from an Object Storage bucket using fsspec.
1194
- If the loading fails (due to connection issues, missing file, etc.), it falls back to
1195
- loading the index from a local file.
1195
+ Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
1196
+ if that succeeds, those entries will override the local defaults.
1196
1197
1197
1198
Parameters
1198
1199
----------
1199
- auth: (Dict, optional). Defaults to None.
1200
- The default authentication is set using `ads.set_auth` API. If you need to override the
1201
- default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1202
- authentication signer and kwargs required to instantiate IdentityClient object.
1200
+ auth
1201
+ Optional auth dict (as returned by `ads.common.auth.default_signer()`)
1202
+ to pass through to `fsspec.open()`.
1203
1203
1204
1204
Returns
1205
1205
-------
1206
- GPUShapesIndex: The parsed GPU shapes index.
1206
+ GPUShapesIndex
1207
+ Merged index where any shape present remotely supersedes the local entry.
1207
1208
1208
1209
Raises
1209
1210
------
1210
- FileNotFoundError: If the GPU shapes index cannot be found in either Object Storage or locally.
1211
- json.JSONDecodeError: If the JSON is malformed.
1211
+ json.JSONDecodeError
1212
+ If any of the JSON is malformed.
1212
1213
"""
1213
1214
file_name = "gpu_shapes_index.json"
1214
- data : Dict [str , Any ] = {}
1215
1215
1216
- # Check if the CONDA_BUCKET_NS environment variable is set.
1216
+ # Try remote load
1217
+ remote_data : Dict [str , Any ] = {}
1217
1218
if CONDA_BUCKET_NS :
1218
1219
try :
1219
1220
auth = auth or authutil .default_signer ()
1220
- # Construct the object storage path. Adjust bucket name and path as needed.
1221
1221
storage_path = (
1222
1222
f"oci://{ CONDA_BUCKET_NAME } @{ CONDA_BUCKET_NS } /service_pack/{ file_name } "
1223
1223
)
1224
- logger .debug ("Loading GPU shapes index from Object Storage" )
1225
- with fsspec .open (storage_path , mode = "r" , ** auth ) as file_obj :
1226
- data = json .load (file_obj )
1227
- logger .debug ("Successfully loaded GPU shapes index." )
1228
- except Exception as ex :
1229
1224
logger .debug (
1230
- f"Failed to load GPU shapes index from Object Storage. Details: { ex } "
1225
+ "Loading GPU shapes index from Object Storage: %s" , storage_path
1231
1226
)
1232
-
1233
- # If loading from Object Storage failed, load from the local resource folder.
1234
- if not data :
1235
- try :
1236
- local_path = os .path .join (
1237
- os .path .dirname (__file__ ), "../resources" , file_name
1238
- )
1239
- logger .debug (f"Loading GPU shapes index from { local_path } ." )
1240
- with open (local_path ) as file_obj :
1241
- data = json .load (file_obj )
1242
- logger .debug ("Successfully loaded GPU shapes index." )
1243
- except Exception as e :
1227
+ with fsspec .open (storage_path , mode = "r" , ** auth ) as f :
1228
+ remote_data = json .load (f )
1244
1229
logger .debug (
1245
- f"Failed to load GPU shapes index from { local_path } . Details: { e } "
1230
+ "Loaded %d shapes from Object Storage" ,
1231
+ len (remote_data .get ("shapes" , {})),
1246
1232
)
1233
+ except Exception as ex :
1234
+ logger .debug ("Remote load failed (%s); falling back to local" , ex )
1235
+
1236
+ # Load local copy
1237
+ local_data : Dict [str , Any ] = {}
1238
+ local_path = os .path .join (os .path .dirname (__file__ ), "../resources" , file_name )
1239
+ try :
1240
+ logger .debug ("Loading GPU shapes index from local file: %s" , local_path )
1241
+ with open (local_path ) as f :
1242
+ local_data = json .load (f )
1243
+ logger .debug (
1244
+ "Loaded %d shapes from local file" , len (local_data .get ("shapes" , {}))
1245
+ )
1246
+ except Exception as ex :
1247
+ logger .debug ("Local load GPU shapes index failed (%s)" , ex )
1248
+
1249
+ # Merge: remote shapes override local
1250
+ local_shapes = local_data .get ("shapes" , {})
1251
+ remote_shapes = remote_data .get ("shapes" , {})
1252
+ merged_shapes = {** local_shapes , ** remote_shapes }
1247
1253
1248
- return GPUShapesIndex (** data )
1254
+ return GPUShapesIndex (shapes = merged_shapes )
1249
1255
1250
1256
1251
1257
def get_preferred_compatible_family (selected_families : set [str ]) -> str :
0 commit comments