import ibmos2spark
# @hidden_cell
credentials = {
'endpoint': 'https://s3-api.us-geo.objectstorage.service.networklayer.com',
'service_id': 'iam-ServiceId-4749bc73-2bb7-4155-b345-28d6fde6d97c',
'iam_service_endpoint': 'https://iam.ng.bluemix.net/oidc/token',
'api_key': 'cil43xaAaMIgSyp6tdqioOE9dJSKb1wT9_nMPXwdfyuZ'
}
configuration_name = 'os_362c3529afdf436e8dc3c5cb207484d1_configs'
cos = ibmos2spark.CloudObjectStorage(sc, credentials, configuration_name, 'bluemix_cos')
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
# Since JSON data can be semi-structured and contain additional metadata, it is possible that you might face issues with the DataFrame layout.
# Please read the documentation of 'SparkSession.read()' to learn more about the possibilities to adjust the data loading.
# PySpark documentation: http://spark.apache.org/docs/2.0.2/api/python/pyspark.sql.html#pyspark.sql.DataFrameReader.json
spark_data = spark.read.json(cos.url('medium-sparkify-event-data.json', 'sparkify-donotdelete-pr-igw07vjrf2o5ef'))
spark_data.take(5)
[Row(artist='Martin Orford', auth='Logged In', firstName='Joseph', gender='M', itemInSession=20, lastName='Morales', length=597.55057, level='free', location='Corpus Christi, TX', method='PUT', page='NextSong', registration=1532063507000, sessionId=292, song='Grand Designs', status=200, ts=1538352011000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"', userId='293'), Row(artist="John Brown's Body", auth='Logged In', firstName='Sawyer', gender='M', itemInSession=74, lastName='Larson', length=380.21179, level='free', location='Houston-The Woodlands-Sugar Land, TX', method='PUT', page='NextSong', registration=1538069638000, sessionId=97, song='Bulls', status=200, ts=1538352025000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='98'), Row(artist='Afroman', auth='Logged In', firstName='Maverick', gender='M', itemInSession=184, lastName='Santiago', length=202.37016, level='paid', location='Orlando-Kissimmee-Sanford, FL', method='PUT', page='NextSong', registration=1535953455000, sessionId=178, song='Because I Got High', status=200, ts=1538352118000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='179'), Row(artist=None, auth='Logged In', firstName='Maverick', gender='M', itemInSession=185, lastName='Santiago', length=None, level='paid', location='Orlando-Kissimmee-Sanford, FL', method='PUT', page='Logout', registration=1535953455000, sessionId=178, song=None, status=307, ts=1538352119000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='179'), Row(artist='Lily Allen', auth='Logged In', firstName='Gianna', gender='F', itemInSession=22, lastName='Campos', length=194.53342, level='paid', location='Mobile, AL', method='PUT', page='NextSong', registration=1535931018000, sessionId=245, song='Smile (Radio Edit)', status=200, ts=1538352124000, userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0', userId='246')]
# import libraries
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, IntegerType
from pyspark.sql.functions import *
from pyspark.sql.functions import sum as Fsum
from time import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, StandardScaler, Imputer, VectorAssembler,MinMaxScaler
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression, DecisionTreeClassifier, GBTClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
spark_data.printSchema()
root |-- artist: string (nullable = true) |-- auth: string (nullable = true) |-- firstName: string (nullable = true) |-- gender: string (nullable = true) |-- itemInSession: long (nullable = true) |-- lastName: string (nullable = true) |-- length: double (nullable = true) |-- level: string (nullable = true) |-- location: string (nullable = true) |-- method: string (nullable = true) |-- page: string (nullable = true) |-- registration: long (nullable = true) |-- sessionId: long (nullable = true) |-- song: string (nullable = true) |-- status: long (nullable = true) |-- ts: long (nullable = true) |-- userAgent: string (nullable = true) |-- userId: string (nullable = true)
print("Number of records in the dataset:", spark_data.count())
spark_data.describe('sessionId').show()
spark_data.describe('userId').show()
spark_data.describe('artist').show()
Number of records in the dataset: 543705 +-------+------------------+ |summary| sessionId| +-------+------------------+ | count| 543705| | mean|2040.8143533717732| | stddev| 1434.338931078271| | min| 1| | max| 4808| +-------+------------------+ +-------+------------------+ |summary| userId| +-------+------------------+ | count| 543705| | mean| 60268.42669103512| | stddev|109898.82324176628| | min| | | max| 99| +-------+------------------+ +-------+-----------------+ |summary| artist| +-------+-----------------+ | count| 432877| | mean|527.5289537712895| | stddev|966.1072451772758| | min| !!!| | max|ÃÂlafur Arnalds| +-------+-----------------+
#Drop null values in the userID and sessionID column
spark_data_valid = spark_data.dropna(how='any',subset = ['userID','sessionID'])
#rows with UserID is blank
spark_data_invalid = spark_data_valid.filter(spark_data_valid['userID']=='')
#Filter out rows with UserID is blank
spark_data_valid = spark_data_valid.filter(spark_data_valid['userID']!='')
spark_data_invalid.take(5)
[Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=186, lastName=None, length=None, level='paid', location=None, method='GET', page='Home', registration=None, sessionId=178, song=None, status=200, ts=1538352148000, userAgent=None, userId=''), Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=187, lastName=None, length=None, level='paid', location=None, method='GET', page='Home', registration=None, sessionId=178, song=None, status=200, ts=1538352151000, userAgent=None, userId=''), Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=188, lastName=None, length=None, level='paid', location=None, method='GET', page='Home', registration=None, sessionId=178, song=None, status=200, ts=1538352168000, userAgent=None, userId=''), Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=189, lastName=None, length=None, level='paid', location=None, method='PUT', page='Login', registration=None, sessionId=178, song=None, status=307, ts=1538352169000, userAgent=None, userId=''), Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=114, lastName=None, length=None, level='free', location=None, method='GET', page='Home', registration=None, sessionId=442, song=None, status=200, ts=1538353292000, userAgent=None, userId='')]
print('# of Records after removing blank userID:', spark_data_valid.count())
print('# of Records with blank userID:', spark_data.count()-spark_data_valid.count())
# of Records after removing blank userID: 528005 # of Records with blank userID: 15700
# List of page values
spark_data_valid.select(['userId','page']).groupBy('page').count().sort(desc('count')).show()
+--------------------+------+ | page| count| +--------------------+------+ | NextSong|432877| | Thumbs Up| 23826| | Home| 19089| | Add to Playlist| 12349| | Add Friend| 8087| | Roll Advert| 7773| | Logout| 5990| | Thumbs Down| 4911| | Downgrade| 3811| | Settings| 2964| | Help| 2644| | About| 1026| | Upgrade| 968| | Save Settings| 585| | Error| 503| | Submit Upgrade| 287| | Submit Downgrade| 117| | Cancel| 99| |Cancellation Conf...| 99| +--------------------+------+
# number of users by userAgent (dedup userId)
spark_data_valid.select(['userId','userAgent']).dropDuplicates().groupBy('userAgent').count().sort(desc('count')).show(100)
+--------------------+-----+ | userAgent|count| +--------------------+-----+ |"Mozilla/5.0 (Win...| 39| |Mozilla/5.0 (Wind...| 34| |"Mozilla/5.0 (Mac...| 30| |"Mozilla/5.0 (Mac...| 25| |"Mozilla/5.0 (Mac...| 22| |"Mozilla/5.0 (Win...| 22| |"Mozilla/5.0 (Mac...| 21| |Mozilla/5.0 (Maci...| 17| |Mozilla/5.0 (Wind...| 15| |"Mozilla/5.0 (Mac...| 14| |"Mozilla/5.0 (Win...| 12| |Mozilla/5.0 (Wind...| 11| |"Mozilla/5.0 (iPh...| 11| |"Mozilla/5.0 (Win...| 9| |"Mozilla/5.0 (Win...| 9| |Mozilla/5.0 (X11;...| 7| |"Mozilla/5.0 (iPh...| 6| |Mozilla/5.0 (comp...| 6| |"Mozilla/5.0 (Win...| 6| |"Mozilla/5.0 (X11...| 6| |"Mozilla/5.0 (iPa...| 6| |"Mozilla/5.0 (Win...| 6| |"Mozilla/5.0 (Win...| 6| |"Mozilla/5.0 (Win...| 5| |Mozilla/5.0 (Wind...| 5| |"Mozilla/5.0 (Win...| 5| |"Mozilla/5.0 (Win...| 4| |Mozilla/5.0 (Wind...| 4| |Mozilla/5.0 (Wind...| 4| |Mozilla/5.0 (Wind...| 3| |"Mozilla/5.0 (Mac...| 3| |"Mozilla/5.0 (Mac...| 3| |"Mozilla/5.0 (Mac...| 3| |"Mozilla/5.0 (Mac...| 3| |Mozilla/5.0 (comp...| 3| |"Mozilla/5.0 (iPa...| 3| |"Mozilla/5.0 (Mac...| 3| |"Mozilla/5.0 (Mac...| 3| |"Mozilla/5.0 (X11...| 3| |"Mozilla/5.0 (Win...| 3| |"Mozilla/5.0 (Mac...| 2| |Mozilla/5.0 (comp...| 2| |Mozilla/5.0 (Maci...| 2| |Mozilla/5.0 (comp...| 2| |Mozilla/5.0 (Maci...| 2| |Mozilla/5.0 (X11;...| 2| |"Mozilla/5.0 (X11...| 2| |"Mozilla/5.0 (Mac...| 2| |Mozilla/5.0 (Wind...| 2| |"Mozilla/5.0 (Mac...| 2| |"Mozilla/5.0 (Mac...| 2| |Mozilla/5.0 (X11;...| 2| |"Mozilla/5.0 (Mac...| 2| |"Mozilla/5.0 (Mac...| 2| |"Mozilla/5.0 (Mac...| 2| |"Mozilla/5.0 (Mac...| 2| |Mozilla/5.0 (X11;...| 2| |Mozilla/5.0 (Wind...| 1| |"Mozilla/5.0 (Mac...| 1| |"Mozilla/5.0 (iPh...| 1| |Mozilla/5.0 (Maci...| 1| |"Mozilla/5.0 (Win...| 1| |"Mozilla/5.0 (Mac...| 1| |"Mozilla/5.0 (Win...| 1| |Mozilla/5.0 (Maci...| 1| |"Mozilla/5.0 (Mac...| 1| |Mozilla/5.0 (Wind...| 1| |"Mozilla/5.0 (X11...| 1| |Mozilla/5.0 (Wind...| 1| |"Mozilla/5.0 (X11...| 1| |"Mozilla/5.0 (Mac...| 1| +--------------------+-----+
#can 1 user has multiple agents? Maybe Not :)"
spark_data_valid.select(['userId','userAgent']).dropDuplicates().groupBy('userId').count().sort(desc('count')).show(20)
+------+-----+ |userId|count| +------+-----+ | 124| 1| | 7| 1| | 51| 1| |100010| 1| |200002| 1| | 296| 1| |200037| 1| | 125| 1| | 205| 1| | 169| 1| | 272| 1| | 54| 1| | 232| 1| | 282| 1| | 234| 1| | 15| 1| | 155| 1| |200043| 1| | 154| 1| | 132| 1| +------+-----+ only showing top 20 rows
# number of unique users
print('# Users:',spark_data_valid.select(['userId']).dropDuplicates().count())
# number of users by Gender (dedup userId)
spark_data_valid.select(['userId','gender']).dropDuplicates().groupBy('gender').count().show()
# number of songs played by the most played artists
artist_chart = spark_data_valid.dropna(how='any',subset=['artist'])\
.groupBy('artist').count().sort(desc('count')).show(10)
# how many distinct sessions per user
spark_data_valid.groupBy('userID')\
.agg(countDistinct("sessionId"))\
.withColumnRenamed("count(DISTINCT sessionId)", "sessionIdCount") \
.sort(desc("sessionIdCount")).show()
# Users: 448 +------+-----+ |gender|count| +------+-----+ | F| 198| | M| 250| +------+-----+ +--------------------+-----+ | artist|count| +--------------------+-----+ | Kings Of Leon| 3497| | Coldplay| 3439| |Florence + The Ma...| 2314| | Muse| 2194| | Dwight Yoakam| 2187| | The Black Keys| 2160| | Björk| 2150| | Justin Bieber| 2096| | Jack Johnson| 2049| | Radiohead| 1694| +--------------------+-----+ only showing top 10 rows +------+--------------+ |userID|sessionIdCount| +------+--------------+ |300049| 92| |300035| 88| | 92| 72| |300017| 71| | 140| 69| | 87| 64| |300011| 60| |200023| 59| | 293| 58| | 195| 58| | 230| 56| | 101| 53| | 42| 52| |300021| 51| | 250| 46| |300031| 45| | 121| 43| | 100| 39| |300038| 39| | 29| 39| +------+--------------+ only showing top 20 rows
#### How many songs users listen to on average between visiting our home page?
# SOLUTION USING SPARK DATAFRAME
flag_home_func = udf(lambda x: 1 if x == 'Home' else 0, IntegerType())
#Create User Window
user_window = Window\
.partitionBy('userId')\
.orderBy('ts')\
.rangeBetween(Window.unboundedPreceding,0)
# create 'homeSession' column to flag home event and 'songCount' to flag song played during home events
cusum = spark_data_valid.filter((spark_data_valid.page=='Home')| (spark_data_valid.page=='NextSong')) \
.select(['userID','page','ts'])\
.withColumn('homeSession',flag_home_func('page'))\
.withColumn('songCount',Fsum('homeSession').over(user_window))
#cusum.filter((cusum.userID == 126)).show()
#calculate the average number of songs per user between vising homepage
cusum.filter(cusum.page == 'NextSong') \
.groupBy('userID', 'songCount') \
.agg(count('songCount'))\
.agg(avg('count(songCount)'))\
.show()
+---------------------+ |avg(count(songCount))| +---------------------+ | 23.692025614361558| +---------------------+
# SOLUTION USING SPARK SQL
spark_data_valid.createOrReplaceTempView("log_table")
# SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
is_home = spark.sql("SELECT userID, page, ts, CASE WHEN page = 'Home' THEN 1 ELSE 0 END AS is_home FROM log_table \
WHERE (page = 'NextSong') or (page = 'Home') \
")
# keep the results in a new view
is_home.createOrReplaceTempView("is_home_table")
# find the cumulative sum over the is_home column
cumulative_sum = spark.sql("SELECT *, SUM(is_home) OVER \
(PARTITION BY userID ORDER BY ts DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS period \
FROM is_home_table")
# keep the results in a view
cumulative_sum.createOrReplaceTempView("period_table")
# find the average count for NextSong
spark.sql("SELECT AVG(count_results) FROM \
(SELECT COUNT(*) AS count_results FROM period_table \
GROUP BY userID, period, page HAVING page = 'NextSong') AS counts").show()
+------------------+ |avg(count_results)| +------------------+ |23.672591053264792| +------------------+
When you're working with the full dataset, perform EDA by loading a small subset of the data and doing basic manipulations within Spark. In this workspace, you are already provided a small subset of data you can explore.
Once you've done some preliminary analysis, create a column Churn
to use as the label for your model. I suggest using the Cancellation Confirmation
events to define your churn, which happen for both paid and free users. As a bonus task, you can also look into the Downgrade
events.
# event label
def func_event(df,event,colName):
'''Function to flag event in the dataframe
INPUT:
df - spark dataframe
event - string - event on the page
colName - string - new column on spark df
OUTPUT:
df - original df with new event added
'''
event_func = udf(lambda x: 1 if x==event else 0, IntegerType())
df = df.withColumn(colName,event_func('page'))
return df
spark_data_valid=func_event(spark_data_valid,'Submit Downgrade','downgraded')
spark_data_valid=func_event(spark_data_valid,'Cancellation Confirmation','Churn')
spark_data_valid=func_event(spark_data_valid,'Submit Upgrade','upgraded')
spark_data_valid=func_event(spark_data_valid,'Thumbs Up','thumbs_up')
spark_data_valid=func_event(spark_data_valid,'Thumbs Down','thumbs_down')
spark_data_valid=func_event(spark_data_valid,'Add Friend','add_friend')
spark_data_valid=func_event(spark_data_valid,'Roll Advert','advert')
spark_data_valid=func_event(spark_data_valid,'Error','error')
spark_data_valid=func_event(spark_data_valid,'Add to Playlist','add_playlist')
spark_data_valid=func_event(spark_data_valid,'Settings','settings')
#create paid level flag
flag_paid = udf(lambda x: 1 if x=='paid' else 0, IntegerType())
spark_data_valid = spark_data_valid.withColumn('paid',flag_paid('level'))
spark_data_valid.take(5)
[Row(artist='Martin Orford', auth='Logged In', firstName='Joseph', gender='M', itemInSession=20, lastName='Morales', length=597.55057, level='free', location='Corpus Christi, TX', method='PUT', page='NextSong', registration=1532063507000, sessionId=292, song='Grand Designs', status=200, ts=1538352011000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"', userId='293', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=0), Row(artist="John Brown's Body", auth='Logged In', firstName='Sawyer', gender='M', itemInSession=74, lastName='Larson', length=380.21179, level='free', location='Houston-The Woodlands-Sugar Land, TX', method='PUT', page='NextSong', registration=1538069638000, sessionId=97, song='Bulls', status=200, ts=1538352025000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='98', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=0), Row(artist='Afroman', auth='Logged In', firstName='Maverick', gender='M', itemInSession=184, lastName='Santiago', length=202.37016, level='paid', location='Orlando-Kissimmee-Sanford, FL', method='PUT', page='NextSong', registration=1535953455000, sessionId=178, song='Because I Got High', status=200, ts=1538352118000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='179', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=1), Row(artist=None, auth='Logged In', firstName='Maverick', gender='M', itemInSession=185, lastName='Santiago', length=None, level='paid', location='Orlando-Kissimmee-Sanford, FL', method='PUT', page='Logout', registration=1535953455000, sessionId=178, song=None, status=307, ts=1538352119000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId='179', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=1), Row(artist='Lily Allen', auth='Logged In', firstName='Gianna', gender='F', itemInSession=22, lastName='Campos', length=194.53342, level='paid', location='Mobile, AL', method='PUT', page='NextSong', registration=1535931018000, sessionId=245, song='Smile (Radio Edit)', status=200, ts=1538352124000, userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0', userId='246', downgraded=0, Churn=0, upgraded=0, thumbs_up=0, thumbs_down=0, add_friend=0, advert=0, error=0, add_playlist=0, settings=0, paid=1)]
spark_data_valid.count()
528005
spark_data_valid.select(['sessionId', 'page']).filter(spark_data_valid['sessionId']=='292').show(50)
+---------+---------------+ |sessionId| page| +---------+---------------+ | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| Thumbs Up| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| Logout| | 292| Home| | 292| Add Friend| | 292| NextSong| | 292| NextSong| | 292|Add to Playlist| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| Roll Advert| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292|Add to Playlist| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| Roll Advert| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| Settings| | 292| NextSong| | 292| NextSong| | 292| NextSong| | 292| NextSong| +---------+---------------+ only showing top 50 rows
# identify latest level of user
flag_levels = udf(lambda x: 1 if x=="paid" else 0, IntegerType())
levels = spark_data_valid.select(['userId', 'level', 'ts'])\
.orderBy(desc('ts'))\
.dropDuplicates(['userId'])\
.select(['userId', 'level'])\
.withColumn('last_level', flag_levels('level').cast(IntegerType()))
levels.show()
levels.count()
+------+-----+----------+ |userId|level|last_level| +------+-----+----------+ |100010| free| 0| |200002| paid| 1| | 296| paid| 1| | 125| free| 0| | 124| paid| 1| | 51| paid| 1| | 7| free| 0| |200037| free| 0| | 169| free| 0| | 205| paid| 1| | 272| free| 0| | 15| paid| 1| | 232| free| 0| | 234| paid| 1| | 282| paid| 1| | 54| paid| 1| | 155| paid| 1| |200043| paid| 1| |100014| paid| 1| | 132| paid| 1| +------+-----+----------+ only showing top 20 rows
448
Once you've defined churn, perform some exploratory data analysis to observe the behavior for users who stayed vs users who churned. You can start by exploring aggregates on these two groups of users, observing how much of a specific action they experienced per a certain time unit or number of songs played.
# identify user_ids who churned
churn_user_list = spark_data_valid.filter(spark_data_valid.Churn==1).select(['userId']).collect()
churn_user_array = [int(row.userId) for row in churn_user_list]
print("#users who churned:", len(churn_user_array))
#users who churned: 99
# identify user_ids who stayed
all_user_list = spark_data_valid.select(['userId']).dropDuplicates().collect()
all_user_array=[int(row.userId) for row in all_user_list]
stay_user_array = list(set(all_user_array)-set(churn_user_array))
print("#users who stayed:", len(stay_user_array))
#users who stayed: 349
# Create a View table for churn and stay spark dataframe
spark_data_valid.createOrReplaceTempView("combinedView")
levels.createOrReplaceTempView("levels")
#combine churn and stay df
sqlCombined= 'select v.userId,userAgent,l.last_level, case when gender="F" then 0 else 1 end as gender,\
count(distinct sessionId) as distinctSession,\
count(distinct artist) as artistCount,\
count(page)/count(distinct sessionId) as avgPageSess,\
count(song) as songCount, \
sum(advert) as totalAdvert,\
sum(add_playlist) as addPlaylist,\
sum(error) as totalError,\
sum(length) as totalTime,\
sum(downgraded) as downgraded,\
sum(upgraded) as upgraded,\
sum(thumbs_up) as thumbs_up,\
sum(thumbs_down) as thumbs_down,\
sum(Churn) as churn,\
case when sum(paid)>0 then 1\
when sum(paid)=0 then 0\
end as paid\
from combinedView v\
join levels l\
on v.userId = l.userId\
group by v.userId,gender,userAgent,l.last_level'
combined_agg_df = spark.sql(sqlCombined)
combined_agg_df.take(5)
[Row(userId='138', userAgent='"Mozilla/5.0 (iPad; CPU OS 7_1_1 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D201 Safari/9537.53"', last_level=0, gender=1, distinctSession=14, artistCount=370, avgPageSess=41.142857142857146, songCount=434, totalAdvert=37, addPlaylist=10, totalError=1, totalTime=107916.26365000002, downgraded=0, upgraded=0, thumbs_up=27, thumbs_down=4, churn=0, paid=0), Row(userId='300018', userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"', last_level=1, gender=1, distinctSession=22, artistCount=707, avgPageSess=54.0, songCount=888, totalAdvert=56, addPlaylist=27, totalError=1, totalTime=223907.28055000026, downgraded=0, upgraded=1, thumbs_up=79, thumbs_down=7, churn=0, paid=1), Row(userId='200004', userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.78.2 (KHTML, like Gecko) Version/7.0.6 Safari/537.78.2"', last_level=0, gender=1, distinctSession=32, artistCount=1004, avgPageSess=57.90625, songCount=1430, totalAdvert=98, addPlaylist=40, totalError=2, totalTime=354256.41273000004, downgraded=0, upgraded=1, thumbs_up=80, thumbs_down=42, churn=0, paid=1), Row(userId='202', userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_5) AppleWebKit/537.77.4 (KHTML, like Gecko) Version/6.1.5 Safari/537.77.4"', last_level=1, gender=1, distinctSession=2, artistCount=186, avgPageSess=122.0, songCount=202, totalAdvert=1, addPlaylist=5, totalError=1, totalTime=49361.72388000003, downgraded=0, upgraded=1, thumbs_up=9, thumbs_down=0, churn=0, paid=1), Row(userId='153', userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"', last_level=1, gender=1, distinctSession=9, artistCount=1029, avgPageSess=190.0, songCount=1454, totalAdvert=1, addPlaylist=36, totalError=1, totalTime=365271.0599899996, downgraded=0, upgraded=0, thumbs_up=77, thumbs_down=17, churn=0, paid=1)]
# convert dataframe for visualization
print(combined_agg_df.count())
combined_agg_pd = combined_agg_df.toPandas()
448
combined_agg_pd.head()
userId | userAgent | last_level | gender | distinctSession | artistCount | avgPageSess | songCount | totalAdvert | addPlaylist | totalError | totalTime | downgraded | upgraded | thumbs_up | thumbs_down | churn | paid | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 300018 | "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4... | 1 | 1 | 22 | 707 | 54.000000 | 888 | 56 | 27 | 1 | 223907.28055 | 0 | 1 | 79 | 7 | 0 | 1 |
1 | 138 | "Mozilla/5.0 (iPad; CPU OS 7_1_1 like Mac OS X... | 0 | 1 | 14 | 370 | 41.142857 | 434 | 37 | 10 | 1 | 107916.26365 | 0 | 0 | 27 | 4 | 0 | 0 |
2 | 200004 | "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4... | 0 | 1 | 32 | 1004 | 57.906250 | 1430 | 98 | 40 | 2 | 354256.41273 | 0 | 1 | 80 | 42 | 0 | 1 |
3 | 202 | "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_5... | 1 | 1 | 2 | 186 | 122.000000 | 202 | 1 | 5 | 1 | 49361.72388 | 0 | 1 | 9 | 0 | 0 | 1 |
4 | 153 | "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebK... | 1 | 1 | 9 | 1029 | 190.000000 | 1454 | 1 | 36 | 1 | 365271.05999 | 0 | 0 | 77 | 17 | 0 | 1 |
combined_agg_df.select(['userId','paid','churn']).dropDuplicates().groupBy('churn','paid').count().sort(desc('count')).show()
+-----+----+-----+ |churn|paid|count| +-----+----+-----+ | 0| 1| 246| | 0| 0| 103| | 1| 1| 75| | 1| 0| 24| +-----+----+-----+
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn")
g.map(sns.distplot,'totalTime')
g.add_legend();
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn")
g.map(sns.distplot,'totalAdvert')
g.add_legend();
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn")
g.map(sns.distplot,'songCount')
g.add_legend();
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn")
g.map(sns.distplot,'artistCount')
g.add_legend();
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn",row="paid",hue="gender")
g.map(plt.scatter,'distinctSession','songCount',s=50,alpha=0.3)
g.add_legend();
g = sns.FacetGrid(combined_agg_pd, palette="Set1",col="churn",row="paid",hue="gender")
g.map(plt.scatter,'thumbs_up','thumbs_down',s=50,alpha=0.3)
g.add_legend();
From the above graphs we can observer the following:
numericCols = combined_agg_df.columns
numericCols.remove("churn")
numericCols.remove("userId")
numericCols.remove("userAgent")
print(numericCols)
['last_level', 'gender', 'distinctSession', 'artistCount', 'avgPageSess', 'songCount', 'totalAdvert', 'addPlaylist', 'totalError', 'totalTime', 'downgraded', 'upgraded', 'thumbs_up', 'thumbs_down', 'paid']
stages = []
# Transform all features into a vector using VectorAssembler
assemblerInputs = numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
standardscaler=StandardScaler().setInputCol("features").setOutputCol("scaledFeatures")
stages += [assembler,standardscaler]
stages
[VectorAssembler_470da410c3ddaba6ae67, StandardScaler_4bfc868f66b52c7821f6]
#Run the stages as a Pipeline.
partialPipeline = Pipeline().setStages(stages)
pipelineModel = partialPipeline.fit(combined_agg_df)
preppedDataDF = pipelineModel.transform(combined_agg_df)
preppedDataDF.select("features","scaledFeatures").show(5)
+--------------------+--------------------+ | features| scaledFeatures| +--------------------+--------------------+ |[0.0,1.0,14.0,370...|[0.0,2.0113616675...| |[1.0,1.0,22.0,707...|[2.04127203248654...| |[1.0,1.0,2.0,186....|[2.04127203248654...| |[0.0,1.0,23.0,820...|[0.0,2.0113616675...| |[1.0,1.0,11.0,706...|[2.04127203248654...| +--------------------+--------------------+ only showing top 5 rows
Split the full dataset into train, test, and validation sets. Test out several of the machine learning methods you learned. Evaluate the accuracy of the various models, tuning parameters as necessary. Determine your winning model based on test accuracy and report results on the validation set. Since the churned users are a fairly small subset, I suggest using F1 score as the metric to optimize.
# Split dataset into training, test
(training, test) = preppedDataDF.randomSplit([0.7, 0.3], seed=42)
print('training count:', training.count())
print('test count:', test.count())
training count: 314 test count: 138
training.take(1)
[Row(userId='138', userAgent='"Mozilla/5.0 (iPad; CPU OS 7_1_1 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D201 Safari/9537.53"', last_level=0, gender=1, distinctSession=14, artistCount=370, avgPageSess=41.142857142857146, songCount=434, totalAdvert=37, addPlaylist=10, totalError=1, totalTime=107916.26365000002, downgraded=0, upgraded=0, thumbs_up=27, thumbs_down=4, churn=0, paid=0, features=DenseVector([0.0, 1.0, 14.0, 370.0, 41.1429, 434.0, 37.0, 10.0, 1.0, 107916.2637, 0.0, 0.0, 27.0, 4.0, 0.0]), scaledFeatures=DenseVector([0.0, 2.0114, 1.0629, 0.5912, 0.9043, 0.3771, 1.7289, 0.2969, 0.5946, 0.377, 0.0, 0.0, 0.3876, 0.2995, 0.0]))]
def trainingModel(classifier,training,test):
'''
This function train and test classification model with different classifiers and cross validation
INPUT:
classifier - classifier (DecisionTree, Random Forest, etc.)
training - spark dataframe for training
test - spark dataframe for test
OUTPUT:
model - trained model
predict_test - predicted datafshow all records in test set
'''
start = time()
model=classifier.fit(training)
predict_train=model.transform(training)
predict_test=model.transform(test)
end = time()
prediction_time = end - start
print("Training time:", prediction_time)
return model,predict_test
def evaluateModel(predict_test):
summary = {}
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="churn")
summary['f1_score'] = evaluator.evaluate(predict_test, {evaluator.metricName: "f1"})
summary['precision'] = evaluator.evaluate(predict_test, {evaluator.metricName: "weightedPrecision"})
summary['recall'] = evaluator.evaluate(predict_test, {evaluator.metricName: "weightedRecall"})
summary['accuracy'] = evaluator.evaluate(predict_test, {evaluator.metricName: "accuracy"})
confusion_matrix = predict_test.groupby("Churn").pivot("prediction").count()
return summary,confusion_matrix
# Try logistics Regression with no Params Tuning
lr = LogisticRegression(labelCol="churn", featuresCol="scaledFeatures",maxIter=10)
model_lr,predict_test_lr = trainingModel(lr,training,test)
Training time: 54.073190689086914
summary_lr,confusion_matrix_lr = evaluateModel(predict_test_lr)
# Try RandomForest with no Params Tuning
rf = RandomForestClassifier(labelCol="churn", featuresCol="scaledFeatures")
model_rf,predict_test_rf = trainingModel(rf,training,test)
Training time: 105.55780577659607
summary_rf,confusion_matrix_rf = evaluateModel(predict_test_rf)
{'f1_score': 0.7702669303015324, 'precision': 0.7798353909465021, 'recall': 0.8074074074074075, 'accuracy': 0.7928571428571428} +-----+---+---+ |Churn|0.0|1.0| +-----+---+---+ | 0|103| 4| | 1| 26| 4| +-----+---+---+
# Try GBTClassifier with no Params Tuning
gbtc = GBTClassifier(labelCol="churn", featuresCol="scaledFeatures")
model_gbtc,predict_test_gbtc = trainingModel(gbtc,training,test)
Training time: 213.31944227218628
summary_gbtc,confusion_matrix_gbtc = evaluateModel(predict_test_gbtc)
{'f1_score': 0.7471422611597831, 'precision': 0.7164179104477612, 'recall': 0.746376811594203, 'accuracy': 0.7388059701492538} +-----+---+---+ |Churn|0.0|1.0| +-----+---+---+ | 0| 89| 16| | 1| 22| 9| +-----+---+---+
print("Logistics Regression:", summary_lr)
confusion_matrix_lr.show()
print("Random Forest:", summary_rf)
confusion_matrix_rf.show()
print("Gradient-Boosted Trees (GBTs):", summary_gbtc)
confusion_matrix_gbtc.show()
Logistics Regression: {'f1_score': 0.732241250930752, 'precision': 0.7723076923076923, 'recall': 0.801470588235294, 'accuracy': 0.8074074074074075} +-----+---+---+ |Churn|0.0|1.0| +-----+---+---+ | 0|108| 3| | 1| 26| 4| +-----+---+---+ Random Forest: {'f1_score': 0.7702669303015324, 'precision': 0.7798353909465021, 'recall': 0.8074074074074075, 'accuracy': 0.7928571428571428} +-----+---+---+ |Churn|0.0|1.0| +-----+---+---+ | 0|109| 4| | 1| 24| 4| +-----+---+---+ Gradient-Boosted Trees (GBTs): {'f1_score': 0.7471422611597831, 'precision': 0.7164179104477612, 'recall': 0.746376811594203, 'accuracy': 0.7388059701492538} +-----+---+---+ |Churn|0.0|1.0| +-----+---+---+ | 0| 91| 15| | 1| 19| 7| +-----+---+---+
# Build model
def trainModelCV(classifier, paramGrid):
'''
This function train and test classification model with different classifiers and cross validation
INPUT:
classifier - classifier (DecisionTree, Random Forest, etc.)
param - parmaGrid for Cross Validation based of different classifiers
OUTPUT:
model - trained model
predict_test - predicted data
'''
start = time()
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="churn")
# Create 5-fold CrossValidator
cv = CrossValidator(estimator=classifier, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)
# Run cross validations
model = cv.fit(training)
# this will likely take a fair amount of time because of the amount of models that we're creating and testing
# Use test set to measure the accuracy of our model on new data
predict_test = model.transform(test)
end = time()
prediction_time = end - start
print("Training time:", prediction_time)
return model,predict_test
# build RandomForest model
rf = RandomForestClassifier(labelCol="churn", featuresCol="scaledFeatures")
paramRF = (ParamGridBuilder()
.addGrid(rf.maxDepth, [2, 4, 6])
.addGrid(rf.maxBins, [20, 60])
.addGrid(rf.numTrees, [5, 20])
.build())
modelRF,predict_testCV_RF = trainModelCV(rf,paramRF)
Training time: 1350.1863057613373
summary_rf,confusion_matrix_rf = evaluateModel(predict_testCV_RF)
print(summary_rf)
confusion_matrix_rf.show()
{'f1_score': 0.7406385883407707, 'precision': 0.7363184079601991, 'recall': 0.7938931297709924, 'accuracy': 0.8045112781954887} +-----+---+---+ |Churn|0.0|1.0| +-----+---+---+ | 0|100| 5| | 1| 23| 6| +-----+---+---+