Skip to content

Commit 21ba00b

Browse files
authored
Real-Time What-If Analysis with Model Catalog and Model Deployment Integration (#1041)
2 parents 5451c2c + 4a5e552 commit 21ba00b

File tree

10 files changed

+658
-4
lines changed

10 files changed

+658
-4
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import fsspec
1414
import oracledb
15+
import json
1516
import pandas as pd
1617

1718
from ads.common.object_storage_details import ObjectStorageDetails
@@ -125,7 +126,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
125126
return data
126127

127128

128-
def write_data(data, filename, format, storage_options, index=False, **kwargs):
129+
def write_data(data, filename, format, storage_options=None, index=False, **kwargs):
129130
disable_print()
130131
if not format:
131132
_, format = os.path.splitext(filename)
@@ -141,6 +142,15 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
141142
)
142143

143144

145+
def write_simple_json(data, path):
146+
if ObjectStorageDetails.is_oci_path(path):
147+
storage_options = default_signer()
148+
else:
149+
storage_options = {}
150+
with fsspec.open(path, mode="w", **storage_options) as f:
151+
json.dump(data, f, indent=4)
152+
153+
144154
def merge_category_columns(data, target_category_columns):
145155
result = data.apply(
146156
lambda x: "__".join([str(x[col]) for col in target_category_columns]), axis=1

ads/opctl/operator/lowcode/forecast/__main__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from .operator_config import ForecastOperatorConfig
1919
from .model.forecast_datasets import ForecastDatasets
20+
from .whatifserve import ModelDeploymentManager
2021

2122

2223
def operate(operator_config: ForecastOperatorConfig) -> None:
@@ -27,6 +28,15 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
2728
ForecastOperatorModelFactory.get_model(
2829
operator_config, datasets
2930
).generate_report()
31+
# saving to model catalog
32+
spec = operator_config.spec
33+
if spec.what_if_analysis and datasets.additional_data:
34+
mdm = ModelDeploymentManager(spec, datasets.additional_data)
35+
mdm.save_to_catalog()
36+
if spec.what_if_analysis.model_deployment:
37+
mdm.create_deployment()
38+
mdm.save_deployment_info()
39+
3040

3141
def verify(spec: Dict, **kwargs: Dict) -> bool:
3242
"""Verifies the forecasting operator config."""

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def get_data_multi_indexed(self):
167167
self.historical_data.data,
168168
self.additional_data.data,
169169
],
170-
axis=1,
170+
axis=1
171171
)
172172

173173
def get_data_by_series(self, include_horizon=True):

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,35 @@
1818

1919
from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
2020

21+
@dataclass
22+
class AutoScaling(DataClassSerializable):
23+
"""Class representing simple autoscaling policy"""
24+
minimum_instance: int = 1
25+
maximum_instance: int = None
26+
cool_down_in_seconds: int = 600
27+
scale_in_threshold: int = 10
28+
scale_out_threshold: int = 80
29+
scaling_metric: str = "CPU_UTILIZATION"
30+
31+
@dataclass(repr=True)
32+
class ModelDeploymentServer(DataClassSerializable):
33+
"""Class representing model deployment server specification for whatif-analysis."""
34+
display_name: str = None
35+
initial_shape: str = None
36+
description: str = None
37+
log_group: str = None
38+
log_id: str = None
39+
auto_scaling: AutoScaling = field(default_factory=AutoScaling)
40+
41+
42+
@dataclass(repr=True)
43+
class WhatIfAnalysis(DataClassSerializable):
44+
"""Class representing operator specification for whatif-analysis."""
45+
model_display_name: str = None
46+
compartment_id: str = None
47+
project_id: str = None
48+
model_deployment: ModelDeploymentServer = field(default_factory=ModelDeploymentServer)
49+
2150

2251
@dataclass(repr=True)
2352
class TestData(InputData):
@@ -90,12 +119,14 @@ class ForecastOperatorSpec(DataClassSerializable):
90119
confidence_interval_width: float = None
91120
metric: str = None
92121
tuning: Tuning = field(default_factory=Tuning)
122+
what_if_analysis: WhatIfAnalysis = field(default_factory=WhatIfAnalysis)
93123

94124
def __post_init__(self):
95125
"""Adjusts the specification details."""
96126
self.output_directory = self.output_directory or OutputDirectory(
97127
url=find_output_dirname(self.output_directory)
98128
)
129+
self.generate_model_pickle = True if self.generate_model_pickle or self.what_if_analysis else False
99130
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
100131
self.model = self.model or SupportedModels.Prophet
101132
self.confidence_interval_width = self.confidence_interval_width or 0.80

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,69 @@ spec:
353353
meta:
354354
description: "Report file generation can be enabled using this flag. Defaults to true."
355355

356+
what_if_analysis:
357+
type: dict
358+
required: false
359+
schema:
360+
model_deployment:
361+
type: dict
362+
required: false
363+
meta: "If model_deployment id is not specified, a new model deployment is created; otherwise, the model is linked to the specified model deployment."
364+
schema:
365+
id:
366+
type: string
367+
required: false
368+
display_name:
369+
type: string
370+
required: false
371+
initial_shape:
372+
type: string
373+
required: false
374+
description:
375+
type: string
376+
required: false
377+
log_group:
378+
type: string
379+
required: true
380+
log_id:
381+
type: string
382+
required: true
383+
auto_scaling:
384+
type: dict
385+
required: false
386+
schema:
387+
minimum_instance:
388+
type: integer
389+
required: true
390+
maximum_instance:
391+
type: integer
392+
required: true
393+
scale_in_threshold:
394+
type: integer
395+
required: true
396+
scale_out_threshold:
397+
type: integer
398+
required: true
399+
scaling_metric:
400+
type: string
401+
required: true
402+
cool_down_in_seconds:
403+
type: integer
404+
required: true
405+
model_display_name:
406+
type: string
407+
required: true
408+
project_id:
409+
type: string
410+
required: false
411+
meta: "If not provided, The project OCID from config.PROJECT_OCID is used"
412+
compartment_id:
413+
type: string
414+
required: false
415+
meta: "If not provided, The compartment OCID from config.NB_SESSION_COMPARTMENT_OCID is used."
416+
meta:
417+
description: "When enabled, the models are saved to the model catalog. Defaults to false."
418+
356419
generate_metrics:
357420
type: boolean
358421
required: false
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
7+
from .deployment_manager import ModelDeploymentManager

0 commit comments

Comments
 (0)