from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from sagemaker_pyspark import IAMRole, classpath_jars
from sagemaker_pyspark.algorithms import KMeansSageMakerEstimator
spark = SparkSession \
.builder \
.appName("SageMaker") \
.getOrCreate()
# Load the sagemaker_pyspark classpath. If you used --jars to submit your job
# there is no need to do this in code.
#conf = (SparkConf()
# .set("spark.driver.extraClassPath", ":".join(classpath_jars())))
#SparkContext(conf=conf)
#sc= SparkContext(appName="SageMaker")
iam_role = "<role_Arn>"
region = "us-east-1"
training_data = spark.read.format("libsvm").option("numFeatures", "784").load("s3a://sagemaker-sample-data-{}/spark/mnist/train/".format(region))
test_data = spark.read.format("libsvm").option("numFeatures", "784").load("s3a://sagemaker-sample-data-{}/spark/mnist/train/".format(region))
kmeans_estimator = KMeansSageMakerEstimator(trainingInstanceType="ml.m4.xlarge", trainingInstanceCount=1, endpointInstanceType="ml.m4.xlarge", endpointInitialInstanceCount=1, sagemakerRole=IAMRole(iam_role))
kmeans_estimator.setK(10)
kmeans_estimator.setFeatureDim(784)
kmeans_model = kmeans_estimator.fit(training_data)
transformed_data = kmeans_model.transform(test_data)
transformed_data.show()