User Churn Prediction with Spark

Struggling with churn over?

Project Overview

With slick apps and designs today, services that do not quickly adapt to their customers needs risk losing them to competitors. For a lot of businesses and investors, customer churn is an important problem that not only needs to be measured but addressed actively to ensure value for customers and revenue for the business. This project from the Udacity Data Scientist Nanodegree is aimed at solving just this problem. The code can be found here on my github repository.

We will follow the CRISP-DM process and start by understanding the business problem, define our strategy, dive into available data, and explore whether we can use data science to create value for the business.

Business Understanding

A music streaming company called ‘Sparkify’ provides millions of users the ability to stream music of their choice everyday. To use this service user’s can register using paid or the free tier subscription. Although the data sample we work with in this specific project is a very small (larger dataset also available through Udacity), proportion of subscription cancellation is very high leading to a 25% loss in subscription revenue. For our task, we look at click data to define, measure and predict customer churn to create business value.


At this stage, we gain a good understanding of available data and what our churn label really means. From here we work on creating features which help explain why users churn e.g. total songs heard, how long they have been a user, how many friends they have invited to build a community, how much of the app they have explored etc…

We finally start training data science models to learn how the features we created relate to user churn. We train different models and evaluate the best performing ones based on a suitable metric. We then further optimise the best performing model using techniques learned throughout the nano degree courses.

In the end, we share results. This strategy would help not only predict customers who are likely to cancel their subscription but also give us insight into why they might churn using the model’s important features. With explainability in mind for the final results, we will avoid use matrix factorisation or neural network algorithms to build our models.

Data Understanding

In our data we have 286,500 events collected between 1st October 2018 and 3rd Dec 2018. We first look at our input schema:

|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

The columns we will use are indicated in bold. For this exercise we will not look at all the columns available, however they can be explored in subsequent iterations.

First thing we want to do before exploring the data further is clean up any issues with the raw data. After some exploration we can see that some 8346 events belong to a userId value equal to that of an empty string. We discard this as it won’t help during modelling.

After cleaning and reformatting our date fields we are left with:

  • 225 relatively new users who registered in March earlier in the year
  • 2312 sessions
  • 278154 events
Most events are users listening to the next song

What we mean by “churn”?

For this exercise, we will not be predicting individual churn events rather churned users. The data therefore needs to be summarised to a customer level — a process during which we will build most of our features. Before exploring what those features are, we define churn by adding a column with value 1 representing a user who visited the ‘Cancellation Confirmation’ page and 0 for a user that never did.

Total Daily Events over October and November

Gender impact on churn

Continuing our exploration, we look at gender and its impact on churn.

As expected, gender does not have much of an impact on churn since the dataset is fairly balanced considering there are more males in the dataset than females.

Feature Engineering

  • Total days on Sparkify
  • Total unique pages visited
  • Total events
  • Total artists heard
  • Total unique songs heard
  • Avg length
  • Min length
  • Max length
  • Subscription level — paid or free
  • Total next song events
  • Total thumbs up events
  • Total thumbs down events
  • Total settings events
  • Total add friend events
  • Total add to playlist events
  • Total about events
  • Total error events
  • Total help events
  • Days between first last event
  • Gender
  • Days since registration — at time of event first event recorded

Along with these we have “userId” and “churned” as the label column.

Data Modeling

Data Preparation

1- Categorical columns need to be represented by their numeric index

2- All features need to be assembled into a feature vector and

3- We need to ensure that these features are scaled properly so that no feature over impacts the model simply due to nature of data

MLlib provides a cool feature to encapsulate these transformations in a staged pipeline using… Piplines()! First we create the gender and subscription indexers and then similarly create the vector assembler and the standard scaler all of which can be executed in a single pipeline using:

data_prep_pipline = Pipeline(stages=[gender_indexer, subcription_indexer, assembler, scaler])

scaled_indexed_df =

Data Modelling

We are trying to classify users into churn or not churn groups. Our training data set contains 23% churned users so the available data is not balanced. In this case accuracy of the model will be 77% if it simply labels everything is not churned. Therefore, we either downsample our negative class or use a different metric. In this project we will use f1 score since we want a good balance of precision and recall. We can change this metric based on requirements for example cost of making a wrong prediction can be high therefore each prediction needs to be correct even if we don’t capture each user who churns.

Base Models

After the last step, our input dataset is ready to be used for training. We start by using 3 models: Logistic Regression, Random Forest and Gradient Boosted Trees with default parameters. Again we will use the Pipeline object to embed the model in and execute. After running the three base models we get the following f1 scores:

Note that these vary from training to training depending on the input data select. On a larger data sample these would be more stable.

From the base models gradient boosted trees looks to have had the best performance with an f1 score of 0.83 in this iteration

Confusion Matrix for base GBT model

We already get a decent model which does a better job than marking everyone as not churned as this model makes 15 churned predictions 13 of which are correct.


We will use gradient boosted trees model to further optimise f1 score using parameter grid and cross validator. The params used in the parameter map for our tuned GBT model are: maxDepth: [2,5,10], maxBins: [5,10], maxIter: [5,10,50]

We can see that parameter tuning did not improve the model performance. This is not something that was expected so we’ll evaluate both GBT’s. First lets also look at the confusion matrix for this model:

The tuned model makes 9 mistakes compared to the default’s 7. However, the tuned model has better precision as each prediction as correct even though it classed more actual positives as false negatives.

Results and Evaluation

GBT with default params (best model)

Params: maxBins: 32, maxDepth: 5 and maxIter: 20.

Feature Importances:

GBT with params tuning

Params: maxBins: 10, maxDepth: 2 and maxIter: 5.

Feature Importances:

Both models tend to rank total_unique_pages_visited and days between first and last event as important features. That can be explained as those with higher number of pages and longer periods of usage would tend to have a more informed opinion of whether they want to continue using the service or not. Interestingly both models don’t point to adding friends as a key indicator which I expected to play a bigger role in a users likelihood of churning.


One of execution on the entire 12GB dataset

If the business intends to execute a single marketing campaign targeted to all those customers who are likely to churn then this model can be executed from within the notebook which is already written in spark framework. The results can be stored on disk and shared with the relevant stakeholders or applications.

Real time event based prediction

In this scenario, the model would be train periodically based on a predefined criteria and hosted as a web service which would be provided a prepared data. Preparing that user’s data and invoking the model’s prediction can be triggered by real time events so that customers that are showing signs of churning can be contacted in real time.

Scheduled batch predictions

Another solution can be to parameterise and schedule the spark job to collect, clean, and predict customer likelihood of churning at given time intervals.


The entire code other than a few visualisations are written in using pyspark therefore the code can next be executed within a distributed environment with 12GB of data (and a lot of processing power for the parameter grid!),

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store