Skip to content

Commit e86745a

Browse files
authored
Feat/edge (#38)
* dev: grpc edge implementation, s3 jobs not working * add deps, increment version * dev: set origin in init connection * docs: update var desc * fix: s3 job submission format * docs: add edge usage info to README
1 parent 16a4e64 commit e86745a

24 files changed

+4332
-6
lines changed

README.md

+32
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,38 @@ except error.ResultsError as ex:
130130
results_json = outputs['results.json']
131131
print(results_json)
132132
```
133+
## Modzy Edge Functionality
134+
135+
The SDK provides the following support for Modzy Edge:
136+
137+
```python
138+
from modzy.edge.client import EdgeClient
139+
140+
# initialize edge client
141+
client = EdgeClient('localhost',55000)
142+
143+
# submit text job, wait for completion, get results
144+
job_id = client.submit_text("ed542963de","1.0.27",{"input.txt": "this is awesome"})
145+
final_job_details = client.block_until_complete(job_id)
146+
results = client.get_results(job_id)
147+
148+
# submit embedded job (bytes), wait for completion, get results
149+
job_id = client.submit_embedded("ed542963de","1.0.27",{"input.txt": b"this is awesome"})
150+
final_job_details = client.block_until_complete(job_id)
151+
results = client.get_results(job_id)
152+
153+
# submit S3 job, wait for completion, get results
154+
job_id = client.submit_aws_s3("ed542963de","1.0.27",{"input.txt": {"bucket":bucket,"key":key}},
155+
access_key,secret_key,region)
156+
final_job_details = client.block_until_complete(job_id)
157+
results = client.get_results(job_id)
158+
159+
# get job details for a particular job
160+
job_details = client.get_job_details(job_id)
161+
162+
# get job details for all jobs run on your Modzy Edge instance
163+
all_job_details = client.get_all_job_details()
164+
```
133165

134166
## Features
135167

modzy/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55

66
from .client import ApiClient # noqa
7-
8-
__version__ = '0.6.0'
7+
from .edge.client import EdgeClient
8+
__version__ = '0.7.0'
99

1010
logging.getLogger(__name__).addHandler(logging.NullHandler())

modzy/edge/__init__.py

Whitespace-only changes.

modzy/edge/client.py

+335
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
from os import access
2+
import grpc
3+
import time
4+
import logging
5+
from ..error import ApiError,Timeout
6+
from .._util import depth,encode_data_uri
7+
from grpc._channel import _InactiveRpcError
8+
from google.protobuf.empty_pb2 import Empty
9+
from google.protobuf.struct_pb2 import Struct
10+
from google.protobuf.json_format import MessageToDict
11+
from .proto.jobs.v1.job_pb2 import JobIdentifier
12+
from .proto.results.v1.results_pb2_grpc import ResultsServiceStub
13+
from .proto.jobs.v1.job_pb2 import JobInput,JobSubmission,JobIdentifier
14+
from .proto.jobs.v1.job_pb2_grpc import JobServiceStub
15+
from .proto.common.v1.common_pb2 import ModelIdentifier
16+
17+
class EdgeClient:
18+
"""The Edge API client object.
19+
20+
This class is used to interact with the Modzy Edge API.
21+
22+
Attributes:
23+
host (str): The host for the Modzy Edge API.
24+
port (int): The port on which Modzy Edge is listening.
25+
"""
26+
27+
def __init__(self, host, port):
28+
"""Creates an `ApiClient` instance.
29+
30+
Args:
31+
host (str): The host for the API.
32+
port (int): Port for the API.
33+
"""
34+
self.logger = logging.getLogger(__name__)
35+
self.host = host
36+
self.port = port
37+
self._initialize_connection()
38+
39+
def _initialize_connection(self):
40+
# attempt to create channel, service stubs, and test retrieving job details to confirm connection
41+
self.origin = '{}:{}'.format(self.host,self.port)
42+
self._channel = grpc.insecure_channel(self.origin)
43+
self.jobs_service_stub = JobServiceStub(self._channel)
44+
self.results_service_stub = ResultsServiceStub(self._channel)
45+
self.get_all_job_details(timeout=5)
46+
47+
def __fix_single_source_job(self, sources, s3=False):
48+
"""Compatibility function to check and fix the sources parameter if is a single source dict
49+
50+
Args:
51+
sources (dict): a single of double source dict
52+
53+
Returns:
54+
dict: a properly formatted sources dictionary
55+
56+
"""
57+
dict_levels = depth(sources)
58+
if dict_levels == (1 + s3):
59+
return {'job': sources}
60+
else:
61+
return sources
62+
63+
def __parse_inactive_rpc_error(self,inactive_rpc_error):
64+
"""Parse relevant info from _InactiveRpcError.
65+
66+
Args:
67+
inactive_rpc_error (_InactiveRpcError): Error to be parsed.
68+
69+
Returns:
70+
ApiError: a formatted ApiError.
71+
72+
"""
73+
lines = str(inactive_rpc_error).splitlines()
74+
details_index = [lines.index(l) for l in lines if l.startswith('\tdetails')][0]
75+
details_message = lines[details_index].split('=')[1].strip().replace('"','')
76+
77+
return details_message
78+
79+
def submit_embedded(self, identifier, version, sources, explain=False):
80+
"""Submits a job containing embedded data.
81+
82+
Args:
83+
identifier (str): The model identifier.
84+
version (str): The model version string.
85+
sources (dict): A mapping of source names to text sources. Each source should be a
86+
mapping of model input filename to filepath or file-like object.
87+
explain (bool): indicates if you desire an explainable result for your model.`
88+
89+
Returns:
90+
str: Job identifier returned by Modzy Edge.
91+
92+
Raises:
93+
ApiError: An ApiError will be raised if the API returns an error status,
94+
or the client is unable to connect.
95+
96+
Example:
97+
.. code-block::
98+
99+
job = client.submit_embedded('model-identifier', '1.2.3',
100+
{
101+
'source-name-1': {
102+
'model-input-name-1': b'some bytes',
103+
'model-input-name-2': bytearray([1,2,3,4]),
104+
},
105+
'source-name-2': {
106+
'model-input-name-1': b'some bytes',
107+
'model-input-name-2': bytearray([1,2,3,4]),
108+
}
109+
})
110+
111+
"""
112+
113+
sources = {
114+
source: {
115+
key: encode_data_uri(value)
116+
for key, value in inputs.items()
117+
}
118+
for source, inputs in self.__fix_single_source_job(sources).items()
119+
}
120+
121+
sources_struct = Struct()
122+
for k,v in sources.items():
123+
sources_struct[k] = v
124+
125+
job_input = JobInput(type="embedded",sources=sources_struct)
126+
model_identifier = ModelIdentifier(identifier=identifier,version=version)
127+
job_submission = JobSubmission(model=model_identifier,input=job_input,explain=explain)
128+
129+
try:
130+
job_receipt = self.jobs_service_stub.SubmitJob(job_submission)
131+
except _InactiveRpcError as e:
132+
raise ApiError(self.__parse_inactive_rpc_error(e),self.origin) from e
133+
134+
return job_receipt.job_identifier
135+
136+
def submit_text(self, identifier, version, sources, explain=False):
137+
"""Submits text data for a multiple source `Job`.
138+
139+
Args:
140+
identifier (str): The model identifier.
141+
version (str): The model version string.
142+
sources (dict): A mapping of source names to text sources. Each source should be a
143+
mapping of model input filename to filepath or file-like object.
144+
explain (bool): indicates if you desire an explainable result for your model.`
145+
146+
Returns:
147+
str: Job identifier returned by Modzy Edge.
148+
149+
Raises:
150+
ApiError: An ApiError will be raised if the API returns an error status,
151+
or the client is unable to connect.
152+
153+
Example:
154+
.. code-block::
155+
156+
job = client.submit_text('model-identifier', '1.2.3',
157+
{
158+
'source-name-1': {
159+
'model-input-name-1': 'some text',
160+
'model-input-name-2': 'some more text',
161+
},
162+
'source-name-2': {
163+
'model-input-name-1': 'some text 2',
164+
'model-input-name-2': 'some more text 2',
165+
}
166+
})
167+
168+
"""
169+
sources_struct = Struct()
170+
for k,v in self.__fix_single_source_job(sources).items():
171+
sources_struct[k] = v
172+
173+
job_input = JobInput(type="text",sources=sources_struct)
174+
model_identifier = ModelIdentifier(identifier=identifier,version=version)
175+
job_submission = JobSubmission(model=model_identifier,input=job_input,explain=explain)
176+
177+
try:
178+
job_receipt = self.jobs_service_stub.SubmitJob(job_submission)
179+
except _InactiveRpcError as e:
180+
raise ApiError(self.__parse_inactive_rpc_error(e),self.origin) from e
181+
182+
return job_receipt.job_identifier
183+
184+
def submit_aws_s3(self, identifier, version, sources, access_key_id, secret_access_key, region, explain=False):
185+
"""Submits AwS S3 hosted data for a multiple source `Job`.
186+
187+
Args:
188+
identifier (str): The model identifier or a `Model` instance.
189+
version (str): The model version string.
190+
sources (dict): A mapping of source names to text sources. Each source should be a
191+
mapping of model input filename to S3 bucket and key.
192+
access_key_id (str): The AWS Access Key ID.
193+
secret_access_key (str): The AWS Secret Access Key.
194+
region (str): The AWS Region.
195+
explain (bool): indicates if you desire an explainable result for your model.`
196+
197+
Returns:
198+
str: Job identifier returned by Modzy Edge.
199+
200+
Raises:
201+
ApiError: An ApiError will be raised if the API returns an error status,
202+
or the client is unable to connect.
203+
204+
Example:
205+
.. code-block::
206+
207+
job = client.submit_aws_s3('model-identifier', '1.2.3',
208+
{
209+
'source-name-1': {
210+
'model-input-name-1': {
211+
'bucket': 'my-bucket',
212+
'key': '/my/data/file-1.dat'
213+
},
214+
'model-input-name-2': {
215+
'bucket': 'my-bucket',
216+
'key': '/my/data/file-2.dat'
217+
}
218+
},
219+
'source-name-2': {
220+
'model-input-name-1': {
221+
'bucket': 'my-bucket',
222+
'key': '/my/data/file-3.dat'
223+
},
224+
'model-input-name-2': {
225+
'bucket': 'my-bucket',
226+
'key': '/my/data/file-4.dat'
227+
}
228+
}
229+
},
230+
access_key_id='AWS_ACCESS_KEY_ID',
231+
secret_access_key='AWS_SECRET_ACCESS_KEY',
232+
region='us-east-1',
233+
)
234+
"""
235+
sources_struct = Struct()
236+
for k,v in self.__fix_single_source_job(sources,s3=True).items():
237+
sources_struct[k] = v
238+
239+
job_input = JobInput(type="aws-s3",accessKeyID=access_key_id,secretAccessKey=secret_access_key,
240+
region=region,sources=sources_struct)
241+
242+
model_identifier = ModelIdentifier(identifier=identifier,version=version)
243+
job_submission = JobSubmission(model=model_identifier,input=job_input,explain=explain)
244+
245+
try:
246+
job_receipt = self.jobs_service_stub.SubmitJob(job_submission)
247+
except _InactiveRpcError as e:
248+
raise ApiError(self.__parse_inactive_rpc_error(e),self.origin) from e
249+
250+
return job_receipt.job_identifier
251+
252+
def get_job_details(self, job_identifier):
253+
"""Get job details.
254+
255+
Args:
256+
job_identifier (str): The job identifier.
257+
258+
Returns:
259+
dict: Details for requested job.
260+
261+
Raises:
262+
ApiError: An ApiError will be raised if the API returns an error status,
263+
or the client is unable to connect.
264+
"""
265+
job_identifier = JobIdentifier(identifier=job_identifier)
266+
267+
try:
268+
job_details = self.jobs_service_stub.GetJob(job_identifier)
269+
except _InactiveRpcError as e:
270+
raise ApiError(self.__parse_inactive_rpc_error(e),self.origin) from e
271+
272+
return MessageToDict(job_details)
273+
274+
def get_all_job_details(self,timeout=None):
275+
"""Get job details for all jobs.
276+
277+
Args:
278+
timeout (int): Optional timeout value in seconds.
279+
280+
Returns:
281+
dict: Details for all jobs that have been run.
282+
283+
Raises:
284+
ApiError: An ApiError will be raised if the API returns an error status,
285+
or the client is unable to connect.
286+
"""
287+
try:
288+
all_job_details = self.jobs_service_stub.GetJobs(Empty(),timeout=timeout)
289+
except _InactiveRpcError as e:
290+
raise ApiError(self.__parse_inactive_rpc_error(e),self.origin) from e
291+
292+
return MessageToDict(all_job_details)
293+
294+
def block_until_complete(self, job_identifier, poll_interval=0.01, timeout=30):
295+
"""Block until job complete.
296+
297+
Args:
298+
job_identifier (str): The job identifier.
299+
300+
Returns:
301+
dict: Final job details.
302+
303+
Raises:
304+
ApiError: An ApiError will be raised if the API returns an error status,
305+
or the client is unable to connect.
306+
"""
307+
endby = time.time() + timeout if (timeout is not None) else None
308+
while True:
309+
job_details = self.get_job_details(job_identifier)
310+
if job_details['status'] in {"COMPLETE","CANCELLED","FAILED"}:
311+
return job_details
312+
time.sleep(poll_interval)
313+
if (endby is not None) and (time.time() > endby - poll_interval):
314+
raise Timeout('timed out before completion')
315+
316+
def get_results(self, job_identifier):
317+
"""Block until job complete.
318+
319+
Args:
320+
job_identifier (str): The job identifier.
321+
322+
Returns:
323+
dict: Results for the requested job.
324+
325+
Raises:
326+
ApiError: An ApiError will be raised if the API returns an error status,
327+
or the client is unable to connect.
328+
"""
329+
job_identifier = JobIdentifier(identifier=job_identifier)
330+
try:
331+
results = self.results_service_stub.GetResults(job_identifier)
332+
except _InactiveRpcError as e:
333+
raise ApiError(self.__parse_inactive_rpc_error(e),self.origin) from e
334+
335+
return MessageToDict(results)

modzy/edge/proto/__init__.py

Whitespace-only changes.

modzy/edge/proto/common/__init__.py

Whitespace-only changes.

modzy/edge/proto/common/v1/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)