Skip to content

Commit 4bb5fb6

Browse files
elizjomrDzurb
andauthored
[AQUA] GPU Shape Recommendation (#1221)
Co-authored-by: Dmitrii Cherkasov <[email protected]>
1 parent 9652d9c commit 4bb5fb6

25 files changed

+3420
-65
lines changed

ads/aqua/cli.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,20 @@ def _validate_value(flag, value):
9696
"If you intend to chain a function call to the result, please separate the "
9797
"flag and the subsequent function call with separator `-`."
9898
)
99-
99+
100100
@staticmethod
101101
def install():
102102
"""Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path.
103103
104-
Return
104+
Return
105105
------
106106
int:
107107
Installatation status.
108108
"""
109109
import subprocess
110110

111-
wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl")
112-
status = subprocess.run(f"pip install {wheel_file_path}",shell=True)
113-
return status.check_returncode
111+
wheel_file_path = os.environ.get(
112+
"AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl"
113+
)
114+
status = subprocess.run(f"pip install {wheel_file_path}", shell=True, check=False)
115+
return status.check_returncode

ads/aqua/common/entities.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ class Config:
4646
arbitrary_types_allowed = True
4747
protected_namespaces = ()
4848

49+
class ComputeRank(Serializable):
50+
"""
51+
Represents the cost and performance ranking for a compute shape.
52+
"""
53+
cost: int = Field(
54+
None, description="The relative rank of the cost of the shape. Range is [10 (cost-effective), 100 (most-expensive)]"
55+
)
56+
57+
performance: int = Field(
58+
None, description="The relative rank of the performance of the shape. Range is [10 (lower performance), 110 (highest performance)]"
59+
)
4960

5061
class GPUSpecs(Serializable):
5162
"""
@@ -61,6 +72,12 @@ class GPUSpecs(Serializable):
6172
gpu_type: Optional[str] = Field(
6273
default=None, description="The type of GPU (e.g., 'V100, A100, H100')."
6374
)
75+
quantization: Optional[List[str]] = Field(
76+
default_factory=list, description="The quantization format supported by shape. (ex. bitsandbytes, fp8, etc.)"
77+
)
78+
ranking: Optional[ComputeRank] = Field(
79+
None, description="The relative rank of the cost and performance of the shape."
80+
)
6481

6582

6683
class GPUShapesIndex(Serializable):
@@ -84,6 +101,10 @@ class ComputeShapeSummary(Serializable):
84101
including CPU, memory, and optional GPU characteristics.
85102
"""
86103

