Skip to content

Commit 78d8a59

Browse files
Alec Glassfordtswast
Alec Glassford
authored andcommitted
Add custom prediction routine samples for AI Platform (GoogleCloudPlatform#2121)
* Add custom prediction routine samples Change-Id: I734bebd77970a3ab627b0cbffdcb8fef320c2de4 * Ensure line limit of 79 characters Change-Id: Ic3b512b7478a1e5052baf2978ed1fbc384793e2e
1 parent 3619a77 commit 78d8a59

File tree

6 files changed

+287
-0
lines changed

6 files changed

+287
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Custom prediction routines (beta)
2+
3+
Read the AI Platform documentation about custom prediction routines to learn how
4+
to use these samples:
5+
6+
* [Custom prediction routines (with a TensorFlow Keras
7+
example)](https://cloud.google.com/ml-engine/docs/tensorflow/custom-prediction-routines)
8+
* [Custom prediction routines (with a scikit-learn
9+
example)](https://cloud.google.com/ml-engine/docs/scikit/custom-prediction-routines)
10+
11+
If you want to package a predictor directly from this directory, make sure to
12+
edit `setup.py`: replace the reference to `predictor.py` with either
13+
`tensorflow-predictor.py` or `scikit-predictor.py`.
14+
15+
## What's next
16+
17+
For a more complete example of how to train and deploy a custom prediction
18+
routine, check out one of the following tutorials:
19+
20+
* [Creating a custom prediction routine with
21+
Keras](https://cloud.google.com/ml-engine/docs/tensorflow/custom-prediction-routine-keras)
22+
(also available as [a Jupyter
23+
notebook](https://colab.research.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/tensorflow/custom-prediction-routine-keras.ipynb))
24+
25+
* [Creating a custom prediction routine with
26+
scikit-learn](https://cloud.google.com/ml-engine/docs/scikit/custom-prediction-routine-scikit-learn)
27+
(also available as [a Jupyter
28+
notebook](https://colab.research.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/scikit-learn/custom-prediction-routine-scikit-learn.ipynb))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
class Predictor(object):
17+
"""Interface for constructing custom predictors."""
18+
19+
def predict(self, instances, **kwargs):
20+
"""Performs custom prediction.
21+
22+
Instances are the decoded values from the request. They have already
23+
been deserialized from JSON.
24+
25+
Args:
26+
instances: A list of prediction input instances.
27+
**kwargs: A dictionary of keyword args provided as additional
28+
fields on the predict request body.
29+
30+
Returns:
31+
A list of outputs containing the prediction results. This list must
32+
be JSON serializable.
33+
"""
34+
raise NotImplementedError()
35+
36+
@classmethod
37+
def from_path(cls, model_dir):
38+
"""Creates an instance of Predictor using the given path.
39+
40+
Loading of the predictor should be done in this method.
41+
42+
Args:
43+
model_dir: The local directory that contains the exported model
44+
file along with any additional files uploaded when creating the
45+
version resource.
46+
47+
Returns:
48+
An instance implementing this Predictor class.
49+
"""
50+
raise NotImplementedError()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
18+
class ZeroCenterer(object):
19+
"""Stores means of each column of a matrix and uses them for preprocessing.
20+
"""
21+
22+
def __init__(self):
23+
"""On initialization, is not tied to any distribution."""
24+
self._means = None
25+
26+
def preprocess(self, data):
27+
"""Transforms a matrix.
28+
29+
The first time this is called, it stores the means of each column of
30+
the input. Then it transforms the input so each column has mean 0. For
31+
subsequent calls, it subtracts the stored means from each column. This
32+
lets you 'center' data at prediction time based on the distribution of
33+
the original training data.
34+
35+
Args:
36+
data: A NumPy matrix of numerical data.
37+
38+
Returns:
39+
A transformed matrix with the same dimensions as the input.
40+
"""
41+
if self._means is None: # during training only
42+
self._means = np.mean(data, axis=0)
43+
return data - self._means
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import pickle
17+
18+
import numpy as np
19+
from sklearn.externals import joblib
20+
21+
22+
class MyPredictor(object):
23+
"""An example Predictor for an AI Platform custom prediction routine."""
24+
25+
def __init__(self, model, preprocessor):
26+
"""Stores artifacts for prediction. Only initialized via `from_path`.
27+
"""
28+
self._model = model
29+
self._preprocessor = preprocessor
30+
31+
def predict(self, instances, **kwargs):
32+
"""Performs custom prediction.
33+
34+
Preprocesses inputs, then performs prediction using the trained
35+
scikit-learn model.
36+
37+
Args:
38+
instances: A list of prediction input instances.
39+
**kwargs: A dictionary of keyword args provided as additional
40+
fields on the predict request body.
41+
42+
Returns:
43+
A list of outputs containing the prediction results.
44+
"""
45+
inputs = np.asarray(instances)
46+
preprocessed_inputs = self._preprocessor.preprocess(inputs)
47+
outputs = self._model.predict(preprocessed_inputs)
48+
return outputs.tolist()
49+
50+
@classmethod
51+
def from_path(cls, model_dir):
52+
"""Creates an instance of MyPredictor using the given path.
53+
54+
This loads artifacts that have been copied from your model directory in
55+
Cloud Storage. MyPredictor uses them during prediction.
56+
57+
Args:
58+
model_dir: The local directory that contains the trained
59+
scikit-learn model and the pickled preprocessor instance. These
60+
are copied from the Cloud Storage model directory you provide
61+
when you deploy a version resource.
62+
63+
Returns:
64+
An instance of `MyPredictor`.
65+
"""
66+
model_path = os.path.join(model_dir, 'model.joblib')
67+
model = joblib.load(model_path)
68+
69+
preprocessor_path = os.path.join(model_dir, 'preprocessor.pkl')
70+
with open(preprocessor_path, 'rb') as f:
71+
preprocessor = pickle.load(f)
72+
73+
return cls(model, preprocessor)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from setuptools import setup
16+
17+
setup(
18+
name='my_custom_code',
19+
version='0.1',
20+
scripts=['predictor.py', 'preprocess.py'])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import pickle
17+
18+
import numpy as np
19+
from tensorflow import keras
20+
21+
22+
class MyPredictor(object):
23+
"""An example Predictor for an AI Platform custom prediction routine."""
24+
25+
def __init__(self, model, preprocessor):
26+
"""Stores artifacts for prediction. Only initialized via `from_path`.
27+
"""
28+
self._model = model
29+
self._preprocessor = preprocessor
30+
31+
def predict(self, instances, **kwargs):
32+
"""Performs custom prediction.
33+
34+
Preprocesses inputs, then performs prediction using the trained Keras
35+
model.
36+
37+
Args:
38+
instances: A list of prediction input instances.
39+
**kwargs: A dictionary of keyword args provided as additional
40+
fields on the predict request body.
41+
42+
Returns:
43+
A list of outputs containing the prediction results.
44+
"""
45+
inputs = np.asarray(instances)
46+
preprocessed_inputs = self._preprocessor.preprocess(inputs)
47+
outputs = self._model.predict(preprocessed_inputs)
48+
return outputs.tolist()
49+
50+
@classmethod
51+
def from_path(cls, model_dir):
52+
"""Creates an instance of MyPredictor using the given path.
53+
54+
This loads artifacts that have been copied from your model directory in
55+
Cloud Storage. MyPredictor uses them during prediction.
56+
57+
Args:
58+
model_dir: The local directory that contains the trained Keras
59+
model and the pickled preprocessor instance. These are copied
60+
from the Cloud Storage model directory you provide when you
61+
deploy a version resource.
62+
63+
Returns:
64+
An instance of `MyPredictor`.
65+
"""
66+
model_path = os.path.join(model_dir, 'model.h5')
67+
model = keras.models.load_model(model_path)
68+
69+
preprocessor_path = os.path.join(model_dir, 'preprocessor.pkl')
70+
with open(preprocessor_path, 'rb') as f:
71+
preprocessor = pickle.load(f)
72+
73+
return cls(model, preprocessor)

0 commit comments

Comments
 (0)