Skip to content

What is the correct way to construct a ProtobufResponseRowDeserializer in PySpark? #118

@RunshengSong

Description

@RunshengSong

Please fill out the form below.

System Information

  • Spark or PySpark: PySpark
  • SDK Version: 2.3.4
  • Spark Version:
  • Algorithm (e.g. KMeans): Random Cut Forest Estimator

Describe the problem

I have the following code in pyspark trying to to construct a SageMakerEstimator for a random cut forest image:

# Random Cut Forest Estimator
from pyspark.sql.types import *
from sagemaker_pyspark import IAMRole
from sagemaker import get_execution_role
from sagemaker_pyspark import SageMakerEstimator
from sagemaker_pyspark import RandomNamePolicyFactory
from sagemaker_pyspark import EndpointCreationPolicy
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker_pyspark.transformation.serializers.serializers import ProtobufRequestRowSerializer
from sagemaker_pyspark.transformation.deserializers.deserializers import ProtobufResponseRowDeserializer


response_schema = StructType([StructField("score", DoubleType(), False)])

estimator = SageMakerEstimator(
    trainingImage = get_image_uri(region, 'randomcutforest'), # Training image 
    modelImage = get_image_uri(region, 'randomcutforest'), # Model image
    requestRowSerializer = ProtobufRequestRowSerializer(featuresColumnName="features"),
    responseRowDeserializer = ProtobufResponseRowDeserializer(response_schema, protobufKeys["score"]),
    sagemakerRole = IAMRole(role),
    hyperParameters = {"feature_dim": "6"}, 
    trainingInstanceType = "ml.m4.4xlarge",
    trainingInstanceCount = 1,
    endpointInstanceType = "ml.t2.medium",
    endpointInitialInstanceCount = 1,
    trainingSparkDataFormat = "sagemaker",
    namePolicyFactory = RandomNamePolicyFactory("sparksm-4-"),
    endpointCreationPolicy = EndpointCreationPolicy.CREATE_ON_CONSTRUCT
    )

When I run this code using PySpark, I got the following error:

Py4JError: An error occurred while calling None.com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer. Trace:
py4j.Py4JException: Constructor com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer([class org.apache.spark.sql.types.StructType, class scala.collection.immutable.$colon$colon]) does not exist
	at py4j.reflection.ReflectionEngine.getConstructor(ReflectionEngine.java:179)
	at py4j.reflection.ReflectionEngine.getConstructor(ReflectionEngine.java:196)
	at py4j.Gateway.invoke(Gateway.java:237)
	at py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)
	at py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:745)

The problem is in the ProtobufResponseRowDeserializer. According to the source code of this object for Scala, it should accept a Seq.

What is the correct counterpart in PySpark? Obviously it doesn't accept a list of string.

I tried to search the sagemaker-spark-sdk and I couldn't find any reference there.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions