In [ ]:
import ibmos2spark
# @hidden_cell
credentials = {
    'endpoint': 'https://s3-api.us-geo.objectstorage.service.networklayer.com',
    'service_id': 'iam-ServiceId-4749bc73-2bb7-4155-b345-28d6fde6d97c',
    'iam_service_endpoint': 'https://iam.ng.bluemix.net/oidc/token',
    'api_key': 'cil43xaAaMIgSyp6tdqioOE9dJSKb1wT9_nMPXwdfyuZ'
}

configuration_name = 'os_362c3529afdf436e8dc3c5cb207484d1_configs'
cos = ibmos2spark.CloudObjectStorage(sc, credentials, configuration_name, 'bluemix_cos')

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
# Since JSON data can be semi-structured and contain additional metadata, it is possible that you might face issues with the DataFrame layout.
# Please read the documentation of 'SparkSession.read()' to learn more about the possibilities to adjust the data loading.
# PySpark documentation: http://spark.apache.org/docs/2.0.2/api/python/pyspark.sql.html#pyspark.sql.DataFrameReader.json

spark_data = spark.read.json(cos.url('medium-sparkify-event-data.json', 'sparkify-donotdelete-pr-igw07vjrf2o5ef'))
spark_data.take(5)
Out[ ]:
[Row(artist='Martin Orford', auth='Logged In', firstName='Joseph', gender='M', itemInSession=20, lastName='Morales', length=597.55057, level='free', location='Corpus Christi, TX', method='PUT', page='NextSong', registration=1532063507000, sessionId=292, song='Grand Designs', status=200, ts=1538352011000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"', userId='293'),
 Row(artist="John Brown's Body", auth='Logged In', firstName='Sawyer', gender='M', itemInSession=74, lastName='Larson', length=380.21179, level='free', location='Houston-The Woodlands-Sugar Land, TX', method='PUT', page='NextSong', registration=1538069638000, sessionId=97, song='Bulls', status=200, ts=1538352025000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='98'),
 Row(artist='Afroman', auth='Logged In', firstName='Maverick', gender='M', itemInSession=184, lastName='Santiago', length=202.37016, level='paid', location='Orlando-Kissimmee-Sanford, FL', method='PUT', page='NextSong', registration=1535953455000, sessionId=178, song='Because I Got High', status=200, ts=1538352118000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='179'),
 Row(artist=None, auth='Logged In', firstName='Maverick', gender='M', itemInSession=185, lastName='Santiago', length=None, level='paid', location='Orlando-Kissimmee-Sanford, FL', method='PUT', page='Logout', registration=1535953455000, sessionId=178, song=None, status=307, ts=1538352119000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='179'),
 Row(artist='Lily Allen', auth='Logged In', firstName='Gianna', gender='F', itemInSession=22, lastName='Campos', length=194.53342, level='paid', location='Mobile, AL', method='PUT', page='NextSong', registration=1535931018000, sessionId=245, song='Smile (Radio Edit)', status=200, ts=1538352124000, userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0', userId='246')]
In [ ]:
# import libraries
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, IntegerType
from pyspark.sql.functions import *
from pyspark.sql.functions import sum as Fsum

from time import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

from pyspark.ml.feature import OneHotEncoder, StringIndexer, StandardScaler, Imputer, VectorAssembler,MinMaxScaler
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator 
from pyspark.mllib.evaluation import MulticlassMetrics 
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression, DecisionTreeClassifier, GBTClassifier

from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
In [ ]:
spark_data.printSchema()
root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)

In [ ]:
print("Number of records in the dataset:", spark_data.count())
spark_data.describe('sessionId').show()
spark_data.describe('userId').show()
spark_data.describe('artist').show()
Number of records in the dataset: 543705
+-------+------------------+
|summary|         sessionId|
+-------+------------------+
|  count|            543705|
|   mean|2040.8143533717732|
| stddev| 1434.338931078271|
|    min|                 1|
|    max|              4808|
+-------+------------------+

+-------+------------------+
|summary|            userId|
+-------+------------------+
|  count|            543705|
|   mean| 60268.42669103512|
| stddev|109898.82324176628|
|    min|                  |
|    max|                99|
+-------+------------------+

+-------+-----------------+
|summary|           artist|
+-------+-----------------+
|  count|           432877|
|   mean|527.5289537712895|
| stddev|966.1072451772758|
|    min|              !!!|
|    max|Ólafur Arnalds|
+-------+-----------------+

In [ ]:
#Drop null values in the userID and sessionID column 
spark_data_valid = spark_data.dropna(how='any',subset = ['userID','sessionID'])

#rows with UserID is blank
spark_data_invalid = spark_data_valid.filter(spark_data_valid['userID']=='')

#Filter out rows with UserID is blank
spark_data_valid = spark_data_valid.filter(spark_data_valid['userID']!='')
In [ ]:
spark_data_invalid.take(5)
Out[ ]:
[Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=186, lastName=None, length=None, level='paid', location=None, method='GET', page='Home', registration=None, sessionId=178, song=None, status=200, ts=1538352148000, userAgent=None, userId=''),
 Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=187, lastName=None, length=None, level='paid', location=None, method='GET', page='Home', registration=None, sessionId=178, song=None, status=200, ts=1538352151000, userAgent=None, userId=''),
 Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=188, lastName=None, length=None, level='paid', location=None, method='GET', page='Home', registration=None, sessionId=178, song=None, status=200, ts=1538352168000, userAgent=None, userId=''),
 Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=189, lastName=None, length=None, level='paid', location=None, method='PUT', page='Login', registration=None, sessionId=178, song=None, status=307, ts=1538352169000, userAgent=None, userId=''),
 Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=114, lastName=None, length=None, level='free', location=None, method='GET', page='Home', registration=None, sessionId=442, song=None, status=200, ts=1538353292000, userAgent=None, userId='')]
