Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new model ids 2 #929

Merged
merged 20 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion inference/core/entities/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
DatasetID = str
VersionID = str
ModelID = str
VersionID = int
TaskType = str
ModelType = str
WorkspaceID = str
13 changes: 13 additions & 0 deletions inference/core/managers/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT
from inference.core.managers.base import ModelManager
from inference.core.registries.base import ModelRegistry
from inference.core.utils.roboflow import get_model_id_chunks
from inference.models.aliases import resolve_roboflow_model_alias

ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible"
Expand Down Expand Up @@ -39,10 +40,14 @@ async def infer_from_request(
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
# TODO: active learning is disabled for instant models; to be enabled in the future
_, version_id = get_model_id_chunks(model_id=model_id)
roboflow_instant_model = version_id is None
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
Expand All @@ -58,10 +63,14 @@ def infer_from_request_sync(
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
# TODO: active learning is disabled for instant models; to be enabled in the future
_, version_id = get_model_id_chunks(model_id=model_id)
roboflow_instant_model = version_id is None
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
Expand Down Expand Up @@ -196,10 +205,14 @@ def infer_from_request_sync(
prediction = super().infer_from_request_sync(
model_id=model_id, request=request, **kwargs
)
# TODO: active learning is disabled for instant models; to be enabled in the future
_, version_id = get_model_id_chunks(model_id=model_id)
roboflow_instant_model = version_id is None
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
or roboflow_instant_model
):
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
Expand Down
3 changes: 2 additions & 1 deletion inference/core/models/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from inference.core.utils.image_utils import load_image
from inference.core.utils.onnx import get_onnxruntime_execution_providers
from inference.core.utils.preprocess import letterbox_image, prepare
from inference.core.utils.roboflow import get_model_id_chunks
from inference.core.utils.visualisation import draw_detection_predictions
from inference.models.aliases import resolve_roboflow_model_alias

Expand Down Expand Up @@ -116,7 +117,7 @@ def __init__(
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
self.api_key = api_key if api_key else API_KEY
model_id = resolve_roboflow_model_alias(model_id=model_id)
self.dataset_id, self.version_id = model_id.split("/")
self.dataset_id, self.version_id = get_model_id_chunks(model_id=model_id)
self.endpoint = model_id
self.device_id = GLOBAL_DEVICE_ID
self.cache_dir = os.path.join(cache_dir_root, self.endpoint)
Expand Down
3 changes: 2 additions & 1 deletion inference/core/models/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from inference.core.models.base import Model
from inference.core.models.types import PreprocessReturnMetadata
from inference.core.utils.image_utils import encode_image_to_jpeg_bytes
from inference.core.utils.roboflow import get_model_id_chunks


class ModelStub(Model):
def __init__(self, model_id: str, api_key: str):
super().__init__()
self.model_id = model_id
self.api_key = api_key
self.dataset_id, self.version_id = model_id.split("/")
self.dataset_id, self.version_id = get_model_id_chunks(model_id=model_id)
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
initialise_cache(model_id=model_id)

Expand Down
56 changes: 38 additions & 18 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

from inference.core.cache import cache
from inference.core.devices.utils import GLOBAL_DEVICE_ID
from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID
from inference.core.entities.types import (
DatasetID,
ModelID,
ModelType,
TaskType,
VersionID,
)
from inference.core.env import LAMBDA, MODEL_CACHE_DIR
from inference.core.exceptions import (
MissingApiKeyError,
Expand All @@ -19,6 +25,7 @@
PROJECT_TASK_TYPE_KEY,
ModelEndpointType,
get_roboflow_dataset_type,
get_roboflow_instant_model_data,
get_roboflow_model_data,
get_roboflow_workspace,
)
Expand Down Expand Up @@ -49,7 +56,7 @@ class RoboflowModelRegistry(ModelRegistry):
then returns a model class based on the model type.
"""

def get_model(self, model_id: str, api_key: str) -> Model:
def get_model(self, model_id: ModelID, api_key: str) -> Model:
"""Returns the model class based on the given model id and API key.

Args:
Expand All @@ -70,7 +77,7 @@ def get_model(self, model_id: str, api_key: str) -> Model:


def get_model_type(
model_id: str,
model_id: ModelID,
api_key: Optional[str] = None,
) -> Tuple[TaskType, ModelType]:
"""Retrieves the model type based on the given model ID and API key.
Expand Down Expand Up @@ -115,16 +122,24 @@ def get_model_type(
model_type=model_type,
)
return project_task_type, model_type
api_data = get_roboflow_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")

if version_id is not None:
api_data = get_roboflow_model_data(
api_key=api_key,
model_id=model_id,
endpoint_type=ModelEndpointType.ORT,
device_id=GLOBAL_DEVICE_ID,
).get("ort")
project_task_type = api_data.get("type", "object-detection")
else:
api_data = get_roboflow_instant_model_data(
api_key=api_key,
model_id=model_id,
)
project_task_type = api_data.get("taskType", "object-detection")
if api_data is None:
raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
# some older projects do not have type field - hence defaulting
project_task_type = api_data.get("type", "object-detection")
model_type = api_data.get("modelType")
if model_type is None or model_type == "ort":
# some very old model versions do not have modelType reported - and API respond in a generic way -
Expand All @@ -143,7 +158,8 @@ def get_model_type(


def get_model_metadata_from_cache(
dataset_id: str, version_id: str
dataset_id: Union[DatasetID, ModelID],
version_id: Optional[VersionID],
) -> Optional[Tuple[TaskType, ModelType]]:
if LAMBDA:
return _get_model_metadata_from_cache(
Expand All @@ -158,7 +174,7 @@ def get_model_metadata_from_cache(


def _get_model_metadata_from_cache(
dataset_id: str, version_id: str
dataset_id: Union[DatasetID, ModelID], version_id: Optional[VersionID]
) -> Optional[Tuple[TaskType, ModelType]]:
model_type_cache_path = construct_model_type_cache_path(
dataset_id=dataset_id, version_id=version_id
Expand Down Expand Up @@ -193,8 +209,8 @@ def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> b


def save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
dataset_id: Union[DatasetID, ModelID],
version_id: Optional[VersionID],
project_task_type: TaskType,
model_type: ModelType,
) -> None:
Expand All @@ -219,8 +235,8 @@ def save_model_metadata_in_cache(


def _save_model_metadata_in_cache(
dataset_id: DatasetID,
version_id: VersionID,
dataset_id: Union[DatasetID, ModelID],
version_id: Optional[VersionID],
project_task_type: TaskType,
model_type: ModelType,
) -> None:
Expand All @@ -236,6 +252,10 @@ def _save_model_metadata_in_cache(
)


def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str:
cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id)
def construct_model_type_cache_path(
dataset_id: Union[DatasetID, ModelID], version_id: Optional[VersionID]
) -> str:
cache_dir = os.path.join(
MODEL_CACHE_DIR, dataset_id, version_id if version_id else ""
)
return os.path.join(cache_dir, "model_type.json")
34 changes: 34 additions & 0 deletions inference/core/roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from inference.core.cache.base import BaseCache
from inference.core.entities.types import (
DatasetID,
ModelID,
ModelType,
TaskType,
VersionID,
Expand Down Expand Up @@ -246,6 +247,39 @@ def get_roboflow_model_data(
return api_data


@wrap_roboflow_api_errors()
def get_roboflow_instant_model_data(
api_key: str,
model_id: ModelID,
cache_prefix: str = "roboflow_api_data",
) -> dict:
api_data_cache_key = f"{cache_prefix}:{model_id}"
api_data = cache.get(api_data_cache_key)
if api_data is not None:
logger.debug(f"Loaded model data from cache with key: {api_data_cache_key}.")
return api_data
else:
params = [
("model", model_id),
]
if api_key is not None:
params.append(("api_key", api_key))
api_url = _add_params_to_url(
url=f"{API_BASE_URL}/getWeights",
params=params,
)
api_data = _get_from_url(url=api_url)
cache.set(
api_data_cache_key,
api_data,
expire=10,
)
logger.debug(
f"Loaded model data from Roboflow API and saved to cache with key: {api_data_cache_key}."
)
return api_data


@wrap_roboflow_api_errors()
def get_roboflow_base_lora(
api_key: str, repo: str, revision: str, device_id: str
Expand Down
29 changes: 25 additions & 4 deletions inference/core/utils/roboflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
from typing import Tuple
from typing import Optional, Tuple, Union

from inference.core.entities.types import DatasetID, VersionID
from inference.core.entities.types import DatasetID, ModelID, VersionID
from inference.core.exceptions import InvalidModelIDError


def get_model_id_chunks(model_id: str) -> Tuple[DatasetID, VersionID]:
def get_model_id_chunks(
model_id: str,
) -> Tuple[Union[DatasetID, ModelID], Optional[VersionID]]:
model_id_chunks = model_id.split("/")
if len(model_id_chunks) != 2:
raise InvalidModelIDError(f"Model ID: `{model_id}` is invalid.")
return model_id_chunks[0], model_id_chunks[1]
dataset_id, version_id = model_id_chunks[0], model_id_chunks[1]
if dataset_id.lower() in {
"clip",
"cogvlm",
"doctr",
"doctr_rec",
"doctr_det",
"gaze",
"grounding_dino",
"sam",
"sam2",
"owlv2",
"trocr",
"yolo_world",
}:
return dataset_id, version_id
try:
return dataset_id, str(int(version_id))
except Exception:
return model_id, None
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from typing import List, Literal, Optional, Type, Union

import paho.mqtt.client as mqtt
Expand Down Expand Up @@ -93,10 +94,23 @@ def describe_outputs(cls) -> List[OutputDefinition]:
OutputDefinition(name="message", kind=[STRING_KIND]),
]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.3.0,<2.0.0"


class MQTTWriterSinkBlockV1(WorkflowBlock):
def __init__(self):
self.mqtt_client: Optional[mqtt.Client] = None
self._connected = threading.Event()

def __del__(self):
try:
if self.mqtt_client is not None:
self.mqtt_client.disconnect()
self.mqtt_client.loop_stop()
except Exception as e:
logger.error("Failed to disconnect MQTT client: %s", e)

@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
Expand Down Expand Up @@ -125,7 +139,12 @@ def run(
)
try:
# TODO: blocking, consider adding fire_and_forget like in OPC writer
print("Connecting")
self.mqtt_client.connect(host, port)
self.mqtt_client.loop_start()

if not self._connected.wait(timeout=timeout):
raise Exception("Connection timeout")
except Exception as e:
logger.error("Failed to connect to MQTT broker: %s", e)
return {
Expand All @@ -136,7 +155,10 @@ def run(
if not self.mqtt_client.is_connected():
try:
# TODO: blocking
print("Reconnecting")
self.mqtt_client.reconnect()
if not self._connected.wait(timeout=timeout):
raise Exception("Connection timeout")
except Exception as e:
logger.error("Failed to connect to MQTT broker: %s", e)
return {
Expand All @@ -163,8 +185,10 @@ def run(

def mqtt_on_connect(self, client, userdata, flags, reason_code, properties=None):
logger.info("Connected with result code %s", reason_code)
self._connected.set()

def mqtt_on_connect_fail(
self, client, userdata, flags, reason_code, properties=None
):
logger.error(f"Failed to connect with result code %s", reason_code)
self._connected.clear()
Loading
Loading