104+
available: Optional[bool] = Field(
105+
default = False,
106+
description="True if shape is available on user tenancy, "
107+
)
87108
core_count: Optional[int] = Field(
88109
default=None,
89110
description="Total number of CPU cores available for the compute shape.",

ads/aqua/common/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ class AquaValueError(AquaError, ValueError):
5555
def __init__(self, reason, status=403, service_payload=None):
5656
super().__init__(reason, status, service_payload)
5757

58+
class AquaRecommendationError(AquaError):
59+
"""Exception raised for models incompatible with shape recommendation tool."""
60+
61+
def __init__(self, reason, status=400, service_payload=None):
62+
super().__init__(reason, status, service_payload)
5863

5964
class AquaFileNotFoundError(AquaError, FileNotFoundError):
6065
"""Exception raised for missing target file."""

ads/aqua/common/utils.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,10 +1267,10 @@ def load_gpu_shapes_index(
12671267
auth: Optional[Dict[str, Any]] = None,
12681268
) -> GPUShapesIndex:
12691269
"""
1270-
Load the GPU shapes index, preferring the OS bucket copy over the local one.
1270+
Load the GPU shapes index, merging based on freshness.
12711271
1272-
Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
1273-
if that succeeds, those entries will override the local defaults.
1272+
Compares last-modified timestamps of local and remote files,
1273+
merging the shapes from the fresher file on top of the older one.
12741274
12751275
Parameters
12761276
----------
@@ -1291,7 +1291,9 @@ def load_gpu_shapes_index(
12911291
file_name = "gpu_shapes_index.json"
12921292

12931293
# Try remote load
1294-
remote_data: Dict[str, Any] = {}
1294+
local_data, remote_data = {}, {}
1295+
local_mtime, remote_mtime = None, None
1296+
12951297
if CONDA_BUCKET_NS:
12961298
try:
12971299
auth = auth or authutil.default_signer()
@@ -1301,8 +1303,24 @@ def load_gpu_shapes_index(
13011303
logger.debug(
13021304
"Loading GPU shapes index from Object Storage: %s", storage_path
13031305
)
1304-
with fsspec.open(storage_path, mode="r", **auth) as f:
1306+
1307+
fs = fsspec.filesystem("oci", **auth)
1308+
with fs.open(storage_path, mode="r") as f:
13051309
remote_data = json.load(f)
1310+
1311+
remote_info = fs.info(storage_path)
1312+
remote_mtime_str = remote_info.get("timeModified", None)
1313+
if remote_mtime_str:
1314+
# Convert OCI timestamp (e.g., 'Mon, 04 Aug 2025 06:37:13 GMT') to epoch time
1315+
remote_mtime = datetime.strptime(
1316+
remote_mtime_str, "%a, %d %b %Y %H:%M:%S %Z"
1317+
).timestamp()
1318+
1319+
logger.debug(
1320+
"Remote GPU shapes last-modified time: %s",
1321+
datetime.fromtimestamp(remote_mtime).strftime("%Y-%m-%d %H:%M:%S"),
1322+
)
1323+
13061324
logger.debug(
13071325
"Loaded %d shapes from Object Storage",
13081326
len(remote_data.get("shapes", {})),
@@ -1311,12 +1329,19 @@ def load_gpu_shapes_index(
13111329
logger.debug("Remote load failed (%s); falling back to local", ex)
13121330

13131331
# Load local copy
1314-
local_data: Dict[str, Any] = {}
13151332
local_path = os.path.join(os.path.dirname(__file__), "../resources", file_name)
13161333
try:
13171334
logger.debug("Loading GPU shapes index from local file: %s", local_path)
13181335
with open(local_path) as f:
13191336
local_data = json.load(f)
1337+
1338+
local_mtime = os.path.getmtime(local_path)
1339+
1340+
logger.debug(
1341+
"Local GPU shapes last-modified time: %s",
1342+
datetime.fromtimestamp(local_mtime).strftime("%Y-%m-%d %H:%M:%S"),
1343+
)
1344+
13201345
logger.debug(
13211346
"Loaded %d shapes from local file", len(local_data.get("shapes", {}))
13221347
)
@@ -1326,7 +1351,24 @@ def load_gpu_shapes_index(
13261351
# Merge: remote shapes override local
13271352
local_shapes = local_data.get("shapes", {})
13281353
remote_shapes = remote_data.get("shapes", {})
1329-
merged_shapes = {**local_shapes, **remote_shapes}
1354+
merged_shapes = {}
1355+
1356+
if local_mtime and remote_mtime:
1357+
if remote_mtime >= local_mtime:
1358+
logger.debug("Remote data is fresher or equal; merging remote over local.")
1359+
merged_shapes = {**local_shapes, **remote_shapes}
1360+
else:
1361+
logger.debug("Local data is fresher; merging local over remote.")
1362+
merged_shapes = {**remote_shapes, **local_shapes}
1363+
elif remote_shapes:
1364+
logger.debug("Only remote shapes available.")
1365+
merged_shapes = remote_shapes
1366+
elif local_shapes:
1367+
logger.debug("Only local shapes available.")
1368+
merged_shapes = local_shapes
1369+
else:
1370+
logger.error("No GPU shapes data found in either source.")
1371+
merged_shapes = {}
13301372

13311373
return GPUShapesIndex(shapes=merged_shapes)
13321374

ads/aqua/extension/deployment_handler.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ def get(self, id: Union[str, List[str]] = None):
5757
return self.get_deployment_config(
5858
model_id=id.split(",") if "," in id else id
5959
)
60+
elif paths.startswith("aqua/deployments/recommend_shapes"):
61+
if not id or not isinstance(id, str):
62+
raise HTTPError(
63+
400,
64+
f"Invalid request format for {self.request.path}. "
65+
"Expected a single model OCID specified as --model_id",
66+
)
67+
id = id.replace(" ", "")
68+
return self.get_recommend_shape(model_id=id)
6069
elif paths.startswith("aqua/deployments/shapes"):
6170
return self.list_shapes()
6271
elif paths.startswith("aqua/deployments"):
@@ -161,6 +170,32 @@ def get_deployment_config(self, model_id: Union[str, List[str]]):
161170

162171
return self.finish(deployment_config)
163172

173+
def get_recommend_shape(self, model_id: str):
174+
"""
175+
Retrieves the valid shape and deployment parameter configuration for one Aqua Model.
176+
177+
Parameters
178+
----------
179+
model_id : str
180+
A single model ID (str).
181+
182+
Returns
183+
-------
184+
None
185+
The function sends the ShapeRecommendReport (generate_table = False) or Rich Diff Table (generate_table = True)
186+
"""
187+
app = AquaDeploymentApp()
188+
189+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
190+
191+
recommend_report = app.recommend_shape(
192+
model_id=model_id,
193+
compartment_id=compartment_id,
194+
generate_table=False,
195+
)
196+
197+
return self.finish(recommend_report)
198+
164199
def list_shapes(self):
165200
"""
166201
Lists the valid model deployment shapes.
@@ -408,6 +443,7 @@ def get(self, model_deployment_id):
408443
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
409444
("deployments/config/?([^/]*)", AquaDeploymentHandler),
410445
("deployments/shapes/?([^/]*)", AquaDeploymentHandler),
446+
("deployments/recommend_shapes/?([^/]*)", AquaDeploymentHandler),
411447
("deployments/?([^/]*)", AquaDeploymentHandler),
412448
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
413449
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),

