Skip to content

Commit deec0c2

Browse files
feat: support for samples (input/output) routes. (#7)
1 parent 6d498b3 commit deec0c2

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

modzy/models.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,44 @@ def get_version(self, model, version):
107107
ApiError: A subclass of ApiError will be raised if the API returns an error status,
108108
or the client is unable to connect.
109109
"""
110-
self.logger.debug("getting versopm model %s version %s", model, version)
110+
self.logger.debug("getting version model %s version %s", model, version)
111111
modelId = Model._coerce_identifier(model)
112112
versionId = ModelVersion._coerce_identifier(version)
113113
json_obj = self._api_client.http.get('{}/{}/versions/{}'.format(self._base_route, modelId, versionId))
114114
return ModelVersion(json_obj, self._api_client)
115115

116+
def get_version_input_sample(self, model, version):
117+
"""Gets the input sample associated with the model and version provided.
118+
119+
Returns:
120+
String: A json string with the input sample
121+
122+
Raises:
123+
ApiError: A subclass of ApiError will be raised if the API returns an error status,
124+
or the client is unable to connect.
125+
"""
126+
self.logger.debug("getting input sample: model %s version %s", model, version)
127+
modelId = Model._coerce_identifier(model)
128+
versionId = ModelVersion._coerce_identifier(version)
129+
json_obj = self._api_client.http.get('{}/{}/versions/{}/sample-input'.format(self._base_route, modelId, versionId))
130+
return json_obj
131+
132+
def get_version_output_sample(self, model, version):
133+
"""Gets the output sample associated with the model and version provided.
134+
135+
Returns:
136+
String: A json string with the output sample
137+
138+
Raises:
139+
ApiError: A subclass of ApiError will be raised if the API returns an error status,
140+
or the client is unable to connect.
141+
"""
142+
self.logger.debug("getting output sample: model %s version %s", model, version)
143+
modelId = Model._coerce_identifier(model)
144+
versionId = ModelVersion._coerce_identifier(version)
145+
json_obj = self._api_client.http.get('{}/{}/versions/{}/sample-output'.format(self._base_route, modelId, versionId))
146+
return json_obj
147+
116148
def get_all(self):
117149
"""Gets a list of all `Model` instances.
118150

tests/test_models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,21 @@ def test_get_model_versions(client, logger):
102102
logger.debug("version: %s", version)
103103
assert version.version
104104
assert len(version) # just going to assume there should be some versions
105+
106+
107+
def test_get_model_version(client, logger):
108+
version = client.models.get_version(MODEL_ID, '0.0.27')
109+
logger.debug("version: %s", version)
110+
assert version.version
111+
112+
113+
def test_get_model_version_input_sample(client, logger):
114+
input_sample = client.models.get_version_input_sample(MODEL_ID, '0.0.27')
115+
logger.debug("version: %s", input_sample)
116+
assert input_sample
117+
118+
119+
def test_get_model_version_output_sample(client, logger):
120+
output_sample = client.models.get_version_output_sample(MODEL_ID, '0.0.27')
121+
logger.debug("version: %s", output_sample)
122+
assert output_sample

0 commit comments

Comments
 (0)