1
+ from os import access
2
+ import grpc
3
+ import time
4
+ import logging
5
+ from ..error import ApiError ,Timeout
6
+ from .._util import depth ,encode_data_uri
7
+ from grpc ._channel import _InactiveRpcError
8
+ from google .protobuf .empty_pb2 import Empty
9
+ from google .protobuf .struct_pb2 import Struct
10
+ from google .protobuf .json_format import MessageToDict
11
+ from .proto .jobs .v1 .job_pb2 import JobIdentifier
12
+ from .proto .results .v1 .results_pb2_grpc import ResultsServiceStub
13
+ from .proto .jobs .v1 .job_pb2 import JobInput ,JobSubmission ,JobIdentifier
14
+ from .proto .jobs .v1 .job_pb2_grpc import JobServiceStub
15
+ from .proto .common .v1 .common_pb2 import ModelIdentifier
16
+
17
+ class EdgeClient :
18
+ """The Edge API client object.
19
+
20
+ This class is used to interact with the Modzy Edge API.
21
+
22
+ Attributes:
23
+ host (str): The host for the Modzy Edge API.
24
+ port (int): The port on which Modzy Edge is listening.
25
+ """
26
+
27
+ def __init__ (self , host , port ):
28
+ """Creates an `ApiClient` instance.
29
+
30
+ Args:
31
+ host (str): The host for the API.
32
+ port (int): Port for the API.
33
+ """
34
+ self .logger = logging .getLogger (__name__ )
35
+ self .host = host
36
+ self .port = port
37
+ self ._initialize_connection ()
38
+
39
+ def _initialize_connection (self ):
40
+ # attempt to create channel, service stubs, and test retrieving job details to confirm connection
41
+ self .origin = '{}:{}' .format (self .host ,self .port )
42
+ self ._channel = grpc .insecure_channel (self .origin )
43
+ self .jobs_service_stub = JobServiceStub (self ._channel )
44
+ self .results_service_stub = ResultsServiceStub (self ._channel )
45
+ self .get_all_job_details (timeout = 5 )
46
+
47
+ def __fix_single_source_job (self , sources , s3 = False ):
48
+ """Compatibility function to check and fix the sources parameter if is a single source dict
49
+
50
+ Args:
51
+ sources (dict): a single of double source dict
52
+
53
+ Returns:
54
+ dict: a properly formatted sources dictionary
55
+
56
+ """
57
+ dict_levels = depth (sources )
58
+ if dict_levels == (1 + s3 ):
59
+ return {'job' : sources }
60
+ else :
61
+ return sources
62
+
63
+ def __parse_inactive_rpc_error (self ,inactive_rpc_error ):
64
+ """Parse relevant info from _InactiveRpcError.
65
+
66
+ Args:
67
+ inactive_rpc_error (_InactiveRpcError): Error to be parsed.
68
+
69
+ Returns:
70
+ ApiError: a formatted ApiError.
71
+
72
+ """
73
+ lines = str (inactive_rpc_error ).splitlines ()
74
+ details_index = [lines .index (l ) for l in lines if l .startswith ('\t details' )][0 ]
75
+ details_message = lines [details_index ].split ('=' )[1 ].strip ().replace ('"' ,'' )
76
+
77
+ return details_message
78
+
79
+ def submit_embedded (self , identifier , version , sources , explain = False ):
80
+ """Submits a job containing embedded data.
81
+
82
+ Args:
83
+ identifier (str): The model identifier.
84
+ version (str): The model version string.
85
+ sources (dict): A mapping of source names to text sources. Each source should be a
86
+ mapping of model input filename to filepath or file-like object.
87
+ explain (bool): indicates if you desire an explainable result for your model.`
88
+
89
+ Returns:
90
+ str: Job identifier returned by Modzy Edge.
91
+
92
+ Raises:
93
+ ApiError: An ApiError will be raised if the API returns an error status,
94
+ or the client is unable to connect.
95
+
96
+ Example:
97
+ .. code-block::
98
+
99
+ job = client.submit_embedded('model-identifier', '1.2.3',
100
+ {
101
+ 'source-name-1': {
102
+ 'model-input-name-1': b'some bytes',
103
+ 'model-input-name-2': bytearray([1,2,3,4]),
104
+ },
105
+ 'source-name-2': {
106
+ 'model-input-name-1': b'some bytes',
107
+ 'model-input-name-2': bytearray([1,2,3,4]),
108
+ }
109
+ })
110
+
111
+ """
112
+
113
+ sources = {
114
+ source : {
115
+ key : encode_data_uri (value )
116
+ for key , value in inputs .items ()
117
+ }
118
+ for source , inputs in self .__fix_single_source_job (sources ).items ()
119
+ }
120
+
121
+ sources_struct = Struct ()
122
+ for k ,v in sources .items ():
123
+ sources_struct [k ] = v
124
+
125
+ job_input = JobInput (type = "embedded" ,sources = sources_struct )
126
+ model_identifier = ModelIdentifier (identifier = identifier ,version = version )
127
+ job_submission = JobSubmission (model = model_identifier ,input = job_input ,explain = explain )
128
+
129
+ try :
130
+ job_receipt = self .jobs_service_stub .SubmitJob (job_submission )
131
+ except _InactiveRpcError as e :
132
+ raise ApiError (self .__parse_inactive_rpc_error (e ),self .origin ) from e
133
+
134
+ return job_receipt .job_identifier
135
+
136
+ def submit_text (self , identifier , version , sources , explain = False ):
137
+ """Submits text data for a multiple source `Job`.
138
+
139
+ Args:
140
+ identifier (str): The model identifier.
141
+ version (str): The model version string.
142
+ sources (dict): A mapping of source names to text sources. Each source should be a
143
+ mapping of model input filename to filepath or file-like object.
144
+ explain (bool): indicates if you desire an explainable result for your model.`
145
+
146
+ Returns:
147
+ str: Job identifier returned by Modzy Edge.
148
+
149
+ Raises:
150
+ ApiError: An ApiError will be raised if the API returns an error status,
151
+ or the client is unable to connect.
152
+
153
+ Example:
154
+ .. code-block::
155
+
156
+ job = client.submit_text('model-identifier', '1.2.3',
157
+ {
158
+ 'source-name-1': {
159
+ 'model-input-name-1': 'some text',
160
+ 'model-input-name-2': 'some more text',
161
+ },
162
+ 'source-name-2': {
163
+ 'model-input-name-1': 'some text 2',
164
+ 'model-input-name-2': 'some more text 2',
165
+ }
166
+ })
167
+
168
+ """
169
+ sources_struct = Struct ()
170
+ for k ,v in self .__fix_single_source_job (sources ).items ():
171
+ sources_struct [k ] = v
172
+
173
+ job_input = JobInput (type = "text" ,sources = sources_struct )
174
+ model_identifier = ModelIdentifier (identifier = identifier ,version = version )
175
+ job_submission = JobSubmission (model = model_identifier ,input = job_input ,explain = explain )
176
+
177
+ try :
178
+ job_receipt = self .jobs_service_stub .SubmitJob (job_submission )
179
+ except _InactiveRpcError as e :
180
+ raise ApiError (self .__parse_inactive_rpc_error (e ),self .origin ) from e
181
+
182
+ return job_receipt .job_identifier
183
+
184
+ def submit_aws_s3 (self , identifier , version , sources , access_key_id , secret_access_key , region , explain = False ):
185
+ """Submits AwS S3 hosted data for a multiple source `Job`.
186
+
187
+ Args:
188
+ identifier (str): The model identifier or a `Model` instance.
189
+ version (str): The model version string.
190
+ sources (dict): A mapping of source names to text sources. Each source should be a
191
+ mapping of model input filename to S3 bucket and key.
192
+ access_key_id (str): The AWS Access Key ID.
193
+ secret_access_key (str): The AWS Secret Access Key.
194
+ region (str): The AWS Region.
195
+ explain (bool): indicates if you desire an explainable result for your model.`
196
+
197
+ Returns:
198
+ str: Job identifier returned by Modzy Edge.
199
+
200
+ Raises:
201
+ ApiError: An ApiError will be raised if the API returns an error status,
202
+ or the client is unable to connect.
203
+
204
+ Example:
205
+ .. code-block::
206
+
207
+ job = client.submit_aws_s3('model-identifier', '1.2.3',
208
+ {
209
+ 'source-name-1': {
210
+ 'model-input-name-1': {
211
+ 'bucket': 'my-bucket',
212
+ 'key': '/my/data/file-1.dat'
213
+ },
214
+ 'model-input-name-2': {
215
+ 'bucket': 'my-bucket',
216
+ 'key': '/my/data/file-2.dat'
217
+ }
218
+ },
219
+ 'source-name-2': {
220
+ 'model-input-name-1': {
221
+ 'bucket': 'my-bucket',
222
+ 'key': '/my/data/file-3.dat'
223
+ },
224
+ 'model-input-name-2': {
225
+ 'bucket': 'my-bucket',
226
+ 'key': '/my/data/file-4.dat'
227
+ }
228
+ }
229
+ },
230
+ access_key_id='AWS_ACCESS_KEY_ID',
231
+ secret_access_key='AWS_SECRET_ACCESS_KEY',
232
+ region='us-east-1',
233
+ )
234
+ """
235
+ sources_struct = Struct ()
236
+ for k ,v in self .__fix_single_source_job (sources ,s3 = True ).items ():
237
+ sources_struct [k ] = v
238
+
239
+ job_input = JobInput (type = "aws-s3" ,accessKeyID = access_key_id ,secretAccessKey = secret_access_key ,
240
+ region = region ,sources = sources_struct )
241
+
242
+ model_identifier = ModelIdentifier (identifier = identifier ,version = version )
243
+ job_submission = JobSubmission (model = model_identifier ,input = job_input ,explain = explain )
244
+
245
+ try :
246
+ job_receipt = self .jobs_service_stub .SubmitJob (job_submission )
247
+ except _InactiveRpcError as e :
248
+ raise ApiError (self .__parse_inactive_rpc_error (e ),self .origin ) from e
249
+
250
+ return job_receipt .job_identifier
251
+
252
+ def get_job_details (self , job_identifier ):
253
+ """Get job details.
254
+
255
+ Args:
256
+ job_identifier (str): The job identifier.
257
+
258
+ Returns:
259
+ dict: Details for requested job.
260
+
261
+ Raises:
262
+ ApiError: An ApiError will be raised if the API returns an error status,
263
+ or the client is unable to connect.
264
+ """
265
+ job_identifier = JobIdentifier (identifier = job_identifier )
266
+
267
+ try :
268
+ job_details = self .jobs_service_stub .GetJob (job_identifier )
269
+ except _InactiveRpcError as e :
270
+ raise ApiError (self .__parse_inactive_rpc_error (e ),self .origin ) from e
271
+
272
+ return MessageToDict (job_details )
273
+
274
+ def get_all_job_details (self ,timeout = None ):
275
+ """Get job details for all jobs.
276
+
277
+ Args:
278
+ timeout (int): Optional timeout value in seconds.
279
+
280
+ Returns:
281
+ dict: Details for all jobs that have been run.
282
+
283
+ Raises:
284
+ ApiError: An ApiError will be raised if the API returns an error status,
285
+ or the client is unable to connect.
286
+ """
287
+ try :
288
+ all_job_details = self .jobs_service_stub .GetJobs (Empty (),timeout = timeout )
289
+ except _InactiveRpcError as e :
290
+ raise ApiError (self .__parse_inactive_rpc_error (e ),self .origin ) from e
291
+
292
+ return MessageToDict (all_job_details )
293
+
294
+ def block_until_complete (self , job_identifier , poll_interval = 0.01 , timeout = 30 ):
295
+ """Block until job complete.
296
+
297
+ Args:
298
+ job_identifier (str): The job identifier.
299
+
300
+ Returns:
301
+ dict: Final job details.
302
+
303
+ Raises:
304
+ ApiError: An ApiError will be raised if the API returns an error status,
305
+ or the client is unable to connect.
306
+ """
307
+ endby = time .time () + timeout if (timeout is not None ) else None
308
+ while True :
309
+ job_details = self .get_job_details (job_identifier )
310
+ if job_details ['status' ] in {"COMPLETE" ,"CANCELLED" ,"FAILED" }:
311
+ return job_details
312
+ time .sleep (poll_interval )
313
+ if (endby is not None ) and (time .time () > endby - poll_interval ):
314
+ raise Timeout ('timed out before completion' )
315
+
316
+ def get_results (self , job_identifier ):
317
+ """Block until job complete.
318
+
319
+ Args:
320
+ job_identifier (str): The job identifier.
321
+
322
+ Returns:
323
+ dict: Results for the requested job.
324
+
325
+ Raises:
326
+ ApiError: An ApiError will be raised if the API returns an error status,
327
+ or the client is unable to connect.
328
+ """
329
+ job_identifier = JobIdentifier (identifier = job_identifier )
330
+ try :
331
+ results = self .results_service_stub .GetResults (job_identifier )
332
+ except _InactiveRpcError as e :
333
+ raise ApiError (self .__parse_inactive_rpc_error (e ),self .origin ) from e
334
+
335
+ return MessageToDict (results )
0 commit comments