2022. 5. 21. 15:59ใ๐ Data Engineering/Apache Spark
Spark MLlib์ Pipeline ์ปดํฌ๋ํธ๋ฅผ ์ง์ํ๋ค.
Pipeline์ ๋จธ์ ๋ฌ๋์ ์ํฌํ๋ก์ฐ๋ฅผ ๋งํ๊ธฐ๋ ํ๋ฉฐ, ์ฌ๋ฌ Stage๋ฅผ ๋ด๊ณ ์๋ค. persist() ํจ์๋ฅผ ํตํด ์ ์ฅ๋ ๊ฐ๋ฅํ๋ค.
๊ทธ๋ ๋ค๋ฉด ์ค์ ์ฝ๋๋ ์ด๋ป๊ฒ ์งค๊น?
์ด์ ํฌ์คํ ์์ ๋ค๋ค๋ ํ์ ๋ฐ์ดํฐ๋ก, MLlib Pipeline์ ๊ตฌ์ถํด๋ณด์.
๊ธฐ๋ณธ ์ธํ ์ ์ด์ Spark ํฌ์คํ ๋ค์ ๋ณด๋ฉด ์ดํดํ ์ ์๋ค.
[๋ฐ์ดํฐ https://mengu.tistory.com/50?category=932924]
"๋ณธ ํฌ์คํ ์ ํจ์คํธ์บ ํผ์ค์ ๊ฐ์๋ฅผ ๋ฃ๊ณ , ์ ๋ฆฌํ ์๋ฃ์์ ๋ฐํ๋๋ค."
Basic Settings
๊ธฐ๋ณธ ์ธํ ์ด๋ค. ์ํฌํธ ํด์ค์ผ ํ ๊ฒ๋ค์ ๋ฏธ๋ฆฌ ํด๋จ๊ณ , SparkSession์ ์ด์ด๋์๋ค.
# ํฐํธ ์ค์
from matplotlib import font_manager, rc
font_path = 'C:\\WINDOWS\\Fonts\\HBATANG.TTF'
font = font_manager.FontProperties(fname=font_path).get_name()
rc('font', family=font)
# basic settings
import os
import findspark
findspark.init(os.environ.get("SPARK_HOME"))
import pyspark
from pyspark import SparkConf, SparkContext
import pandas as pd
import faulthandler
faulthandler.enable()
from pyspark.sql import SparkSession
MAX_MEMORY = "5g"
spark = SparkSession.builder.master('local').appName("taxi-fare-prediction")\
.config("spark.executor.memory", MAX_MEMORY)\
.config("spark.driver.memory", MAX_MEMORY).getOrCreate()
# ๋ฐ์ดํฐ๊ฐ ์๋ ํ์ผ
zone_data = "C:/DE study/data-engineering/01-spark/data/taxi_zone_lookup.csv"
trip_files = "C:/DE study/data-engineering/01-spark/data/trips/*"
# ๋ฐ์ดํฐ ๋ก๋
trips_df = spark.read.csv(f"file:///{trip_files}", inferSchema = True, header = True)
zone_df = spark.read.csv(f"file:///{zone_data}", inferSchema = True, header = True)
# ๋ฐ์ดํฐ ์คํค๋ง
trips_df.printSchema()
zone_df.printSchema()
# ๋ฐ์ดํฐ createOrReplaceTempView()
trips_df.createOrReplaceTempView("trips")
zone_df.createOrReplaceTempView("zone")
root
|-- VendorID: integer (nullable = true)
|-- tpep_pickup_datetime: string (nullable = true)
|-- tpep_dropoff_datetime: string (nullable = true)
|-- passenger_count: integer (nullable = true)
|-- trip_distance: double (nullable = true)
|-- RatecodeID: integer (nullable = true)
|-- store_and_fwd_flag: string (nullable = true)
|-- PULocationID: integer (nullable = true)
|-- DOLocationID: integer (nullable = true)
|-- payment_type: integer (nullable = true)
|-- fare_amount: double (nullable = true)
|-- extra: double (nullable = true)
|-- mta_tax: double (nullable = true)
|-- tip_amount: double (nullable = true)
|-- tolls_amount: double (nullable = true)
|-- improvement_surcharge: double (nullable = true)
|-- total_amount: double (nullable = true)
|-- congestion_surcharge: double (nullable = true)
root
|-- LocationID: integer (nullable = true)
|-- Borough: string (nullable = true)
|-- Zone: string (nullable = true)
|-- service_zone: string (nullable = true)
์ํ๋ DataFrame ๊ฐ์ ธ์ค๊ธฐ
Fare ์๊ธ์ ์์ธกํ๊ธฐ ์ํ ๋ฐ์ดํฐ๋ค์ ๊ฐ์ ธ์จ๋ค.
(1) ์ง๋ถ ๋ฐฉ๋ฒ
(2) ํ์น ์๊ฐ
(3) ํ์น ์ฅ์
(4) ํ์ฐจ ์ฅ์
(5) ํ์น ์์ผ
(6) ์๋ ์
(7) ๊ฑฐ๋ฆฌ
(8) ์๊ธ
๊ฐ ๋ฐ์ดํฐ๋ค์ ์ด์์น์ ๊ฒฐ์ธก์น๊ฐ ์กด์ฌํจ์ผ๋ก, SQL์ ํตํด ๊ฐ์ ธ์ฌ ๋ ํํฐ๋ง์ ํด์ ๊ฐ์ ธ์จ๋ค. ์์ธํ ์ ์ฒ๋ฆฌ ๋ด์ฉ์ ๋ค์ ํฌ์คํ ์์ ํ์ธํ ์ ์๋ค.
query = '''
SELECT
PULocationID as pickup_location_id,
DOLocationID as dropoff_location_id,
payment_type,
HOUR(tpep_pickup_datetime) as pickup_time,
DATE_FORMAT(TO_DATE(tpep_pickup_datetime), 'EEEE') AS day_of_week,
passenger_count,
trip_distance,
total_amount
FROM
trips
WHERE
total_amount < 5000
AND total_amount > 0
AND trip_distance > 0
AND trip_distance < 500
AND passenger_count < 4
AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
AND TO_DATE(tpep_pickup_datetime) < '2021-08-01'
'''
# sql ๋ฌธ์ ํตํด ๋ฐ์ดํฐ ํ๋ ์ ๊ฐ์ ธ์ค๊ธฐ
data_df = spark.sql(query)
data_df.createOrReplaceTempView('data')
data_df.printSchema()
root
|-- pickup_location_id: integer (nullable = true)
|-- dropoff_location_id: integer (nullable = true)
|-- payment_type: integer (nullable = true)
|-- pickup_time: integer (nullable = true)
|-- day_of_week: string (nullable = true)
|-- passenger_count: integer (nullable = true)
|-- trip_distance: double (nullable = true)
|-- total_amount: double (nullable = true)
Train, Test ๋ฐ์ดํฐ์ ๋๋๊ธฐ
# train, test dataframe ๋๋๊ธฐ
train_df, test_df = data_df.randomSplit([0.8, 0.2], seed=1)
# ๋์ค์ ๋ก๋ฉ ์๊ฐ์ ์๋ผ๊ธฐ ์ํด ๋ฏธ๋ฆฌ ์ ์ฅํด๋ก๋๋ค.
data_dir = "C:/์ ์ฅ ๊ฒฝ๋ก"
train_df = spark.read.parquet(f'{data_dir}/train/')
test_df = spark.read.parquet(f'{data_dir}/test/')
Pipeline ๊ตฌ์ถ
ํ์ดํ๋ผ์ธ์ ๊ตฌ์ฑํ๊ธฐ์ ์์, ๊ฐ์ฅ ์ค์ํ ๊ฒ์ '์ด๋ป๊ฒ ์ ์ฒ๋ฆฌ'ํ ์ง ์ง๋ ๊ฒ์ด๋ค. ํ์ฌ ๋ฐ์ดํฐ์์ ๋จธ์ ๋ฌ๋ ํ์ต์ ์ํด ์ต์ํ์ผ๋ก ํด์ค์ผ ํ๋ ์ ์ฒ๋ฆฌ๋ก๋ (1) ์-ํซ ์ธ์ฝ๋ฉ (2) ์ ๊ทํ (3) ๋ฒกํฐํ ์ ๋์ด๋ค. ์ด ์ ์ฒ๋ฆฌ๋ค์ ํ์ ์ง๊ณ ๋ ํ์ ํ์ดํ๋ผ์ธ์ ๊ตฌ์ฑํ๋๋ก ํ๋ค.
(1) ์-ํซ ์ธ์ฝ๋ฉ
์นดํ ๊ณ ๋ฆฌํ ์นผ๋ผ๋ง ์-ํซ ์ธ์ฝ๋ฉ์ ํด์ฃผ๋ฉด ๋๋ค.
- StringIndexer : ์นดํ ๊ณ ๋ฆฌํ์ ์-ํซ ์ธ์ฝ๋ฉํ๊ธฐ ์ ์, ์ซ์๋ฅผ ๋จผ์ ๋ถ์ฌํด์ค๋ค. ex) (๋ธ๊ธฐ, ๋ฐ๋๋, ์ด์ฝ) -> (1, 3, 2)
- OneHotEncoder : ์ซ์ ๋ถ์ฌ๋ ๊ฒ์ ๋ฒกํฐํํ๋ค. ex) 1 -> [1, 0, 0]
# ์-ํซ ์ธ์ฝ๋ฉ
from pyspark.ml.feature import OneHotEncoder, StringIndexer
# ์นดํ
๊ณ ๋ฆฌํ ์นผ๋ผ
cat_feats = [
'pickup_location_id',
'dropoff_location_id',
'day_of_week'
]
# ํ์ดํ๋ผ์ธ ๊ตฌ์ฑ์ ์ํ stages ๋ฆฌ์คํธ
stages = []
# index๋ฅผ ๋ฐ๊พธ๊ณ , ๊ทธ ๋ฐ๋ indexer์ ์-ํซ ์ธ์ฝ๋ฉ์ ์ ์ฉํด์ค๋ค.
for c in cat_feats:
cat_indexer = StringIndexer(inputCol=c, outputCol = c + "_idx").setHandleInvalid("keep")
onehot_encoder = OneHotEncoder(inputCols = [cat_indexer.getOutputCol()], outputCols=[c + '_onehot'])
stages += [cat_indexer, onehot_encoder]
(2) ์ซ์ํ ๋ฐ์ดํฐ, ๋ฒกํฐํ ๋ฐ ์ค์ผ์ผ๋ฌ ์ ์ฉ
์ซ์ํ ๋ฐ์ดํฐ๋ฅผ ํ๋์ ๋ฒกํฐ๋ก ๋ฌถ๊ณ , ์ค์ผ์ผ๋ฌ๋ฅผ ์ ์ฉํ๋ค.
# ๋ฒกํฐํ + ์ค์ผ์ผ๋ฌ ์ ์ฉ
from pyspark.ml.feature import VectorAssembler, StandardScaler
# numericํ ์นผ๋ผ
num_feats = [
'passenger_count',
'trip_distance',
'pickup_time'
]
# ๋ฒกํฐํ ํ ํ, ์ค์ผ์ผ๋ฌ๋ฅผ ์ ์ฉํ๋ค.
for n in num_feats:
num_assembler = VectorAssembler(inputCols=[n], outputCol= n + '_vecotr')
num_scaler = StandardScaler(inputCol=num_assembler.getOutputCol(), outputCol = n + '_scaled')
stages += [num_assembler, num_scaler]
* ์ค๊ฐ stages ์ ๊ฒ
stages
[StringIndexer_115f74e6efea,
OneHotEncoder_714f494271bb,
StringIndexer_806a3b8e8a32,
OneHotEncoder_f7bc9266f650,
StringIndexer_2f125ebb95a8,
OneHotEncoder_eb212d50e427,
VectorAssembler_5b7bfff3be42,
StandardScaler_3e59d49af9ad,
VectorAssembler_a61858dada0f,
StandardScaler_c739fc7f7d49,
VectorAssembler_f683b2eeb4d2,
StandardScaler_96e5ba925088]
(3) ๋ชจ๋ VectorAssemble ํ์ฌ ํ๋ จ์ ์ ํฉํ ๋ฐ์ดํฐ ์ ์ ์์ฑํ๋๋ก ๊ตฌ์ถํ์.
# inputs ์นผ๋ผ
assembler_inputs = [c + '_onehot' for c in cat_feats] + [n + '_scaled' for n in num_feats]
print(assembler_inputs)
['pickup_location_id_onehot',
'dropoff_location_id_onehot',
'day_of_week_onehot',
'passenger_count_scaled',
'trip_distance_scaled',
'pickup_time_scaled']
# stages์ VectorAssemble ์ถ๊ฐ
assembler = VectorAssembler(inputCols=assembler_inputs, outputCol= 'feature_vector')
stages += [assembler]
stages
[StringIndexer_dfc09cc586be,
OneHotEncoder_987fbfa36a2d,
StringIndexer_bf2338365d7f,
OneHotEncoder_5a91ea5195e8,
StringIndexer_c416d64272f1,
OneHotEncoder_0dfab0742066,
VectorAssembler_4c5f47a3740c,
StandardScaler_65dfe2363318,
VectorAssembler_7e0a4e81ec39,
StandardScaler_22d11d283c0a,
VectorAssembler_c2b692153924,
StandardScaler_debc924ffa61,
VectorAssembler_c2c382815ebb]
(4) Final, Pipeline() ์ ํตํด์ ๋ค ๋ฌถ์ด์ฃผ๊ธฐ
# pipeline ๋ง๋ค๊ธฐ
from pyspark.ml import Pipeline
transform_stages = stages
pipeline = Pipeline(stages = transform_stages)
fitted_transformer = pipeline.fit(train_df)
์์ธก & ์ถ๋ก
pipeline fit()๋ฅผ ํด์ฃผ๊ณ , ๋ค์ด์ด ๋ฐ์ดํฐ๋ค์ pipeline์ ํต๊ณผ์์ผ ์ ์ฒ๋ฆฌํด์ฃผ์.
vtrain_df = fitted_transformer.transform(train_df)
vtest_df = fitted_transformer.transform(test_df)
์ ํ ํ๊ท ๋ชจ๋ธ์ ๋ง๋ค๊ณ , ํ๋ จ์ํค๊ณ ์์ธกํ๊ธฐ
lr = LinearRegression(
maxIter=50,
solver='normal',
labelCol='total_amount',
featuresCol='feature_vector'
)
# ํ๋ จํ๊ธฐ
model = lr.fit(vtrain_df)
# ์ถ๋ก ํ๊ธฐ
prediction = model.transform(vtest_df)
# ์ถ๋ก ๊ฒฐ๊ณผ ํ์ธํ๊ธฐ
prediction.select(['trip_distance', 'day_of_week','total_amount','prediction']).show(5)
+-------------+-----------+------------+------------------+
|trip_distance|day_of_week|total_amount| prediction|
+-------------+-----------+------------+------------------+
| 0.8| Thursday| 120.3| 89.980284493094|
| 24.94| Saturday| 70.8| 131.2247006891777|
| 0.01| Wednesday| 102.36|13.995815841451657|
| 0.1| Monday| 71.85|10.013635081914366|
| 0.5| Tuesday| 7.8| 12.046081421887|
+-------------+-----------+------------+------------------+
only showing top 5 rows
Evaluator
(1) RMSE
model.summary.rootMeanSquaredError
5.818945295076586
(2) R2
model.summary.r2
0.7997047915616821
์ด์ ํฌ์คํ ์ ๋นํด ์ฑ๋ฅ์ด ์ด์ง ํฅ์๋ ๊ฒ์ ํ์ธํ ์ ์๋ค. ์ด๋ฒ ํฌ์คํ ์์ Spark๋ก ๋จธ์ ๋ฌ๋ ๊ตฌ์ถํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด์๋ค. ํจ์ฌ ๋ค์ํ ์ ์ฒ๋ฆฌ ๊ธฐ๊ตฌ๋ค์ด ์กด์ฌํ๋ฉฐ, ๋ค๋ฅด๊ฒ ์กฐํฉํ๋ค๋ฉด ๋ถ๋ช ๋ ์ข์ ์ฑ๋ฅ์ ๋ชจ๋ธ์ด ๋ง๋ค์ด์ง ์ ์๋ค. ๊ทธ๋ฌ๋ ์ค์ค๋ก ์ด๊ฒ์ ๊ฒ ๋ง์ ธ๋ณด๊ธธ ๋ฐ๋๋ค.
์๊ณ ํ์ จ์ต๋๋ค.
'๐ Data Engineering > Apache Spark' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[SparkML] ALS, ์ถ์ฒ ์๊ณ ๋ฆฌ์ฆ ํ์ฉํ๊ธฐ (0) | 2022.05.23 |
---|---|
[SparkML] MLlib Parameter ํ๋ ๊ฐ๋ ๋ฐ ์ฝ๋ (0) | 2022.05.22 |
[SparkML] MLlib ๊ฐ๋ ๋ฐ ์ค์ต ์ฝ๋ (0) | 2022.05.20 |
[SparkSQL] ํ์ ๋ฐ์ดํฐ ๋ค์ด/์ ์ฒ๋ฆฌ/๋ถ์ feat. TLC (0) | 2022.05.10 |
[SparkSQL] Catalyst, Tungsten ์๋ ์๋ฆฌ (0) | 2022.05.09 |