ads/aqua/modeldeployment/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111

1212
DEFAULT_WAIT_TIME = 12000
1313
DEFAULT_POLL_INTERVAL = 10
14+

ads/aqua/modeldeployment/deployment.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
import shlex
99
import threading
1010
from datetime import datetime, timedelta
11-
from typing import Dict, List, Optional
11+
from typing import Dict, List, Optional, Union
1212

1313
from cachetools import TTLCache, cached
1414
from oci.data_science.models import ModelDeploymentShapeSummary
1515
from pydantic import ValidationError
16+
from rich.table import Table
1617

1718
from ads.aqua.app import AquaApp, logger
1819
from ads.aqua.common.entities import (
@@ -67,14 +68,22 @@
6768
ModelDeploymentConfigSummary,
6869
MultiModelDeploymentConfigLoader,
6970
)
70-
from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME
71+
from ads.aqua.modeldeployment.constants import (
72+
DEFAULT_POLL_INTERVAL,
73+
DEFAULT_WAIT_TIME,
74+
)
7175
from ads.aqua.modeldeployment.entities import (
7276
AquaDeployment,
7377
AquaDeploymentDetail,
7478
ConfigValidationError,
7579
CreateModelDeploymentDetails,
7680
)
7781
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
82+
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
83+
from ads.aqua.shaperecommend.shape_report import (
84+
RequestRecommend,
85+
ShapeRecommendationReport,
86+
)
7887
from ads.common.object_storage_details import ObjectStorageDetails
7988
from ads.common.utils import UNKNOWN, get_log_links
8089
from ads.common.work_request import DataScienceWorkRequest
@@ -1257,6 +1266,50 @@ def validate_deployment_params(
12571266
)
12581267
return {"valid": True}
12591268

1269+
def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]:
1270+
"""
1271+
For the CLI (set generate_table = True), generates the table (in rich diff) with valid
1272+
GPU deployment shapes for the provided model and configuration.
1273+
1274+
For the API (set generate_table = False), generates the JSON with valid
1275+
GPU deployment shapes for the provided model and configuration.
1276+
1277+
Validates if recommendations are generated, calls method to construct the rich diff
1278+
table with the recommendation data.
1279+
1280+
Parameters
1281+
----------
1282+
model_ocid : str
1283+
OCID of the model to recommend feasible compute shapes.
1284+
1285+
Returns
1286+
-------
1287+
Table (generate_table = True)
1288+
A table format for the recommendation report with compatible deployment shapes
1289+
or troubleshooting info citing the largest shapes if no shape is suitable.
1290+
1291+
ShapeRecommendationReport (generate_table = False)
1292+
A recommendation report with compatible deployment shapes, or troubleshooting info
1293+
citing the largest shapes if no shape is suitable.
1294+
1295+
Raises
1296+
------
1297+
AquaValueError
1298+
If model type is unsupported by tool (no recommendation report generated)
1299+
"""
1300+
try:
1301+
request = RequestRecommend(**kwargs)
1302+
except ValidationError as e:
1303+
custom_error = build_pydantic_error_message(e)
1304+
raise AquaValueError( # noqa: B904
1305+
f"Failed to request shape recommendation due to invalid input parameters: {custom_error}"
1306+
)
1307+
1308+
shape_recommend = AquaShapeRecommend()
1309+
shape_recommend_report = shape_recommend.which_shapes(request)
1310+
1311+
return shape_recommend_report
1312+
12601313
@telemetry(entry_point="plugin=deployment&action=list_shapes", name="aqua")
12611314
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
12621315
def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:

0 commit comments

Comments
 (0)