Start Spark Session

LR Model in Spark - Diamonds

This is an expansion on Build LR & Test/Evaluate LR Model in Python - Diamonds examples but this time we’ll use SparkML to build a LR Model that will predict the price of Diamonds


  1. Use PySpark to connect to a spark cluster.
  2. Create a spark session.
  3. Read a csv file into a data frame.
  4. Split the dataset into training and testing sets.
  5. Use VectorAssembler to combine multiple columns into a single vector column
  6. Use Linear Regression to build a prediction model.
  7. Use metrics to evaluate the model.
  8. Stop the spark session
# To suppress warnings generated by the code
def warn(*args, **kwargs):
import warnings
warnings.warn = warn

# FindSpark simplifies the process of using Apache Spark with Python
import findspark

from pyspark.sql import SparkSession

#import functions/Classes for sparkml
from import VectorAssembler
from import LinearRegression

# import functions/Classes for metrics
from import RegressionEvaluator
#Create SparkSession
spark = SparkSession.builder.appName("Regressing using SparkML").getOrCreate()

Import Data

import wget ("")

CSV to SparkDF

# using the function we load the data into a dataframe.
# the header = True mentions that there is a header row in out csv file
# the inferSchema = True, tells spark to automatically find out the data types of the columns.

# Load mpg dataset
diamond_data ="diamonds.csv", header=True, inferSchema=True)

|  s|carat|    cut|color|clarity|depth|table|price|   x|   y|   z|
|  1| 0.23|  Ideal|    E|    SI2| 61.5| 55.0|  326|3.95|3.98|2.43|
|  2| 0.21|Premium|    E|    SI1| 59.8| 61.0|  326|3.89|3.84|2.31|
|  3| 0.23|   Good|    E|    VS1| 56.9| 65.0|  327|4.05|4.07|2.31|
|  4| 0.29|Premium|    I|    VS2| 62.4| 58.0|  334| 4.2|4.23|2.63|
|  5| 0.31|   Good|    J|    SI2| 63.3| 58.0|  335|4.34|4.35|2.75|
only showing top 5 rows

Identify Label/Input Columns

  • use the price column as label/target column
  • use the columns carat,depth and table as features
assembler = VectorAssembler(inputCols=["carat", "depth", "table"], outputCol="features")
diamond_transformed_data = assembler.transform(diamond_data)

# Print the vectorized features/input and label/target columns"features", "price").show()
|        features|price|
|[0.23,61.5,55.0]|  326|
|[0.21,59.8,61.0]|  326|
|[0.23,56.9,65.0]|  327|
|[0.29,62.4,58.0]|  334|
|[0.31,63.3,58.0]|  335|
|[0.24,62.8,57.0]|  336|
|[0.24,62.3,57.0]|  336|
|[0.26,61.9,55.0]|  337|
|[0.22,65.1,61.0]|  337|
|[0.23,59.4,61.0]|  338|
| [0.3,64.0,55.0]|  339|
|[0.23,62.8,56.0]|  340|
|[0.22,60.4,61.0]|  342|
|[0.31,62.2,54.0]|  344|
| [0.2,60.2,62.0]|  345|
|[0.32,60.9,58.0]|  345|
| [0.3,62.0,54.0]|  348|
| [0.3,63.4,54.0]|  351|
| [0.3,63.8,56.0]|  351|
| [0.3,62.7,59.0]|  351|
only showing top 20 rows

Split Data

# Split data into training and testing sets
(training_data, testing_data) = diamond_transformed_data.randomSplit([0.7, 0.3], seed=42)

Build/Train Model

# Train linear regression model
# Ignore any warnings

lr = LinearRegression(featuresCol="features", labelCol="price")
model =

Predict Price

# Make predictions on testing data
predictions = model.transform(testing_data)

Evaluate Model

R squared

The closer to 1 the better

#R-squared (R2): R2 is a statistical measure that represents the proportion of variance in the dependent variable (target) that is explained by the independent variables (features).

evaluator = RegressionEvaluator(labelCol="price", predictionCol="prediction", metricName="r2")
r2 = evaluator.evaluate(predictions)
print("R Squared =", r2)

R Squared = 0.854508517843993

MAE Mean Absolute Error

Lower the better

evaluator = RegressionEvaluator(labelCol="price", predictionCol="prediction", metricName="mae")
mae = evaluator.evaluate(predictions)
print("MAE =", mae)

MAE = 994.7282983463749

RMSE Root Mean Squared Error

The lower the better

evaluator = RegressionEvaluator(labelCol="price", predictionCol="prediction", metricName="rmse")
mae = evaluator.evaluate(predictions)
print("RMSE =", rmse)

RMSE = 1534.8181642609825