In [ ]:
print('# of Records after removing blank userID:', spark_data_valid.count())
print('# of Records with blank userID:', spark_data.count()-spark_data_valid.count())
# of Records after removing blank userID: 528005
# of Records with blank userID: 15700
In [ ]:
# List of page values
spark_data_valid.select(['userId','page']).groupBy('page').count().sort(desc('count')).show()
+--------------------+------+
|                page| count|
+--------------------+------+
|            NextSong|432877|
|           Thumbs Up| 23826|
|                Home| 19089|
|     Add to Playlist| 12349|
|          Add Friend|  8087|
|         Roll Advert|  7773|
|              Logout|  5990|
|         Thumbs Down|  4911|
|           Downgrade|  3811|
|            Settings|  2964|
|                Help|  2644|
|               About|  1026|
|             Upgrade|   968|
|       Save Settings|   585|
|               Error|   503|
|      Submit Upgrade|   287|
|    Submit Downgrade|   117|
|              Cancel|    99|
|Cancellation Conf...|    99|
+--------------------+------+

In [ ]:
# number of users by userAgent (dedup userId)
spark_data_valid.select(['userId','userAgent']).dropDuplicates().groupBy('userAgent').count().sort(desc('count')).show(100)
+--------------------+-----+
|           userAgent|count|
+--------------------+-----+
|"Mozilla/5.0 (Win...|   39|
|Mozilla/5.0 (Wind...|   34|
|"Mozilla/5.0 (Mac...|   30|
|"Mozilla/5.0 (Mac...|   25|
|"Mozilla/5.0 (Mac...|   22|
|"Mozilla/5.0 (Win...|   22|
|"Mozilla/5.0 (Mac...|   21|
|Mozilla/5.0 (Maci...|   17|
|Mozilla/5.0 (Wind...|   15|
|"Mozilla/5.0 (Mac...|   14|
|"Mozilla/5.0 (Win...|   12|
|Mozilla/5.0 (Wind...|   11|
|"Mozilla/5.0 (iPh...|   11|
|"Mozilla/5.0 (Win...|    9|
|"Mozilla/5.0 (Win...|    9|
|Mozilla/5.0 (X11;...|    7|
|"Mozilla/5.0 (iPh...|    6|
|Mozilla/5.0 (comp...|    6|
|"Mozilla/5.0 (Win...|    6|
|"Mozilla/5.0 (X11...|    6|
|"Mozilla/5.0 (iPa...|    6|
|"Mozilla/5.0 (Win...|    6|
|"Mozilla/5.0 (Win...|    6|
|"Mozilla/5.0 (Win...|    5|
|Mozilla/5.0 (Wind...|    5|
|"Mozilla/5.0 (Win...|    5|
|"Mozilla/5.0 (Win...|    4|
|Mozilla/5.0 (Wind...|    4|
|Mozilla/5.0 (Wind...|    4|
|Mozilla/5.0 (Wind...|    3|
|"Mozilla/5.0 (Mac...|    3|
|"Mozilla/5.0 (Mac...|    3|
|"Mozilla/5.0 (Mac...|    3|
|"Mozilla/5.0 (Mac...|    3|
|Mozilla/5.0 (comp...|    3|
|"Mozilla/5.0 (iPa...|    3|
|"Mozilla/5.0 (Mac...|    3|
|"Mozilla/5.0 (Mac...|    3|
|"Mozilla/5.0 (X11...|    3|
|"Mozilla/5.0 (Win...|    3|
|"Mozilla/5.0 (Mac...|    2|
|Mozilla/5.0 (comp...|    2|
|Mozilla/5.0 (Maci...|    2|
|Mozilla/5.0 (comp...|    2|
|Mozilla/5.0 (Maci...|    2|
|Mozilla/5.0 (X11;...|    2|
|"Mozilla/5.0 (X11...|    2|
|"Mozilla/5.0 (Mac...|    2|
|Mozilla/5.0 (Wind...|    2|
|"Mozilla/5.0 (Mac...|    2|
|"Mozilla/5.0 (Mac...|    2|
|Mozilla/5.0 (X11;...|    2|
|"Mozilla/5.0 (Mac...|    2|
|"Mozilla/5.0 (Mac...|    2|
|"Mozilla/5.0 (Mac...|    2|
|"Mozilla/5.0 (Mac...|    2|
|Mozilla/5.0 (X11;...|    2|
|Mozilla/5.0 (Wind...|    1|
|"Mozilla/5.0 (Mac...|    1|
|"Mozilla/5.0 (iPh...|    1|
|Mozilla/5.0 (Maci...|    1|
|"Mozilla/5.0 (Win...|    1|
|"Mozilla/5.0 (Mac...|    1|
|"Mozilla/5.0 (Win...|    1|
|Mozilla/5.0 (Maci...|    1|
|"Mozilla/5.0 (Mac...|    1|
|Mozilla/5.0 (Wind...|    1|
|"Mozilla/5.0 (X11...|    1|
|Mozilla/5.0 (Wind...|    1|
|"Mozilla/5.0 (X11...|    1|
|"Mozilla/5.0 (Mac...|    1|
+--------------------+-----+

In [ ]:
#can 1 user has multiple agents? Maybe Not :)"
spark_data_valid.select(['userId','userAgent']).dropDuplicates().groupBy('userId').count().sort(desc('count')).show(20)
+------+-----+
|userId|count|
+------+-----+
|   124|    1|
|     7|    1|
|    51|    1|
|100010|    1|
|200002|    1|
|   296|    1|
|200037|    1|
|   125|    1|
|   205|    1|
|   169|    1|
|   272|    1|
|    54|    1|
|   232|    1|
|   282|    1|
|   234|    1|
|    15|    1|
|   155|    1|
|200043|    1|
|   154|    1|
|   132|    1|
+------+-----+
only showing top 20 rows

In [ ]:
# number of unique users
print('# Users:',spark_data_valid.select(['userId']).dropDuplicates().count())

# number of users by Gender (dedup userId)
spark_data_valid.select(['userId','gender']).dropDuplicates().groupBy('gender').count().show()

# number of songs played by the most played artists
artist_chart = spark_data_valid.dropna(how='any',subset=['artist'])\
.groupBy('artist').count().sort(desc('count')).show(10)

# how many distinct sessions per user
spark_data_valid.groupBy('userID')\
    .agg(countDistinct("sessionId"))\
    .withColumnRenamed("count(DISTINCT sessionId)", "sessionIdCount") \
    .sort(desc("sessionIdCount")).show()
# Users: 448
+------+-----+
|gender|count|
+------+-----+
|     F|  198|
|     M|  250|
+------+-----+

+--------------------+-----+
|              artist|count|
+--------------------+-----+
|       Kings Of Leon| 3497|
|            Coldplay| 3439|
|Florence + The Ma...| 2314|
|                Muse| 2194|
|       Dwight Yoakam| 2187|
|      The Black Keys| 2160|
|            Björk| 2150|
|       Justin Bieber| 2096|
|        Jack Johnson| 2049|
|           Radiohead| 1694|
+--------------------+-----+
only showing top 10 rows

+------+--------------+
|userID|sessionIdCount|
+------+--------------+
|300049|            92|
|300035|            88|
|    92|            72|
|300017|            71|
|   140|            69|
|    87|            64|
|300011|            60|
|200023|            59|
|   293|            58|
|   195|            58|
|   230|            56|
|   101|            53|
|    42|            52|
|300021|            51|
|   250|            46|
|300031|            45|
|   121|            43|
|   100|            39|
|300038|            39|
|    29|            39|
+------+--------------+
only showing top 20 rows

In [ ]:
#### How many songs users listen to on average between visiting our home page?
# SOLUTION USING SPARK DATAFRAME

flag_home_func = udf(lambda x: 1 if x == 'Home' else 0, IntegerType())

#Create User Window
user_window = Window\
    .partitionBy('userId')\
    .orderBy('ts')\
    .rangeBetween(Window.unboundedPreceding,0)

# create 'homeSession' column to flag home event and 'songCount' to flag song played during home events
cusum = spark_data_valid.filter((spark_data_valid.page=='Home')| (spark_data_valid.page=='NextSong')) \
    .select(['userID','page','ts'])\
    .withColumn('homeSession',flag_home_func('page'))\
    .withColumn('songCount',Fsum('homeSession').over(user_window))

#cusum.filter((cusum.userID == 126)).show()

#calculate the average number of songs per user between vising homepage
cusum.filter(cusum.page == 'NextSong') \
    .groupBy('userID', 'songCount') \
    .agg(count('songCount'))\
    .agg(avg('count(songCount)'))\
    .show()
+---------------------+
|avg(count(songCount))|
+---------------------+
|   23.692025614361558|
+---------------------+

In [ ]:
# SOLUTION USING SPARK SQL
spark_data_valid.createOrReplaceTempView("log_table")

# SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
is_home = spark.sql("SELECT userID, page, ts, CASE WHEN page = 'Home' THEN 1 ELSE 0 END AS is_home FROM log_table \
            WHERE (page = 'NextSong') or (page = 'Home') \
            ")

# keep the results in a new view
is_home.createOrReplaceTempView("is_home_table")

# find the cumulative sum over the is_home column
cumulative_sum = spark.sql("SELECT *, SUM(is_home) OVER \
    (PARTITION BY userID ORDER BY ts DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS period \
    FROM is_home_table")

# keep the results in a view
cumulative_sum.createOrReplaceTempView("period_table")

# find the average count for NextSong
spark.sql("SELECT AVG(count_results) FROM \
          (SELECT COUNT(*) AS count_results FROM period_table \
GROUP BY userID, period, page HAVING page = 'NextSong') AS counts").show()
+------------------+
|avg(count_results)|
+------------------+
|23.672591053264792|
+------------------+

Exploratory Data Analysis¶

When you're working with the full dataset, perform EDA by loading a small subset of the data and doing basic manipulations within Spark. In this workspace, you are already provided a small subset of data you can explore.

Define Churn¶

Once you've done some preliminary analysis, create a column Churn to use as the label for your model. I suggest using the Cancellation Confirmation events to define your churn, which happen for both paid and free users. As a bonus task, you can also look into the Downgrade events.

1.FLAG EVENTS¶

In [ ]:
# event label 
def func_event(df,event,colName):
    '''Function to flag event in the dataframe
    INPUT:
    df - spark dataframe
    event - string - event on the page
    colName - string - new column on spark df
    
    OUTPUT:
    df - original df with new event added
    '''
    event_func = udf(lambda x: 1 if x==event else 0, IntegerType())
    df = df.withColumn(colName,event_func('page'))
    return df
In [ ]:
spark_data_valid=func_event(spark_data_valid,'Submit Downgrade','downgraded')
spark_data_valid=func_event(spark_data_valid,'Cancellation Confirmation','Churn')
spark_data_valid=func_event(spark_data_valid,'Submit Upgrade','upgraded')
spark_data_valid=func_event(spark_data_valid,'Thumbs Up','thumbs_up')
spark_data_valid=func_event(spark_data_valid,'Thumbs Down','thumbs_down')
spark_data_valid=func_event(spark_data_valid,'Add Friend','add_friend')
spark_data_valid=func_event(spark_data_valid,'Roll Advert','advert')
spark_data_valid=func_event(spark_data_valid,'Error','error')
spark_data_valid=func_event(spark_data_valid,'Add to Playlist','add_playlist')
spark_data_valid=func_event(spark_data_valid,'Settings','settings')
In [ ]:
#create paid level flag
flag_paid = udf(lambda x: 1 if x=='paid' else 0, IntegerType())
spark_data_valid = spark_data_valid.withColumn('paid',flag_paid('level'))
In [ ]:
spark_data_valid.take(5)
Out[ ]:
[Row(artist='Martin Orford', auth='Logged In', firstName='Joseph', gender='M', itemInSession=20, lastName='Morales', length=597.55057, level='free', location='Corpus Christi, TX', method='PUT', page='NextSong', registration=1532063507000, sessionId=292, song='Grand Designs', status=200, ts=1538352011000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"', userId='293', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=0),
 Row(artist="John Brown's Body", auth='Logged In', firstName='Sawyer', gender='M', itemInSession=74, lastName='Larson', length=380.21179, level='free', location='Houston-The Woodlands-Sugar Land, TX', method='PUT', page='NextSong', registration=1538069638000, sessionId=97, song='Bulls', status=200, ts=1538352025000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='98', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=0),
 Row(artist='Afroman', auth='Logged In', firstName='Maverick', gender='M', itemInSession=184, lastName='Santiago', length=202.37016, level='paid', location='Orlando-Kissimmee-Sanford, FL', method='PUT', page='NextSong', registration=1535953455000, sessionId=178, song='Because I Got High', status=200, ts=1538352118000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='179', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=1),
 Row(artist=None, auth='Logged In', firstName='Maverick', gender='M', itemInSession=185, lastName='Santiago', length=None, level='paid', location='Orlando-Kissimmee-Sanford, FL', method='PUT', page='Logout', registration=1535953455000, sessionId=178, song=None, status=307, ts=1538352119000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='179', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=1),
 Row(artist='Lily Allen', auth='Logged In', firstName='Gianna', gender='F', itemInSession=22, lastName='Campos', length=194.53342, level='paid', location='Mobile, AL', method='PUT', page='NextSong', registration=1535931018000, sessionId=245, song='Smile (Radio Edit)', status=200, ts=1538352124000, userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0', userId='246', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=1)]
In [ ]:
spark_data_valid.count()
Out[ ]:
528005
In [ ]:
spark_data_valid.select(['sessionId', 'page']).filter(spark_data_valid['sessionId']=='292').show(50)
+---------+---------------+
|sessionId|           page|
+---------+---------------+
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|      Thumbs Up|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|         Logout|
|      292|           Home|
|      292|     Add Friend|
|      292|       NextSong|
|      292|       NextSong|
|      292|Add to Playlist|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|    Roll Advert|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|Add to Playlist|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|    Roll Advert|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       Settings|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
|      292|       NextSong|
+---------+---------------+
only showing top 50 rows

In [ ]:
# identify latest level of user
flag_levels = udf(lambda x: 1 if x=="paid" else 0, IntegerType())
levels = spark_data_valid.select(['userId', 'level', 'ts'])\
    .orderBy(desc('ts'))\
    .dropDuplicates(['userId'])\
    .select(['userId', 'level'])\
    .withColumn('last_level', flag_levels('level').cast(IntegerType()))

levels.show()
levels.count()
+------+-----+----------+
|userId|level|last_level|
+------+-----+----------+
|100010| free|         0|
|200002| paid|         1|
|   296| paid|         1|
|   125| free|         0|
|   124| paid|         1|
|    51| paid|         1|
|     7| free|         0|
|200037| free|         0|
|   169| free|         0|
|   205| paid|         1|
|   272| free|         0|
|    15| paid|         1|
|   232| free|         0|
|   234| paid|         1|
|   282| paid|         1|
|    54| paid|         1|
|   155| paid|         1|
|200043| paid|         1|
|100014| paid|         1|
|   132| paid|         1|
+------+-----+----------+
only showing top 20 rows

Out[ ]:
448

2. EXPLORE DATA¶

Once you've defined churn, perform some exploratory data analysis to observe the behavior for users who stayed vs users who churned. You can start by exploring aggregates on these two groups of users, observing how much of a specific action they experienced per a certain time unit or number of songs played.

In [ ]:
# identify user_ids who churned
churn_user_list = spark_data_valid.filter(spark_data_valid.Churn==1).select(['userId']).collect()
churn_user_array = [int(row.userId) for row in churn_user_list]
print("#users who churned:", len(churn_user_array))
#users who churned: 99
In [ ]:
# identify user_ids who stayed
all_user_list = spark_data_valid.select(['userId']).dropDuplicates().collect()
all_user_array=[int(row.userId) for row in all_user_list]
stay_user_array = list(set(all_user_array)-set(churn_user_array))
print("#users who stayed:", len(stay_user_array))
#users who stayed: 349

Compare churn_users and stay_users in term of:¶

  • % paid vs. % free users between churned users and stayed users
  • how many songs played by churned users and stayed users on average
  • how much time they spent on the site?
  • have they downgraded before ?
  • how many thumbs-up or thumbs-down did they have?
In [ ]:
# Create a View table for churn and stay spark dataframe
spark_data_valid.createOrReplaceTempView("combinedView")
levels.createOrReplaceTempView("levels")
In [ ]:
#combine churn and stay df
sqlCombined= 'select v.userId,userAgent,l.last_level, case when gender="F" then 0 else 1 end as gender,\
    count(distinct sessionId) as distinctSession,\
    count(distinct artist) as artistCount,\
    count(page)/count(distinct sessionId) as avgPageSess,\
    count(song) as songCount, \
    sum(advert) as totalAdvert,\
    sum(add_playlist) as addPlaylist,\
    sum(error) as totalError,\
    sum(length) as totalTime,\
    sum(downgraded) as downgraded,\
    sum(upgraded) as upgraded,\
    sum(thumbs_up) as thumbs_up,\
    sum(thumbs_down) as thumbs_down,\
    sum(Churn) as churn,\
    case when sum(paid)>0 then 1\
    when sum(paid)=0 then 0\
    end as paid\
    from combinedView v\
    join levels l\
    on v.userId = l.userId\
    group by v.userId,gender,userAgent,l.last_level'

combined_agg_df = spark.sql(sqlCombined)
combined_agg_df.take(5)
Out[ ]:
[Row(userId='138', userAgent='"Mozilla/5.0 (iPad; CPU OS 7_1_1 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D201 Safari/9537.53"', last_level=0, gender=1, distinctSession=14, artistCount=370, avgPageSess=41.142857142857146, songCount=434, totalAdvert=37, addPlaylist=10, totalError=1, totalTime=107916.26365000002, downgraded=0, upgraded=0, thumbs_up=27, thumbs_down=4, churn=0, paid=0),
 Row(userId='300018', userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"', last_level=1, gender=1, distinctSession=22, artistCount=707, avgPageSess=54.0, songCount=888, totalAdvert=56, addPlaylist=27, totalError=1, totalTime=223907.28055000026, downgraded=0, upgraded=1, thumbs_up=79, thumbs_down=7, churn=0, paid=1),
 Row(userId='200004', userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.78.2 (KHTML, like Gecko) Version/7.0.6 Safari/537.78.2"', last_level=0, gender=1, distinctSession=32, artistCount=1004, avgPageSess=57.90625, songCount=1430, totalAdvert=98, addPlaylist=40, totalError=2, totalTime=354256.41273000004, downgraded=0, upgraded=1, thumbs_up=80, thumbs_down=42, churn=0, paid=1),
 Row(userId='202', userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_5) AppleWebKit/537.77.4 (KHTML, like Gecko) Version/6.1.5 Safari/537.77.4"', last_level=1, gender=1, distinctSession=2, artistCount=186, avgPageSess=122.0, songCount=202, totalAdvert=1, addPlaylist=5, totalError=1, totalTime=49361.72388000003, downgraded=0, upgraded=1, thumbs_up=9, thumbs_down=0, churn=0, paid=1),
 Row(userId='153', userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"', last_level=1, gender=1, distinctSession=9, artistCount=1029, avgPageSess=190.0, songCount=1454, totalAdvert=1, addPlaylist=36, totalError=1, totalTime=365271.0599899996, downgraded=0, upgraded=0, thumbs_up=77, thumbs_down=17, churn=0, paid=1)]
In [ ]:
# convert dataframe for visualization
print(combined_agg_df.count())
combined_agg_pd = combined_agg_df.toPandas()
448
In [ ]:
combined_agg_pd.head()
Out[ ]:
userId userAgent last_level gender distinctSession artistCount avgPageSess songCount totalAdvert addPlaylist totalError totalTime downgraded upgraded thumbs_up thumbs_down churn paid
0 300018 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4... 1 1 22 707 54.000000 888 56 27 1 223907.28055 0 1 79 7 0 1
1 138 "Mozilla/5.0 (iPad; CPU OS 7_1_1 like Mac OS X... 0 1 14 370 41.142857 434 37 10 1 107916.26365 0 0 27 4 0 0
2 200004 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4... 0 1 32 1004 57.906250 1430 98 40 2 354256.41273 0 1 80 42 0 1
3 202 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_5... 1 1 2 186 122.000000 202 1 5 1 49361.72388 0 1 9 0 0 1
4 153 "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK... 1 1 9 1029 190.000000 1454 1 36 1 365271.05999 0 0 77 17 0 1
In [ ]:
combined_agg_df.select(['userId','paid','churn']).dropDuplicates().groupBy('churn','paid').count().sort(desc('count')).show()
+-----+----+-----+
|churn|paid|count|
+-----+----+-----+
|    0|   1|  246|
|    0|   0|  103|
|    1|   1|   75|
|    1|   0|   24|
+-----+----+-----+

In [ ]:
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn")
g.map(sns.distplot,'totalTime')
g.add_legend();
In [ ]:
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn")
g.map(sns.distplot,'totalAdvert')
g.add_legend();
In [ ]:
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn")
g.map(sns.distplot,'songCount')
g.add_legend();
In [ ]:
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn")
g.map(sns.distplot,'artistCount')
g.add_legend();
In [ ]:
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn",row="paid",hue="gender")
g.map(plt.scatter,'distinctSession','songCount',s=50,alpha=0.3)
g.add_legend();
In [ ]:
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn",row="paid",hue="gender")
g.map(plt.scatter,'thumbs_up','thumbs_down',s=50,alpha=0.3)
g.add_legend();

From the above graphs we can observer the following:

  1. There are higher percentage of paid than non-paid Non-churned users
  2. Non-churn users spend more time on the site, have more song counts and have more activities than those who churned

Feature Engineering¶

In [ ]:
numericCols = combined_agg_df.columns
numericCols.remove("churn")
numericCols.remove("userId")
numericCols.remove("userAgent")
print(numericCols)
['last_level', 'gender', 'distinctSession', 'artistCount', 'avgPageSess', 'songCount', 'totalAdvert', 'addPlaylist', 'totalError', 'totalTime', 'downgraded', 'upgraded', 'thumbs_up', 'thumbs_down', 'paid']
In [ ]:
stages = []

# Transform all features into a vector using VectorAssembler
assemblerInputs = numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
standardscaler=StandardScaler().setInputCol("features").setOutputCol("scaledFeatures")
stages += [assembler,standardscaler]
In [ ]:
stages
Out[ ]:
[VectorAssembler_470da410c3ddaba6ae67, StandardScaler_4bfc868f66b52c7821f6]
In [ ]:
#Run the stages as a Pipeline. 
partialPipeline = Pipeline().setStages(stages)
pipelineModel = partialPipeline.fit(combined_agg_df)
preppedDataDF = pipelineModel.transform(combined_agg_df)
In [ ]:
preppedDataDF.select("features","scaledFeatures").show(5)
+--------------------+--------------------+
|            features|      scaledFeatures|
+--------------------+--------------------+
|[0.0,1.0,14.0,370...|[0.0,2.0113616675...|
|[1.0,1.0,22.0,707...|[2.04127203248654...|
|[1.0,1.0,2.0,186....|[2.04127203248654...|
|[0.0,1.0,23.0,820...|[0.0,2.0113616675...|
|[1.0,1.0,11.0,706...|[2.04127203248654...|
+--------------------+--------------------+
only showing top 5 rows

Modeling¶

Split the full dataset into train, test, and validation sets. Test out several of the machine learning methods you learned. Evaluate the accuracy of the various models, tuning parameters as necessary. Determine your winning model based on test accuracy and report results on the validation set. Since the churned users are a fairly small subset, I suggest using F1 score as the metric to optimize.

In [ ]:
# Split dataset into training, test 
(training, test) = preppedDataDF.randomSplit([0.7, 0.3], seed=42)
print('training count:', training.count())
print('test count:', test.count())
training count: 314
test count: 138
In [ ]:
training.take(1)
Out[ ]:
[Row(userId='138', userAgent='"Mozilla/5.0 (iPad; CPU OS 7_1_1 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D201 Safari/9537.53"', last_level=0, gender=1, distinctSession=14, artistCount=370, avgPageSess=41.142857142857146, songCount=434, totalAdvert=37, addPlaylist=10, totalError=1, totalTime=107916.26365000002, downgraded=0, upgraded=0, thumbs_up=27, thumbs_down=4, churn=0, paid=0, features=DenseVector([0.0, 1.0, 14.0, 370.0, 41.1429, 434.0, 37.0, 10.0, 1.0, 107916.2637, 0.0, 0.0, 27.0, 4.0, 0.0]), scaledFeatures=DenseVector([0.0, 2.0114, 1.0629, 0.5912, 0.9043, 0.3771, 1.7289, 0.2969, 0.5946, 0.377, 0.0, 0.0, 0.3876, 0.2995, 0.0]))]
In [ ]:
def trainingModel(classifier,training,test):
     '''
    This function train and test classification model with different classifiers and cross validation
    
    INPUT:
    classifier - classifier (DecisionTree, Random Forest, etc.)
    training - spark dataframe for training
    test - spark dataframe for test
    
    OUTPUT:
    model - trained model
    predict_test - predicted datafshow all records in test set
    
    '''
    start = time()
    model=classifier.fit(training)
    predict_train=model.transform(training)
    predict_test=model.transform(test)
    end = time()
    prediction_time = end - start
    print("Training time:", prediction_time)
    return model,predict_test

    
In [ ]:
def evaluateModel(predict_test):
    summary = {}
    evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="churn")
    summary['f1_score'] = evaluator.evaluate(predict_test, {evaluator.metricName: "f1"})
    summary['precision'] = evaluator.evaluate(predict_test, {evaluator.metricName: "weightedPrecision"})
    summary['recall'] = evaluator.evaluate(predict_test, {evaluator.metricName: "weightedRecall"})
    summary['accuracy'] = evaluator.evaluate(predict_test, {evaluator.metricName: "accuracy"})
    confusion_matrix = predict_test.groupby("Churn").pivot("prediction").count()
    return summary,confusion_matrix
In [ ]:
# Try logistics Regression with no Params Tuning
lr = LogisticRegression(labelCol="churn", featuresCol="scaledFeatures",maxIter=10)

model_lr,predict_test_lr = trainingModel(lr,training,test)
Training time: 54.073190689086914
In [ ]:
summary_lr,confusion_matrix_lr = evaluateModel(predict_test_lr)
In [ ]:
# Try RandomForest with no Params Tuning
rf = RandomForestClassifier(labelCol="churn", featuresCol="scaledFeatures")
model_rf,predict_test_rf = trainingModel(rf,training,test)
Training time: 105.55780577659607
In [ ]:
summary_rf,confusion_matrix_rf = evaluateModel(predict_test_rf)
{'f1_score': 0.7702669303015324, 'precision': 0.7798353909465021, 'recall': 0.8074074074074075, 'accuracy': 0.7928571428571428}
+-----+---+---+
|Churn|0.0|1.0|
+-----+---+---+
|    0|103|  4|
|    1| 26|  4|
+-----+---+---+

In [ ]:
# Try GBTClassifier with no Params Tuning
gbtc = GBTClassifier(labelCol="churn", featuresCol="scaledFeatures")
model_gbtc,predict_test_gbtc = trainingModel(gbtc,training,test)
Training time: 213.31944227218628
In [ ]:
summary_gbtc,confusion_matrix_gbtc = evaluateModel(predict_test_gbtc)
{'f1_score': 0.7471422611597831, 'precision': 0.7164179104477612, 'recall': 0.746376811594203, 'accuracy': 0.7388059701492538}
+-----+---+---+
|Churn|0.0|1.0|
+-----+---+---+
|    0| 89| 16|
|    1| 22|  9|
+-----+---+---+

In [ ]:
print("Logistics Regression:", summary_lr)
confusion_matrix_lr.show()
print("Random Forest:", summary_rf)
confusion_matrix_rf.show()
print("Gradient-Boosted Trees (GBTs):", summary_gbtc)
confusion_matrix_gbtc.show()
Logistics Regression: {'f1_score': 0.732241250930752, 'precision': 0.7723076923076923, 'recall': 0.801470588235294, 'accuracy': 0.8074074074074075}
+-----+---+---+
|Churn|0.0|1.0|
+-----+---+---+
|    0|108|  3|
|    1| 26|  4|
+-----+---+---+

Random Forest: {'f1_score': 0.7702669303015324, 'precision': 0.7798353909465021, 'recall': 0.8074074074074075, 'accuracy': 0.7928571428571428}
+-----+---+---+
|Churn|0.0|1.0|
+-----+---+---+
|    0|109|  4|
|    1| 24|  4|
+-----+---+---+

Gradient-Boosted Trees (GBTs): {'f1_score': 0.7471422611597831, 'precision': 0.7164179104477612, 'recall': 0.746376811594203, 'accuracy': 0.7388059701492538}
+-----+---+---+
|Churn|0.0|1.0|
+-----+---+---+
|    0| 91| 15|
|    1| 19|  7|
+-----+---+---+

In [ ]:
# Build model

def trainModelCV(classifier, paramGrid):
    
    '''
    This function train and test classification model with different classifiers and cross validation
    
    INPUT:
    classifier - classifier (DecisionTree, Random Forest, etc.)
    param - parmaGrid for Cross Validation based of different classifiers
    
    OUTPUT:
    model - trained model
    predict_test - predicted data
    
    '''
    start = time()
    evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="churn")

    # Create 5-fold CrossValidator
    cv = CrossValidator(estimator=classifier, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)

    # Run cross validations
    model = cv.fit(training)
    # this will likely take a fair amount of time because of the amount of models that we're creating and testing
    
    # Use test set to measure the accuracy of our model on new data
    predict_test = model.transform(test)

    end = time()
    prediction_time = end - start
    print("Training time:", prediction_time)
    
    return model,predict_test
In [ ]:
# build RandomForest model

rf = RandomForestClassifier(labelCol="churn", featuresCol="scaledFeatures")
paramRF = (ParamGridBuilder()
             .addGrid(rf.maxDepth, [2, 4, 6])
             .addGrid(rf.maxBins, [20, 60])
             .addGrid(rf.numTrees, [5, 20])
             .build())

modelRF,predict_testCV_RF = trainModelCV(rf,paramRF)
Training time: 1350.1863057613373
In [ ]:
summary_rf,confusion_matrix_rf = evaluateModel(predict_testCV_RF)
print(summary_rf)
confusion_matrix_rf.show()
{'f1_score': 0.7406385883407707, 'precision': 0.7363184079601991, 'recall': 0.7938931297709924, 'accuracy': 0.8045112781954887}
+-----+---+---+
|Churn|0.0|1.0|
+-----+---+---+
|    0|100|  5|
|    1| 23|  6|
+-----+---+---+

In [ ]: