-
Notifications
You must be signed in to change notification settings - Fork 131
Open
Labels
Description
Please fill out the form below.
System Information
- Spark or PySpark: PySpark
- SDK Version: 2.3.4
- Spark Version:
- Algorithm (e.g. KMeans): Random Cut Forest Estimator
Describe the problem
I have the following code in pyspark trying to to construct a SageMakerEstimator
for a random cut forest image:
# Random Cut Forest Estimator
from pyspark.sql.types import *
from sagemaker_pyspark import IAMRole
from sagemaker import get_execution_role
from sagemaker_pyspark import SageMakerEstimator
from sagemaker_pyspark import RandomNamePolicyFactory
from sagemaker_pyspark import EndpointCreationPolicy
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker_pyspark.transformation.serializers.serializers import ProtobufRequestRowSerializer
from sagemaker_pyspark.transformation.deserializers.deserializers import ProtobufResponseRowDeserializer
response_schema = StructType([StructField("score", DoubleType(), False)])
estimator = SageMakerEstimator(
trainingImage = get_image_uri(region, 'randomcutforest'), # Training image
modelImage = get_image_uri(region, 'randomcutforest'), # Model image
requestRowSerializer = ProtobufRequestRowSerializer(featuresColumnName="features"),
responseRowDeserializer = ProtobufResponseRowDeserializer(response_schema, protobufKeys["score"]),
sagemakerRole = IAMRole(role),
hyperParameters = {"feature_dim": "6"},
trainingInstanceType = "ml.m4.4xlarge",
trainingInstanceCount = 1,
endpointInstanceType = "ml.t2.medium",
endpointInitialInstanceCount = 1,
trainingSparkDataFormat = "sagemaker",
namePolicyFactory = RandomNamePolicyFactory("sparksm-4-"),
endpointCreationPolicy = EndpointCreationPolicy.CREATE_ON_CONSTRUCT
)
When I run this code using PySpark, I got the following error:
Py4JError: An error occurred while calling None.com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer. Trace:
py4j.Py4JException: Constructor com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer([class org.apache.spark.sql.types.StructType, class scala.collection.immutable.$colon$colon]) does not exist
at py4j.reflection.ReflectionEngine.getConstructor(ReflectionEngine.java:179)
at py4j.reflection.ReflectionEngine.getConstructor(ReflectionEngine.java:196)
at py4j.Gateway.invoke(Gateway.java:237)
at py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)
at py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:745)
The problem is in the ProtobufResponseRowDeserializer
. According to the source code of this object for Scala, it should accept a Seq
.
What is the correct counterpart in PySpark? Obviously it doesn't accept a list of string.
I tried to search the sagemaker-spark-sdk
and I couldn't find any reference there.