2022. 5. 23. 00:17ใ๐ Data Engineering/Apache Spark
ALS, Alternating Least Squares
SparkML์ ์ถ์ฒ ์๊ณ ๋ฆฌ์ฆ์ธ ALS๋ฅผ ์ง์ํ๋ค.
์ํ ํ์ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์์ ์ง์ ALS ๋ชจ๋ธ์ Spark์์ ์ฌ์ฉํด๋ณด์.
[ALS ๊ฐ๋ ]
Basic Settings
from matplotlib import font_manager, rc
font_path = 'C:\\WINDOWS\\Fonts\\HBATANG.TTF'
font = font_manager.FontProperties(fname=font_path).get_name()
rc('font', family=font)
import os
import findspark
findspark.init(os.environ.get("SPARK_HOME"))
import pyspark
from pyspark import SparkConf, SparkContext
import pandas as pd
import faulthandler
faulthandler.enable()
from pyspark.sql import SparkSession
MAX_MEMORY = "5g"
spark = SparkSession.builder.master('local').appName("movie-recommedation")\
.config("spark.executer.memory", MAX_MEMORY)\
.config("spark.driver.memory", MAX_MEMORY).getOrCreate()
์ํ ํ์ ๋ฐ์ดํฐ ๋ฐ์์ค๊ธฐ
(1) ์ฌ์ดํธ์ ์ ์ํด์ค๋๋ค.
https://grouplens.org/datasets/movielens/25m/
(2) ๊ฒ์ ์ฌ๊ฐํ ์์ ์๋ Zip ๋งํฌ๋ฅผ ๋๋ฌ์ ๋ฐ์ดํฐ๋ฅผ ๋ค์ด๋ก๋ํด์ค๋๋ค.
(3) C:\์ํ๋ ์ ์ฅ ๊ฒฝ๋ก\ml-25m ์์ ์ด ์ํ๋ ๊ฒฝ๋ก์ ํด๋๋ฅผ ์ ์ฅํฉ๋๋ค.
ALS ์ฌ์ฉํ๊ธฐ
(1) ๋ฐ์ดํฐ ๋ก๋
ratings_file = 'C:/์์ ์ด ์ค์ ํ ํด๋ ๊ฒฝ๋ก/ml-25m/ratings.csv'
ratings_df = spark.read.csv(f"file:///{ratings_file}", inferSchema=True, header=True)
(2) ํ์ํ ์นผ๋ผ ๋จ๊ธฐ๊ธฐ ๋ฐ ๋ฐ์ดํฐ ํ์ธ
ratings_df.show()
# timestamp ํ์์์ผ๋๊น ์์ ๊ธฐ
ratings_df = ratings_df.select(['userId', 'movieId', 'rating'])
+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
| 1| 296| 5.0|1147880044|
| 1| 306| 3.5|1147868817|
| 1| 307| 5.0|1147868828|
| 1| 665| 5.0|1147878820|
| 1| 899| 3.5|1147868510|
| 1| 1088| 4.0|1147868495|
| 1| 1175| 3.5|1147868826|
| 1| 1217| 3.5|1147878326|
| 1| 1237| 5.0|1147868839|
| 1| 1250| 4.0|1147868414|
| 1| 1260| 3.5|1147877857|
| 1| 1653| 4.0|1147868097|
| 1| 2011| 2.5|1147868079|
| 1| 2012| 2.5|1147868068|
| 1| 2068| 2.5|1147869044|
| 1| 2161| 3.5|1147868609|
| 1| 2351| 4.5|1147877957|
| 1| 2573| 4.0|1147878923|
| 1| 2632| 5.0|1147878248|
| 1| 2692| 5.0|1147869100|
+------+-------+------+----------+
only showing top 20 rows
----
ratings_df.printSchema()
root
|-- userId: integer (nullable = true)
|-- movieId: integer (nullable = true)
|-- rating: double (nullable = true)
userId ๋ณ๋ก ์ํ ์ฝ๋์ ๊ทธ์ ๋ฐ๋ฅธ ํ์ ์ด ์ ๋ฆฌ๋์ด ์๋ ๋ฐ์ดํฐ์ด๋ค. ๋ชจ๋ ์ ์ ๊ฐ ์ํ๋ฅผ ๋ณธ ๊ฒ์ ์๋๊ธฐ ๋๋ฌธ์, 1๋ฒ ์ ์ ๊ฐ ์ ๋ณธ ์ํ๋ ์์ ๋ฐ์ดํฐ์ ์์ ๊ฒ์ด๋ค.
๋ฐ์ดํฐ์ ์ฌ๋ผ์จ ์ํ์ ๊ฐ์๋ ์ด 59047๊ฐ์ด๋ค.
ratings_df.createOrReplaceTempView('ratings_df')
query = '''
SELECT
COUNT(DISTINCT movieId) as movie_count
FROM
ratings_df
'''
spark.sql(query).show()
+-----------+
|movie_count|
+-----------+
| 59047|
+-----------+
์ฒซ ๋ฒ์งธ ์ ์ ๊ฐ ํ์ ์ ๋งค๊ธด ์ํ๋ ์ด 70๊ฐ๋ค.
query = '''
SELECT
COUNT(*) as 1_count
FROM
ratings_df
WHERE
userId == 1
'''
spark.sql(query).show()
+-------+
|1_count|
+-------+
| 70|
+-------+
์ฐ๋ฆฐ ALS ์๊ณ ๋ฆฌ์ฆ์ ํตํด์ '์ ๋ณธ ์ํ' ์ค 1๋ฒ ์ ์ ๊ฐ ๊ฐ์ฅ ์ข์ํ ๊ฒ ๊ฐ์ ์ํ๋ฅผ ์ถ์ฒํ๋ ๊ฒ์ด ํฌ์คํ ์ ๋ชฉํ๋ค.
(3) ํ์ ํต๊ณ
# ํ์ ํต๊ณ
ratings_df.select('rating').describe().show()
+-------+------------------+
|summary| rating|
+-------+------------------+
| count| 25000095|
| mean| 3.533854451353085|
| stddev|1.0607439611423535|
| min| 0.5|
| max| 5.0|
+-------+------------------+
(4) Train, Test ๋ฐ์ดํฐ ๋๋๊ธฐ
# train, test set ๋๋๊ธฐ
train_df, test_df = ratings_df.randomSplit([0.8,0.2])
print(f'train_df ์ ๊ธธ์ด : {train_df.count()}')
print(f'test_df ์ ๊ธธ์ด : {test_df.count()}')
train_df ์ ๊ธธ์ด : 19999764
test_df ์ ๊ธธ์ด : 5000331
(5) ๋ชจ๋ธ ๊ตฌ์ถ ๋ฐ ํ๋ จ
from pyspark.ml.recommendation import ALS
# ์ถ์ฒ ์๊ณ ๋ฆฌ์ฆ
als = ALS(
maxIter = 5,
regParam=0.1,
userCol='userId',
itemCol = 'movieId',
ratingCol = 'rating',
# ํ์ตํ์ง ๋ชปํ๋ ๋ฐ์ดํฐ ๋ง๋ฌ์ ๋ ์ด๋ป๊ฒ ๋์ฒํ ๊ฒ์ธ๊ฐ ์ค์ .
coldStartStrategy='drop'
)
'userCol' ์๋ ์ ์ ์์ด๋ ์นผ๋ผ์ ์ ๋ ฅํ๋ค.
'itemCol' ์๋ ์์ดํ ์์ด๋ ์นผ๋ผ์ ์ ๋ ฅํ๋ค.
'ratingCol' ์๋ ํ์ ์นผ๋ผ์ ์ ๋ ฅํ๋ค.
๋จ์ํ ALS์ ๊ฐ๋ ๊ณผ ์ฌ์ฉ ๋ฐฉ๋ฒ๋ง ์๊ณ ์์ด๋ ๋ชจ๋ธ์ ํ์ฉํ ์ ์๋ค!
# ๋ชจ๋ธ fit()
model = als.fit(train_df)
(6) ์ถ๋ก ๋ฐ ์ฑ๋ฅ ํ์ธ
# ์ถ๋ก
prediction = model.transform(test_df)
prediction.show()
+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
| 1| 1088| 4.0| 2.5297563|
| 3| 175197| 3.5| 2.874376|
| 4| 1580| 4.5| 3.097906|
| 9| 1088| 5.0| 3.9108486|
| 12| 8638| 4.0| 3.8486152|
| 13| 3175| 4.0| 3.6777968|
| 20| 1580| 4.0| 4.033978|
| 23| 1959| 5.0| 3.9067807|
| 30| 3175| 4.5| 3.8995147|
| 31| 8638| 2.0| 2.5974488|
| 41| 1580| 4.0| 3.5832882|
| 41| 2366| 3.0| 3.150863|
| 57| 1580| 3.0| 3.604269|
| 58| 6658| 5.0| 3.2883408|
| 63| 68135| 4.0| 3.1181765|
| 70| 1580| 3.0| 3.017872|
| 72| 1591| 2.0| 2.358852|
| 75| 1088| 3.5| 3.070664|
| 75| 1959| 2.0| 3.243212|
| 80| 1342| 2.0| 2.482828|
+------+-------+------+----------+
only showing top 20 rows
----
prediction.select('rating', 'prediction').describe().show()
+-------+------------------+-----------------+
|summary| rating| prediction|
+-------+------------------+-----------------+
| count| 4996844| 4996844|
| mean| 3.534328668255403|3.399862379635884|
| stddev|1.0605845133909824| 0.63975911880358|
| min| 0.5| -2.4304335|
| max| 5.0| 6.9266715|
+-------+------------------+-----------------+
์ค์ rating๊ณผ prediction ํต๊ณ ๊ฐ์ ๋น๊ตํ ๊ฒฐ๊ณผ, mean ๊ฐ์ด ์ ์๋ฏธํ๊ฒ ๋น์ทํ ๊ฒ์ ํ์ธํ ์ ์๋ค. ํ์ง๋ง ํ์ค ํธ์ฐจ๊ฐ ์ฐจ์ด๋๊ณ , min ๊ฐ์ด ๋ง์ด๋์ค๊ฐ ์๋ค๋ ์ ์์ ๋ณด์์ด ํ์ํด ๋ณด์ธ๋ค.
# ํ๊ฐํ๊ธฐ
from pyspark.ml.evaluation import RegressionEvaluator
evaluator = RegressionEvaluator(metricName='rmse', labelCol='rating',predictionCol='prediction')
# RMSE
rmse = evaluator.evaluate(prediction)
rmse
0.8134903395643102
(7) ์ ์ ๋ง๋ค Top 5 recommendation ํด์ฃผ๊ธฐ
# ์ ์ ๋ง๋ค Top5 recommendation ํด์ฃผ๊ธฐ
model.recommendForAllUsers(5).show()
+------+--------------------+
|userId| recommendations|
+------+--------------------+
| 1|[{202231, 5.65169...|
| 3|[{194434, 6.35166...|
| 4|[{194434, 6.07157...|
| 5|[{194434, 6.13072...|
| 6|[{162436, 6.21624...|
| 7|[{185645, 5.45231...|
| 8|[{194434, 5.93970...|
| 9|[{185645, 6.51190...|
| 10|[{194434, 6.12319...|
| 12|[{194434, 5.63328...|
| 13|[{194434, 6.47734...|
| 15|[{194434, 6.69569...|
| 16|[{194434, 6.56789...|
| 17|[{199187, 6.23324...|
| 19|[{194434, 5.82160...|
| 20|[{194434, 6.83623...|
| 21|[{194434, 6.51367...|
| 22|[{185645, 7.19147...|
| 23|[{194434, 6.35103...|
| 24|[{203086, 6.53243...|
+------+--------------------+
only showing top 20 rows
(8) ์์ดํ ๋ง๋ค Top 5 ์ ์ ์ถ์ฒํด์ฃผ๊ธฐ
# ์์ดํ
๋ง๋ค Top3 User ์ถ์ฒํด์ฃผ๊ธฐ
model.recommendForAllItems(3).show()
+-------+--------------------+
|movieId| recommendations|
+-------+--------------------+
| 1|[{18230, 5.483052...|
| 2|[{87426, 5.260374...|
| 3|[{87426, 5.19888}...|
| 4|[{52924, 4.736040...|
| 5|[{52924, 4.899741...|
| 6|[{156252, 5.42339...|
| 7|[{10417, 5.039373...|
| 8|[{87426, 5.294695...|
| 9|[{87426, 5.268173...|
| 10|[{87426, 5.282108...|
| 11|[{10417, 5.357073...|
| 12|[{87426, 5.257176...|
| 13|[{108346, 5.20422...|
| 14|[{105801, 4.85644...|
| 15|[{87426, 5.336809...|
| 16|[{96740, 5.368758...|
| 17|[{58248, 5.512191...|
| 18|[{87426, 5.399836...|
| 19|[{87426, 5.231883...|
| 20|[{87426, 5.285846...|
+-------+--------------------+
only showing top 20 rows
(9) ํน์ ์ ์ ๋ฅผ ์ํ ์ถ์ฒ API ๋ง๋ค๊ธฐ
์ ์ ๋ฅผ ์ ๋ ฅํ๋ฉด Top 3 Moive๋ฅผ ์ถ์ฒํด์ฃผ๋ API๋ฅผ ๋ง๋ค์ด ๋ณด์.
๋จผ์ API์ ๋ค์ด๊ฐ 'userId' ๋ฐ์ดํฐ๋ฅผ ์กฐ์งํด๋ณด๊ณ , ๊ทธ ๊ฒฐ๊ณผ๋ก ๋์ฌ 'recommendation' ๋ฐ์ดํฐ ํ์์ ์ดํด๋ณด์.
from pyspark.sql.types import IntegerType
# id ๋ฐ์ดํฐ๊ฐ ์ด๋ป๊ฒ ์ ๊ณต๋ ์ง ํ์ธ
user_list = [65, 78, 93]
user_df = spark.createDataFrame(user_list, IntegerType()).toDF('userId')
user_df.show()
+------+
|userId|
+------+
| 65|
| 78|
| 93|
+------+
----
# recommend๊ฐ ์ด๋ป๊ฒ ๋์ฌ์ง ํ์ธ
user_recommend = model.recommendForUserSubset(user_df, 5)
# ์ฒซ๋ฒ์งธ ์ ์ ์ ๋ํ Top 5 recommendation ์ถ์ถ
movies_list = user_recommend.collect()[0].recommendations
recs_df = spark.createDataFrame(movies_list)
recs_df.show()
+-------+-----------------+
|movieId| rating|
+-------+-----------------+
| 205277|6.759361743927002|
| 159761|6.485182762145996|
| 169606|6.246584415435791|
| 137363|5.943302154541016|
| 203633|5.908850193023682|
+-------+-----------------+
์ํ ID๋ฅผ ์ ์ ์๊ฒ ์ถ์ฒํด๋ดค์, ์ด๋ค ์ํ์ธ์ง ์์ง ๋ชปํ๋ค. ์ํ ID๋ฅผ ์ ๋ชฉ, ์ฅ๋ฅด์ ์ฐ๊ฒฐ์์ผ์ฃผ์.
movies_file = 'C:/ml-25m ํด๋ ๊ฒฝ๋ก/movies.csv'
movies_df = spark.read.csv(f"file:///{movies_file}", inferSchema=True, header=True)
movies_df.show()
+-------+--------------------+--------------------+
|movieId| title| genres|
+-------+--------------------+--------------------+
| 1| Toy Story (1995)|Adventure|Animati...|
| 2| Jumanji (1995)|Adventure|Childre...|
| 3|Grumpier Old Men ...| Comedy|Romance|
| 4|Waiting to Exhale...|Comedy|Drama|Romance|
| 5|Father of the Bri...| Comedy|
| 6| Heat (1995)|Action|Crime|Thri...|
| 7| Sabrina (1995)| Comedy|Romance|
| 8| Tom and Huck (1995)| Adventure|Children|
| 9| Sudden Death (1995)| Action|
| 10| GoldenEye (1995)|Action|Adventure|...|
| 11|American Presiden...|Comedy|Drama|Romance|
| 12|Dracula: Dead and...| Comedy|Horror|
| 13| Balto (1995)|Adventure|Animati...|
| 14| Nixon (1995)| Drama|
| 15|Cutthroat Island ...|Action|Adventure|...|
| 16| Casino (1995)| Crime|Drama|
| 17|Sense and Sensibi...| Drama|Romance|
| 18| Four Rooms (1995)| Comedy|
| 19|Ace Ventura: When...| Comedy|
| 20| Money Train (1995)|Action|Comedy|Cri...|
+-------+--------------------+--------------------+
only showing top 20 rows
movieId์ recommendation์ด ๋์์ ๋, ์ํ ์ ๋ชฉ๊ณผ ์ฅ๋ฅด๋ฅผ ์ฐ๊ฒฐํด์ฃผ๋ ์ฟผ๋ฆฌ๋ฅผ ์ง ๋ณด์.
recs_df.createOrReplaceTempView('recommendations')
movies_df.createOrReplaceTempView('movies')
# SQL ์ฟผ๋ฆฌ
query = '''
Select
*
From
recommendations r
Join movies m
On r.movieId = m.movieId
ORDER BY
rating desc
'''
recommendation_movies = spark.sql(query)
recommendation_movies.show()
+-------+-----------------+-------+--------------------+--------------------+
|movieId| rating|movieId| title| genres|
+-------+-----------------+-------+--------------------+--------------------+
| 205277|6.759361743927002| 205277| Inside Out (1991)|Comedy|Drama|Romance|
| 159761|6.485182762145996| 159761| Loot (1970)| Comedy|Crime|
| 169606|6.246584415435791| 169606|Dara O'Briain Cro...| Comedy|
| 137363|5.943302154541016| 137363|The Mother Of Inv...| Comedy|
| 203633|5.908850193023682| 203633| The Bribe (2018)| Comedy|Crime|
+-------+-----------------+-------+--------------------+--------------------+
API
# ์ค์ ๋ก ์ด์ฉํ ๋๋ ํ๋์ func์์์ ์ํ
# ํธํ๋ค.
query = '''
Select
*
From
recommendations r
Join movies m
On r.movieId = m.movieId
ORDER BY
rating desc
'''
def get_recommendation(user_id, num_recs):
# userid๋ฅผ ์
๋ ฅ ๋ฐ๋๋ค.
user_df = spark.createDataFrame([user_id], IntegerType()).toDF('userId')
# userid๋ฅผ ๋ฐํ์ผ๋ก recommendation ํ๊ธฐ
user_recs_df = model.recommendForUserSubset(user_df, num_recs)
# recommendation์ dataframe์ผ๋ก
recs_list = user_recs_df.collect()[0].recommendations
recs_df = spark.createDataFrame(recs_list)
recs_df.createOrReplaceTempView('recommendations')
movies_df.createOrReplaceTempView('movies')
# SQL ๋ฌธ์ ํตํด์ recommendation๊ณผ ์ํ ์ ๋ชฉ/์ฅ๋ฅด ๋ฐ์ดํฐ๋ฅผ ํฉ์ณค๋ค.
recommend_movies = spark.sql(query)
return recommend_movies
# userId๊ฐ 456๋ฒ์ธ ์ฌ๋์๊ฒ 10๊ฐ์ ์ํ๋ฅผ ์ถ์ฒํด์ฃผ์.
recs = get_recommendation(456, 10)
# pandas ๋ณํ
recs.toPandas()
์๋๋ฆฌ์ค '์ํ๋ฅผ ์ถ์ฒํด๋๋ฆฝ๋๋ค!!'
userId 100๋ฒ์ด ๋ดํด๋ฆญํธ์ ์ ์ํ๋ค.
๋ ๋ง. ํ ํ๋ฉด์ด ๋จ๋ ๊ฐ์ด๋ฐ, ๋คํธ์ํฌ๋ ๋ฐฑ์ค๋์ ๋ฏธ์ ์ ๋ด๋ฆฐ๋ค.
'userId 100๋ฒ์ด ์ข์ํ ๋งํ ์ํ๋ฅผ 5๊ฐ ์ถ์ฒํด์ค๋ผ! ์น์ ํ๊ฒ!'
๋ฐฑ์ค๋๋ ๊ทธ๋ฅ ์์ ์ฝ๋๋ฅผ ๋๋ฆฌ๋ ค ํ๋ค. ํ์ง๋ง pandas dataframe์ผ๋ก ์ ๋ฌํ๋ค๊ฐ ์๋จน์ ๊ฒ ๊ฐ์ ๋ถ์๊ธฐ์ฌ์ ์๋กญ๊ฒ ํจ์๋ฅผ ์ง๊ธฐ๋ก ํ๋ค.
def recommendation(user_id, num_recs):
recs = get_recommendation(user_id, num_recs)
r = recs.toPandas()
# ์ ๋ชฉ ๋ฆฌ์คํธ ๋ฐ์์ค๊ธฐ
title_list = list(r['title'])
# ์ฅ๋ฅด ๋ฆฌ์คํธ ๋ฐ์์ค๊ธฐ
genre_list = list(r['genres'])
# ์๋ง์ ๋งจํธ๋ฅผ ์จ์ ์ ์ ์๊ฒ ์ถ์ฒํด์ฃผ๊ธฐ
for i in range(0, num_recs):
print(f'{i+1}๋ฒ์งธ ์ถ์ฒ๋๋ฆด ์ํ๋ {title_list[i]}/{genre_list[i]}์
๋๋ค.')
์ด์ userId 100๋ฒ์ด ์ค๋ฉด ๋ค์ ํจ์๋ฅผ ์ฐ๋ฉด ๋๋ค.
recommendation(100, 5)
1๋ฒ์งธ ์ถ์ฒ๋๋ฆด ์ํ๋ Adrenaline (1990)/(no genres listed)์
๋๋ค.
2๋ฒ์งธ ์ถ์ฒ๋๋ฆด ์ํ๋ Truth and Justice (2019)/Drama์
๋๋ค.
3๋ฒ์งธ ์ถ์ฒ๋๋ฆด ์ํ๋ School of Babel (2014)/Documentary์
๋๋ค.
4๋ฒ์งธ ์ถ์ฒ๋๋ฆด ์ํ๋ National Theatre Live: One Man, Two Guvnors (2011)/Comedy์
๋๋ค.
5๋ฒ์งธ ์ถ์ฒ๋๋ฆด ์ํ๋ Les Luthiers: El Grosso Concerto (2001)/(no genres listed)์
๋๋ค.
์ด๋ ๊ฒ ํด์ Spark์์ ์ถ์ฒ ์๊ณ ๋ฆฌ์ฆ ALS๋ฅผ ํ์ฉํด๋ดค๋ค.
๋์์ด ๋์๊ธธ ๋ฐ๋ผ๋ฉฐ, ๋ฐ๋ผ์ค๋๋ผ ๊ณ ์ํ์ จ์ต๋๋ค.
'๐ Data Engineering > Apache Spark' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[SparkML] MLlib Parameter ํ๋ ๊ฐ๋ ๋ฐ ์ฝ๋ (0) | 2022.05.22 |
---|---|
[SparkML] MLlib Pipeline ๊ตฌ์ถํ๊ธฐ (0) | 2022.05.21 |
[SparkML] MLlib ๊ฐ๋ ๋ฐ ์ค์ต ์ฝ๋ (0) | 2022.05.20 |
[SparkSQL] ํ์ ๋ฐ์ดํฐ ๋ค์ด/์ ์ฒ๋ฆฌ/๋ถ์ feat. TLC (0) | 2022.05.10 |
[SparkSQL] Catalyst, Tungsten ์๋ ์๋ฆฌ (0) | 2022.05.09 |