Skip to content

Commit e5f833a

Browse files
Add model manager as BackgroundService for handling machine learning models
Signed-off-by: Idlir Shkurti <[email protected]>
1 parent ca89486 commit e5f833a

File tree

4 files changed

+275
-0
lines changed

4 files changed

+275
-0
lines changed

src/frequenz/sdk/ml/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Model interface."""
5+
6+
from ._model_manager import ModelManager
7+
8+
__all__ = [
9+
"ModelManager",
10+
]
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Load, update, monitor and retrieve machine learning models."""
5+
6+
import asyncio
7+
import logging
8+
import pickle
9+
from dataclasses import dataclass
10+
from pathlib import Path
11+
from typing import Generic, TypeVar, cast
12+
13+
from frequenz.channels.file_watcher import EventType, FileWatcher
14+
from typing_extensions import override
15+
16+
from frequenz.sdk.actor import BackgroundService
17+
18+
_logger = logging.getLogger(__name__)
19+
20+
T = TypeVar("T")
21+
22+
23+
@dataclass
24+
class _Model(Generic[T]):
25+
"""Represent a machine learning model."""
26+
27+
data: T
28+
path: Path
29+
30+
31+
class ModelNotFoundError(Exception):
32+
"""Exception raised when a model is not found."""
33+
34+
def __init__(self, key: str) -> None:
35+
"""Initialize the exception with the specified model key.
36+
37+
Args:
38+
key: The key of the model that was not found.
39+
"""
40+
super().__init__(f"Model with key '{key}' is not found.")
41+
42+
43+
class ModelManager(BackgroundService, Generic[T]):
44+
"""Load, update, monitor and retrieve machine learning models."""
45+
46+
def __init__(self, model_paths: dict[str, Path], *, name: str | None = None):
47+
"""Initialize the model manager with the specified model paths.
48+
49+
Args:
50+
model_paths: A dictionary of model keys and their corresponding file paths.
51+
name: The name of the model manager service.
52+
"""
53+
super().__init__(name=name)
54+
self._models: dict[str, _Model[T]] = {}
55+
self.model_paths = model_paths
56+
self.load_models()
57+
58+
def load_models(self) -> None:
59+
"""Load the models from the specified paths."""
60+
for key, path in self.model_paths.items():
61+
self._models[key] = _Model(data=self._load(path), path=path)
62+
63+
@staticmethod
64+
def _load(path: Path) -> T:
65+
"""Load the model from the specified path.
66+
67+
Args:
68+
path: The path to the model file.
69+
70+
Returns:
71+
T: The loaded model data.
72+
73+
Raises:
74+
ModelNotFoundError: If the model file does not exist.
75+
"""
76+
try:
77+
with path.open("rb") as file:
78+
return cast(T, pickle.load(file))
79+
except FileNotFoundError as exc:
80+
raise ModelNotFoundError(str(path)) from exc
81+
82+
@override
83+
def start(self) -> None:
84+
"""Start the model monitoring service by creating a background task."""
85+
if not self.is_running:
86+
task = asyncio.create_task(self._monitor_paths())
87+
self._tasks.add(task)
88+
_logger.info(
89+
"%s: Started ModelManager service with task %s",
90+
self.name,
91+
task,
92+
)
93+
94+
async def _monitor_paths(self) -> None:
95+
"""Monitor model file paths and reload models as necessary."""
96+
model_paths = [model.path for model in self._models.values()]
97+
file_watcher = FileWatcher(
98+
paths=list(model_paths), event_types=[EventType.CREATE, EventType.MODIFY]
99+
)
100+
_logger.info("%s: Monitoring model paths for changes.", self.name)
101+
async for event in file_watcher:
102+
_logger.info(
103+
"%s: Reloading model from file %s due to a %s event...",
104+
self.name,
105+
event.path,
106+
event.type.name,
107+
)
108+
self.reload_model(Path(event.path))
109+
110+
def reload_model(self, path: Path) -> None:
111+
"""Reload the model from the specified path.
112+
113+
Args:
114+
path: The path to the model file.
115+
"""
116+
for key, model in self._models.items():
117+
if model.path == path:
118+
try:
119+
model.data = self._load(path)
120+
_logger.info(
121+
"%s: Successfully reloaded model from %s",
122+
self.name,
123+
path,
124+
)
125+
except Exception: # pylint: disable=broad-except
126+
_logger.exception("Failed to reload model from %s", path)
127+
128+
def get_model(self, key: str) -> T:
129+
"""Retrieve a loaded model by key.
130+
131+
Args:
132+
key: The key of the model to retrieve.
133+
134+
Returns:
135+
The loaded model data.
136+
137+
Raises:
138+
KeyError: If the model with the specified key is not found.
139+
"""
140+
try:
141+
return self._models[key].data
142+
except KeyError as exc:
143+
raise KeyError(f"Model with key '{key}' is not found.") from exc

