Skip to content

Integration of client & ml model id in end-to-end ml streaming flow #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 118 additions & 10 deletions pymilo/streaming/communicator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
"""PyMilo Communication Mediums."""
import uuid
import json
import asyncio
import uvicorn
import requests
import websockets
from enum import Enum
from pydantic import BaseModel
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, HTTPException
from .interfaces import ClientCommunicator
from .param import PYMILO_INVALID_URL, PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED
from .util import validate_websocket_url, validate_http_url
Expand Down Expand Up @@ -83,6 +84,46 @@ def attribute_type(self, payload):
response = self.session.post(url=self._server_url + "/attribute_type/", json=payload, timeout=5)
return response.json()

def register_client(self):
"""
Register client in the PyMiloServer.

:return: newly allocated client id
"""
response = self.session.get(url=self._server_url + "/register/", timeout=5)
return response.json()["client_id"]

def register_model(self, payload):
"""
Register ML model in the PyMiloServer.

:param payload: request payload
:type payload: dict
:return: newly allocated ml model id
"""
response = self.session.post(url=self._server_url + "/request_model_id/", json=payload, timeout=5)
return response.json()["ml_model_id"]

def get_clients(self):
"""
Get all clients registered in the PyMiloServer.

:return: list of client ids
"""
response = self.session.get(url=self._server_url + "/clients/", timeout=5)
return response.json()["clients_id"]

def get_ml_models(self, payload):
"""
Get all ML models registered for this specific client in the PyMiloServer.

:param payload: request payload
:type payload: dict
:return: list of ml model ids
"""
response = self.session.post(url=self._server_url + "/client/models/", json=payload, timeout=5)
return response.json()["ml_models_id"]


class RESTServerCommunicator():
"""Facilitate working with the communication medium from the server side for the REST protocol."""
Expand Down Expand Up @@ -130,35 +171,90 @@ class AttributeCallPayload(StandardPayload):
class AttributeTypePayload(StandardPayload):
attribute: str

@self.app.get("/register/")
async def register_client():
client_id = str(uuid.uuid4())
self._ps.init_client(client_id)
return {
"client_id": client_id
}

@self.app.post("/request_model_id/")
async def request_model(request: Request):
body = await request.json()
body = self.parse(body)
client_id = body["client_id"]
model_id = str(uuid.uuid4())
is_succeed, detail_message = self._ps.init_ml_model(client_id, model_id)
if not is_succeed:
raise HTTPException(status_code=404, detail=detail_message)
return {
"client_id": client_id,
"ml_model_id": model_id,
}

@self.app.get("/clients/")
async def get_client():
return {
"clients_id": self._ps.get_clients(),
}

@self.app.post("/client/models/")
async def get_client_models(request: Request):
body = await request.json()
body = self.parse(body)
client_id = body["client_id"]
return {
"client_id": client_id,
"ml_models_id": self._ps.get_ml_models(client_id),
}

@self.app.get("/download/")
async def download(request: Request):
body = await request.json()
body = self.parse(body)
payload = DownloadPayload(**body)
message = "/download request from client: {} for model: {}".format(payload.client_id, payload.ml_model_id)
client_id = payload.client_id
ml_model_id = payload.ml_model_id
is_valid, invalidity_reason = self._ps._validate_id(client_id, ml_model_id)
if not is_valid:
raise HTTPException(status_code=404, detail=invalidity_reason)
message = "/download request from client: {} for model: {}".format(client_id, ml_model_id)
return {
"message": message,
"payload": self._ps.export_model(),
"payload": self._ps.export_model(client_id, ml_model_id),
}

@self.app.post("/upload/")
async def upload(request: Request):
body = await request.json()
body = self.parse(body)
payload = UploadPayload(**body)
message = "/upload request from client: {} for model: {}".format(payload.client_id, payload.ml_model_id)
client_id = payload.client_id
ml_model_id = payload.ml_model_id
is_valid, invalidity_reason = self._ps._validate_id(client_id, ml_model_id)
if not is_valid:
raise HTTPException(status_code=404, detail=invalidity_reason)
message = "/upload request from client: {} for model: {}".format(client_id, ml_model_id)
return {
"message": message,
"payload": self._ps.update_model(payload.model)
"payload": self._ps.update_model(client_id, ml_model_id, payload.model)
}

@self.app.post("/attribute_call/")
async def attribute_call(request: Request):
body = await request.json()
body = self.parse(body)
payload = AttributeCallPayload(**body)
client_id = payload.client_id
ml_model_id = payload.ml_model_id
is_valid, invalidity_reason = self._ps._validate_id(client_id, ml_model_id)
if not is_valid:
raise HTTPException(status_code=404, detail=invalidity_reason)
message = "/attribute_call request from client: {} for model: {}".format(
payload.client_id, payload.ml_model_id)
client_id,
ml_model_id,
)
result = self._ps.execute_model(payload)
return {
"message": message,
Expand All @@ -170,8 +266,15 @@ async def attribute_type(request: Request):
body = await request.json()
body = self.parse(body)
payload = AttributeTypePayload(**body)
client_id = payload.client_id
ml_model_id = payload.ml_model_id
is_valid, invalidity_reason = self._ps._validate_id(client_id, ml_model_id)
if not is_valid:
raise HTTPException(status_code=404, detail=invalidity_reason)
message = "/attribute_type request from client: {} for model: {}".format(
payload.client_id, payload.ml_model_id)
client_id,
ml_model_id,
)
is_callable, field_value = self._ps.is_callable_attribute(payload)
return {
"message": message,
Expand Down Expand Up @@ -398,7 +501,7 @@ async def handle_message(self, websocket: WebSocket, message: str):
payload = self.parse(message['payload'])

if action == "download":
response = self._handle_download()
response = self._handle_download(payload)
elif action == "upload":
response = self._handle_upload(payload)
elif action == "attribute_call":
Expand All @@ -412,15 +515,20 @@ async def handle_message(self, websocket: WebSocket, message: str):
except Exception as e:
await websocket.send_text(json.dumps({"error": str(e)}))

def _handle_download(self) -> dict:
def _handle_download(self, payload) -> dict:
"""
Handle download requests.

:param payload: the payload containing the ids associated with the requested model for download.
:type payload: dict
:return: a response containing the exported model.
"""
return {
"message": "Download request received.",
"payload": self._ps.export_model(),
"payload": self._ps.export_model(
payload["client_id"],
payload["ml_model_id"],
),
}

def _handle_upload(self, payload: dict) -> dict:
Expand Down
46 changes: 46 additions & 0 deletions pymilo/streaming/pymilo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,52 @@ def upload(self):
else:
print(PYMILO_CLIENT_LOCAL_MODEL_UPLOAD_FAILED)

def register(self):
"""
Register client in the remote server.

:return: None
"""
_id = self._communicator.register_client()
self.client_id = _id

def register_ml_model(self):
"""
Register ML model in the remote server.

:return: None
"""
_id = self._communicator.register_model(
self.encrypt_compress(
{
"client_id": self.client_id,
}
)
)
self.ml_model_id = _id

def get_clients(self):
"""
Get all clients in the remote server.

:return: list of client ids
"""
return self._communicator.get_clients()

def get_ml_models(self):
"""
Get all registered ml models in the remote server for this client.

:return: list of ml model ids
"""
return self._communicator.get_ml_models(
self.encrypt_compress(
{
"client_id": self.client_id,
}
)
)

def __getattr__(self, attribute):
"""
Overwrite the __getattr__ default function to extract requested.
Expand Down
Loading
Loading