|  | 
|  | 1 | +# License: MIT | 
|  | 2 | +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH | 
|  | 3 | + | 
|  | 4 | +"""Load, update, monitor and retrieve machine learning models.""" | 
|  | 5 | + | 
|  | 6 | +import asyncio | 
|  | 7 | +import logging | 
|  | 8 | +import pickle | 
|  | 9 | +from asyncio import CancelledError | 
|  | 10 | +from dataclasses import dataclass | 
|  | 11 | +from pathlib import Path | 
|  | 12 | +from typing import Generic, TypeVar, cast | 
|  | 13 | + | 
|  | 14 | +from frequenz.channels.file_watcher import EventType, FileWatcher | 
|  | 15 | + | 
|  | 16 | +from frequenz.sdk.actor import BackgroundService | 
|  | 17 | + | 
|  | 18 | +_logger = logging.getLogger(__name__) | 
|  | 19 | + | 
|  | 20 | +T = TypeVar("T") | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +@dataclass | 
|  | 24 | +class _Model(Generic[T]): | 
|  | 25 | +    """Represent a machine learning model.""" | 
|  | 26 | + | 
|  | 27 | +    data: T | 
|  | 28 | +    path: Path | 
|  | 29 | + | 
|  | 30 | + | 
|  | 31 | +class ModelManager(BackgroundService, Generic[T]): | 
|  | 32 | +    """Load, update, monitor and retrieve machine learning models.""" | 
|  | 33 | + | 
|  | 34 | +    def __init__(self, model_paths: dict[str, Path]): | 
|  | 35 | +        """Initialize the model manager with the specified model paths. | 
|  | 36 | +
 | 
|  | 37 | +        Args: | 
|  | 38 | +            model_paths: A dictionary of model keys and their corresponding file paths. | 
|  | 39 | +        """ | 
|  | 40 | +        super().__init__() | 
|  | 41 | +        self._models: dict[str, _Model[T]] = {} | 
|  | 42 | +        self.model_paths = model_paths | 
|  | 43 | +        self.load_models() | 
|  | 44 | + | 
|  | 45 | +    def load_models(self) -> None: | 
|  | 46 | +        """Load the models from the specified paths.""" | 
|  | 47 | +        for key, path in self.model_paths.items(): | 
|  | 48 | +            self._models[key] = _Model(data=self._load(path), path=path) | 
|  | 49 | + | 
|  | 50 | +    @staticmethod | 
|  | 51 | +    def _load(path: Path) -> T: | 
|  | 52 | +        """Load the model from the specified path. | 
|  | 53 | +
 | 
|  | 54 | +        Args: | 
|  | 55 | +            path: The path to the model file. | 
|  | 56 | +
 | 
|  | 57 | +        Returns: | 
|  | 58 | +            T: The loaded model data. | 
|  | 59 | +
 | 
|  | 60 | +        Raises: | 
|  | 61 | +            FileNotFoundError: If the model file does not exist. | 
|  | 62 | +        """ | 
|  | 63 | +        if not path.exists(): | 
|  | 64 | +            raise FileNotFoundError(f"The model path {path} does not exist.") | 
|  | 65 | +        with path.open("rb") as file: | 
|  | 66 | +            return cast(T, pickle.load(file)) | 
|  | 67 | + | 
|  | 68 | +    def start(self) -> None: | 
|  | 69 | +        """Start the model monitoring service by creating a background task.""" | 
|  | 70 | +        if not self.is_running: | 
|  | 71 | +            task = asyncio.create_task(self.run()) | 
|  | 72 | +            self._tasks.add(task) | 
|  | 73 | +            _logger.info("Started ModelManager service with task %s", task) | 
|  | 74 | + | 
|  | 75 | +    async def run(self) -> None: | 
|  | 76 | +        """Monitor model file paths and reload models as necessary.""" | 
|  | 77 | +        model_paths = [model.path for model in self._models.values()] | 
|  | 78 | +        file_watcher = FileWatcher(paths=list(model_paths)) | 
|  | 79 | +        _logger.info("Monitoring model paths for changes.") | 
|  | 80 | +        async for event in file_watcher: | 
|  | 81 | +            if event.type in (EventType.CREATE, EventType.MODIFY): | 
|  | 82 | +                _logger.info("Model file %s modified, reloading...", event.path) | 
|  | 83 | +                self.reload_model(Path(event.path)) | 
|  | 84 | + | 
|  | 85 | +    def reload_model(self, path: Path) -> None: | 
|  | 86 | +        """Reload the model from the specified path. | 
|  | 87 | +
 | 
|  | 88 | +        Args: | 
|  | 89 | +            path: The path to the model file. | 
|  | 90 | +        """ | 
|  | 91 | +        for key, model in self._models.items(): | 
|  | 92 | +            if model.path == path: | 
|  | 93 | +                try: | 
|  | 94 | +                    model.data = self._load(path) | 
|  | 95 | +                    _logger.info("Successfully reloaded model from %s", path) | 
|  | 96 | +                except Exception as e:  # pylint: disable=broad-except | 
|  | 97 | +                    _logger.error("Failed to reload model from %s: %s", path, e) | 
|  | 98 | + | 
|  | 99 | +    async def stop(self, msg: str | None = None) -> None: | 
|  | 100 | +        """Stop all model monitoring tasks with enhanced exception handling. | 
|  | 101 | +
 | 
|  | 102 | +        Args: | 
|  | 103 | +            msg: An optional message to log when stopping the service. | 
|  | 104 | +        """ | 
|  | 105 | +        _logger.info("Stopping ModelManager service: %s", msg) | 
|  | 106 | +        # Attempt to cancel all running tasks | 
|  | 107 | +        for task in list(self._tasks): | 
|  | 108 | +            task.cancel() | 
|  | 109 | +            try: | 
|  | 110 | +                await task  # Wait for task to be cancelled | 
|  | 111 | +            except CancelledError: | 
|  | 112 | +                _logger.info("Task %s cancelled successfully", task) | 
|  | 113 | +            except Exception as e:  # pylint: disable=broad-except | 
|  | 114 | +                _logger.error("Error while cancelling task %s: %s", task, e) | 
|  | 115 | + | 
|  | 116 | +        # Call the parent stop method if it handles additional teardown | 
|  | 117 | +        try: | 
|  | 118 | +            await super().stop(msg) | 
|  | 119 | +        except Exception as e:  # pylint: disable=broad-except | 
|  | 120 | +            _logger.error("Error during stop in superclass: %s", e) | 
|  | 121 | + | 
|  | 122 | +    def get_model(self, key: str) -> T: | 
|  | 123 | +        """Retrieve a loaded model by key. | 
|  | 124 | +
 | 
|  | 125 | +        Args: | 
|  | 126 | +            key: The key of the model to retrieve. | 
|  | 127 | +
 | 
|  | 128 | +        Returns: | 
|  | 129 | +            The loaded model data. | 
|  | 130 | +
 | 
|  | 131 | +        Raises: | 
|  | 132 | +            KeyError: If the model with the specified key is not found. | 
|  | 133 | +        """ | 
|  | 134 | +        try: | 
|  | 135 | +            return self._models[key].data | 
|  | 136 | +        except KeyError as exc: | 
|  | 137 | +            raise KeyError(f"Model with key '{key}' is not found.") from exc | 
0 commit comments