tests/model/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for the model package."""

tests/model/test_model_manager.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for machine learning model manager."""
5+
6+
import pickle
7+
from dataclasses import dataclass
8+
from pathlib import Path
9+
from typing import Any
10+
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
11+
12+
import pytest
13+
14+
from frequenz.sdk.ml import ModelManager
15+
16+
17+
@dataclass
18+
class MockModel:
19+
"""Mock model for unit testing purposes."""
20+
21+
data: int | str
22+
23+
def predict(self) -> int | str:
24+
"""Make a prediction based on the model data."""
25+
return self.data
26+
27+
28+
async def test_model_manager_loading() -> None:
29+
"""Test loading models using ModelManager with direct configuration."""
30+
model1 = MockModel("Model 1 Data")
31+
model2 = MockModel("Model 2 Data")
32+
pickled_model1 = pickle.dumps(model1)
33+
pickled_model2 = pickle.dumps(model2)
34+
35+
model_paths = {
36+
"model1": Path("path/to/model1.pkl"),
37+
"model2": Path("path/to/model2.pkl"),
38+
}
39+
40+
mock_files = {
41+
"path/to/model1.pkl": mock_open(read_data=pickled_model1)(),
42+
"path/to/model2.pkl": mock_open(read_data=pickled_model2)(),
43+
}
44+
45+
def mock_open_func(file_path: Path, *__args: Any, **__kwargs: Any) -> Any:
46+
"""Mock open function to return the correct mock file object.
47+
48+
Args:
49+
file_path: The path to the file to open.
50+
*__args: Variable length argument list. This can be used to pass additional
51+
positional parameters typically used in file opening operations,
52+
such as `mode` or `buffering`.
53+
**__kwargs: Arbitrary keyword arguments. This can include parameters like
54+
`encoding` and `errors`, common in file opening operations.
55+
56+
Returns:
57+
Any: The mock file object.
58+
59+
Raises:
60+
FileNotFoundError: If the file path is not in the mock files dictionary.
61+
"""
62+
file_path_str = str(file_path)
63+
if file_path_str in mock_files:
64+
file_handle = MagicMock()
65+
file_handle.__enter__.return_value = mock_files[file_path_str]
66+
return file_handle
67+
raise FileNotFoundError(f"No mock setup for {file_path_str}")
68+
69+
with patch("pathlib.Path.open", new=mock_open_func):
70+
with patch.object(Path, "exists", return_value=True):
71+
model_manager: ModelManager[MockModel] = ModelManager(
72+
model_paths=model_paths
73+
)
74+
75+
with patch(
76+
"frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock
77+
):
78+
model_manager.start() # Start the service
79+
80+
assert isinstance(model_manager.get_model("model1"), MockModel)
81+
assert model_manager.get_model("model1").data == "Model 1 Data"
82+
assert model_manager.get_model("model2").data == "Model 2 Data"
83+
84+
with pytest.raises(KeyError):
85+
model_manager.get_model("key3")
86+
87+
await model_manager.stop() # Stop the service to clean up
88+
89+
90+
async def test_model_manager_update() -> None:
91+
"""Test updating a model in ModelManager."""
92+
original_model = MockModel("Original Data")
93+
updated_model = MockModel("Updated Data")
94+
pickled_original_model = pickle.dumps(original_model)
95+
pickled_updated_model = pickle.dumps(updated_model)
96+
97+
model_paths = {"model1": Path("path/to/model1.pkl")}
98+
99+
mock_file = mock_open(read_data=pickled_original_model)
100+
with (
101+
patch("pathlib.Path.open", mock_file),
102+
patch.object(Path, "exists", return_value=True),
103+
):
104+
model_manager = ModelManager[MockModel](model_paths=model_paths)
105+
with patch(
106+
"frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock
107+
):
108+
model_manager.start() # Start the service
109+
110+
assert model_manager.get_model("model1").data == "Original Data"
111+
112+
# Simulate updating the model file
113+
mock_file.return_value.read.return_value = pickled_updated_model
114+
with patch("pathlib.Path.open", mock_file):
115+
model_manager.reload_model(Path("path/to/model1.pkl"))
116+
assert model_manager.get_model("model1").data == "Updated Data"
117+
118+
await model_manager.stop() # Stop the service to clean up

0 commit comments

Comments
 (0)