[SparkML] MLlib Pipeline ๊ตฌ์ถ•ํ•˜๊ธฐ

2022. 5. 21. 15:59ใ†๐Ÿ›  Data Engineering/Apache Spark

 

Spark MLlib์€ Pipeline ์ปดํฌ๋„ŒํŠธ๋ฅผ ์ง€์›ํ•œ๋‹ค.
Pipeline์€ ๋จธ์‹ ๋Ÿฌ๋‹์˜ ์›Œํฌํ”Œ๋กœ์šฐ๋ฅผ ๋งํ•˜๊ธฐ๋„ ํ•˜๋ฉฐ, ์—ฌ๋Ÿฌ Stage๋ฅผ ๋‹ด๊ณ  ์žˆ๋‹ค. persist() ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ์ €์žฅ๋„ ๊ฐ€๋Šฅํ•˜๋‹ค.

 

 

 

 

 

 

๊ทธ๋ ‡๋‹ค๋ฉด ์‹ค์ œ ์ฝ”๋“œ๋Š” ์–ด๋–ป๊ฒŒ ์งค๊นŒ?

์ด์ „ ํฌ์ŠคํŒ…์—์„œ ๋‹ค๋ค˜๋˜ ํƒ์‹œ ๋ฐ์ดํ„ฐ๋กœ, MLlib Pipeline์„ ๊ตฌ์ถ•ํ•ด๋ณด์ž.

๊ธฐ๋ณธ ์„ธํŒ…์€ ์ด์ „ Spark ํฌ์ŠคํŒ…๋“ค์„ ๋ณด๋ฉด ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

[๋ฐ์ดํ„ฐ https://mengu.tistory.com/50?category=932924]

 

[SparkSQL] ํƒ์‹œ ๋ฐ์ดํ„ฐ ๋‹ค์šด/์ „์ฒ˜๋ฆฌ/๋ถ„์„ feat. TLC

์ด์ „ ํฌ์ŠคํŒ…์—์„œ ๊ณต๋ถ€ํ•œ SparkSQL ์ง€์‹์„ ๋ฐ”ํƒ•์œผ๋กœ, ์‹ค์ œ Taxi ๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•ด๋ณด์ž. * ์ „์ฒ˜๋ฆฌ๋ž€? ์ด์ƒ์น˜ ์ œ๊ฑฐ, ๊ทธ๋ฃนํ™” ๋“ฑ ๋ฐ์ดํ„ฐ ๋ถ„์„์ด ์šฉ์ดํ•˜๋„๋ก ๋ฐ์ดํ„ฐ๋ฅผ ๋ณ€ํ˜•ํ•˜๋Š” ๊ณผ์ •์„ ๋งํ•œ๋‹ค. TLC Trip Recor

mengu.tistory.com

 

 

"๋ณธ ํฌ์ŠคํŒ…์€ ํŒจ์ŠคํŠธ์บ ํผ์Šค์˜ ๊ฐ•์˜๋ฅผ ๋“ฃ๊ณ , ์ •๋ฆฌํ•œ ์ž๋ฃŒ์ž„์„ ๋ฐํž™๋‹ˆ๋‹ค."

 

 

 

 

Basic Settings

๊ธฐ๋ณธ ์„ธํŒ…์ด๋‹ค. ์ž„ํฌํŠธ ํ•ด์ค˜์•ผ ํ•  ๊ฒƒ๋“ค์€ ๋ฏธ๋ฆฌ ํ•ด๋†จ๊ณ , SparkSession์„ ์—ด์–ด๋‘์—ˆ๋‹ค.

 

# ํฐํŠธ ์„ค์ •
from matplotlib import font_manager, rc
font_path = 'C:\\WINDOWS\\Fonts\\HBATANG.TTF'
font = font_manager.FontProperties(fname=font_path).get_name()
rc('font', family=font)

# basic settings
import os
import findspark
findspark.init(os.environ.get("SPARK_HOME"))
import pyspark
from pyspark import SparkConf, SparkContext
import pandas as pd
import faulthandler
faulthandler.enable()
from pyspark.sql import SparkSession

MAX_MEMORY = "5g"
spark = SparkSession.builder.master('local').appName("taxi-fare-prediction")\
.config("spark.executor.memory", MAX_MEMORY)\
.config("spark.driver.memory", MAX_MEMORY).getOrCreate()

# ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋Š” ํŒŒ์ผ
zone_data = "C:/DE study/data-engineering/01-spark/data/taxi_zone_lookup.csv"
trip_files = "C:/DE study/data-engineering/01-spark/data/trips/*"

# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
trips_df = spark.read.csv(f"file:///{trip_files}", inferSchema = True, header = True)
zone_df = spark.read.csv(f"file:///{zone_data}", inferSchema = True, header = True)

# ๋ฐ์ดํ„ฐ ์Šคํ‚ค๋งˆ
trips_df.printSchema()
zone_df.printSchema()

# ๋ฐ์ดํ„ฐ createOrReplaceTempView()
trips_df.createOrReplaceTempView("trips")
zone_df.createOrReplaceTempView("zone")


root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: string (nullable = true)
 |-- tpep_dropoff_datetime: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)

root
 |-- LocationID: integer (nullable = true)
 |-- Borough: string (nullable = true)
 |-- Zone: string (nullable = true)
 |-- service_zone: string (nullable = true)

 

 

 

 

 

์›ํ•˜๋Š” DataFrame ๊ฐ€์ ธ์˜ค๊ธฐ


Fare ์š”๊ธˆ์„ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•œ ๋ฐ์ดํ„ฐ๋“ค์„ ๊ฐ€์ ธ์˜จ๋‹ค.

 

(1) ์ง€๋ถˆ ๋ฐฉ๋ฒ•

(2) ํƒ‘์Šน ์‹œ๊ฐ„

(3) ํƒ‘์Šน ์žฅ์†Œ

(4) ํ•˜์ฐจ ์žฅ์†Œ

(5) ํƒ‘์Šน ์š”์ผ

(6) ์†๋‹˜ ์ˆ˜

(7) ๊ฑฐ๋ฆฌ

(8) ์š”๊ธˆ

 

๊ฐ ๋ฐ์ดํ„ฐ๋“ค์— ์ด์ƒ์น˜์™€ ๊ฒฐ์ธก์น˜๊ฐ€ ์กด์žฌํ•จ์œผ๋กœ, SQL์„ ํ†ตํ•ด ๊ฐ€์ ธ์˜ฌ ๋•Œ ํ•„ํ„ฐ๋ง์„ ํ•ด์„œ ๊ฐ€์ ธ์˜จ๋‹ค. ์ž์„ธํ•œ ์ „์ฒ˜๋ฆฌ ๋‚ด์šฉ์€ ๋‹ค์Œ ํฌ์ŠคํŒ…์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

query = '''
SELECT
    PULocationID as pickup_location_id,
    DOLocationID as dropoff_location_id,
    payment_type,
    HOUR(tpep_pickup_datetime) as pickup_time,
    DATE_FORMAT(TO_DATE(tpep_pickup_datetime), 'EEEE') AS day_of_week,
    passenger_count,
    trip_distance,
    total_amount
FROM
    trips
WHERE
    total_amount < 5000
    AND total_amount > 0
    AND trip_distance > 0
    AND trip_distance < 500
    AND passenger_count < 4
    AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
    AND TO_DATE(tpep_pickup_datetime) < '2021-08-01'
'''

# sql ๋ฌธ์„ ํ†ตํ•ด ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ๊ฐ€์ ธ์˜ค๊ธฐ
data_df = spark.sql(query)
data_df.createOrReplaceTempView('data')
data_df.printSchema()


root
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- pickup_time: integer (nullable = true)
 |-- day_of_week: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- total_amount: double (nullable = true)

 

 

 

Train, Test ๋ฐ์ดํ„ฐ์…‹ ๋‚˜๋ˆ„๊ธฐ

 

# train, test dataframe ๋‚˜๋ˆ„๊ธฐ
train_df, test_df = data_df.randomSplit([0.8, 0.2], seed=1)


# ๋‚˜์ค‘์˜ ๋กœ๋”ฉ ์‹œ๊ฐ„์„ ์•„๋ผ๊ธฐ ์œ„ํ•ด ๋ฏธ๋ฆฌ ์ €์žฅํ•ด๋‘ก๋‹ˆ๋‹ค.
data_dir = "C:/์ €์žฅ ๊ฒฝ๋กœ"
train_df = spark.read.parquet(f'{data_dir}/train/')
test_df = spark.read.parquet(f'{data_dir}/test/')

 

 

 

 

 

Pipeline ๊ตฌ์ถ•


ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ตฌ์„ฑํ•˜๊ธฐ์— ์•ž์„œ, ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๊ฒƒ์€ '์–ด๋–ป๊ฒŒ ์ „์ฒ˜๋ฆฌ'ํ• ์ง€ ์งœ๋Š” ๊ฒƒ์ด๋‹ค. ํ˜„์žฌ ๋ฐ์ดํ„ฐ์—์„œ ๋จธ์‹ ๋Ÿฌ๋‹ ํ•™์Šต์„ ์œ„ํ•ด ์ตœ์†Œํ•œ์œผ๋กœ ํ•ด์ค˜์•ผ ํ•˜๋Š” ์ „์ฒ˜๋ฆฌ๋กœ๋Š” (1) ์›-ํ•ซ ์ธ์ฝ”๋”ฉ (2) ์ •๊ทœํ™” (3) ๋ฒกํ„ฐํ™” ์ •๋„์ด๋‹ค. ์ด ์ „์ฒ˜๋ฆฌ๋“ค์„ ํ™•์ • ์ง“๊ณ  ๋‚œ ํ›„์— ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ตฌ์„ฑํ•˜๋„๋ก ํ•œ๋‹ค.

 

 

(1) ์›-ํ•ซ ์ธ์ฝ”๋”ฉ

์นดํ…Œ๊ณ ๋ฆฌํ˜• ์นผ๋Ÿผ๋งŒ ์›-ํ•ซ ์ธ์ฝ”๋”ฉ์„ ํ•ด์ฃผ๋ฉด ๋œ๋‹ค. 

 

- StringIndexer : ์นดํ…Œ๊ณ ๋ฆฌํ˜•์„ ์›-ํ•ซ ์ธ์ฝ”๋”ฉํ•˜๊ธฐ ์ „์—, ์ˆซ์ž๋ฅผ ๋จผ์ € ๋ถ€์—ฌํ•ด์ค€๋‹ค. ex) (๋”ธ๊ธฐ, ๋ฐ”๋‚˜๋‚˜, ์ดˆ์ฝ”) -> (1, 3, 2)

- OneHotEncoder : ์ˆซ์ž ๋ถ€์—ฌ๋œ ๊ฒƒ์„ ๋ฒกํ„ฐํ™”ํ•œ๋‹ค. ex) 1 -> [1, 0, 0]

 

# ์›-ํ•ซ ์ธ์ฝ”๋”ฉ
from pyspark.ml.feature import OneHotEncoder, StringIndexer


# ์นดํ…Œ๊ณ ๋ฆฌํ˜• ์นผ๋Ÿผ
cat_feats = [
    'pickup_location_id',
    'dropoff_location_id',
    'day_of_week'
]

# ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์„ฑ์„ ์œ„ํ•œ stages ๋ฆฌ์ŠคํŠธ
stages = []


# index๋ฅผ ๋ฐ”๊พธ๊ณ , ๊ทธ ๋ฐ”๋€ indexer์— ์›-ํ•ซ ์ธ์ฝ”๋”ฉ์„ ์ ์šฉํ•ด์ค€๋‹ค.
for c in cat_feats:
    cat_indexer = StringIndexer(inputCol=c, outputCol = c + "_idx").setHandleInvalid("keep")
    onehot_encoder = OneHotEncoder(inputCols = [cat_indexer.getOutputCol()], outputCols=[c + '_onehot'])
    stages += [cat_indexer, onehot_encoder]

 

 

 

 

(2) ์ˆซ์žํ˜• ๋ฐ์ดํ„ฐ, ๋ฒกํ„ฐํ™” ๋ฐ ์Šค์ผ€์ผ๋Ÿฌ ์ ์šฉ

์ˆซ์žํ˜• ๋ฐ์ดํ„ฐ๋ฅผ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ๋กœ ๋ฌถ๊ณ , ์Šค์ผ€์ผ๋Ÿฌ๋ฅผ ์ ์šฉํ•œ๋‹ค.

 

# ๋ฒกํ„ฐํ™” + ์Šค์ผ€์ผ๋Ÿฌ ์ ์šฉ
from pyspark.ml.feature import VectorAssembler, StandardScaler

# numericํ˜• ์นผ๋Ÿผ
num_feats = [
    'passenger_count',
    'trip_distance',
    'pickup_time'
]


# ๋ฒกํ„ฐํ™” ํ•œ ํ›„, ์Šค์ผ€์ผ๋Ÿฌ๋ฅผ ์ ์šฉํ•œ๋‹ค.
for n in num_feats:
    num_assembler = VectorAssembler(inputCols=[n], outputCol= n + '_vecotr')
    num_scaler = StandardScaler(inputCol=num_assembler.getOutputCol(), outputCol = n + '_scaled')
    stages += [num_assembler, num_scaler]

 

 

* ์ค‘๊ฐ„ stages ์ ๊ฒ€

 

stages


[StringIndexer_115f74e6efea,
 OneHotEncoder_714f494271bb,
 StringIndexer_806a3b8e8a32,
 OneHotEncoder_f7bc9266f650,
 StringIndexer_2f125ebb95a8,
 OneHotEncoder_eb212d50e427,
 VectorAssembler_5b7bfff3be42,
 StandardScaler_3e59d49af9ad,
 VectorAssembler_a61858dada0f,
 StandardScaler_c739fc7f7d49,
 VectorAssembler_f683b2eeb4d2,
 StandardScaler_96e5ba925088]

 

 

 

 

(3) ๋ชจ๋‘  VectorAssemble ํ•˜์—ฌ ํ›ˆ๋ จ์— ์ ํ•ฉํ•œ ๋ฐ์ดํ„ฐ ์…‹์„ ์™„์„ฑํ•˜๋„๋ก ๊ตฌ์ถ•ํ•˜์ž.

 

# inputs ์นผ๋Ÿผ
assembler_inputs = [c + '_onehot' for c in cat_feats] + [n + '_scaled' for n in num_feats]
print(assembler_inputs)


['pickup_location_id_onehot',
 'dropoff_location_id_onehot',
 'day_of_week_onehot',
 'passenger_count_scaled',
 'trip_distance_scaled',
 'pickup_time_scaled']
 
 
 
 # stages์— VectorAssemble ์ถ”๊ฐ€
assembler = VectorAssembler(inputCols=assembler_inputs, outputCol= 'feature_vector')
stages += [assembler]
stages


[StringIndexer_dfc09cc586be,
 OneHotEncoder_987fbfa36a2d,
 StringIndexer_bf2338365d7f,
 OneHotEncoder_5a91ea5195e8,
 StringIndexer_c416d64272f1,
 OneHotEncoder_0dfab0742066,
 VectorAssembler_4c5f47a3740c,
 StandardScaler_65dfe2363318,
 VectorAssembler_7e0a4e81ec39,
 StandardScaler_22d11d283c0a,
 VectorAssembler_c2b692153924,
 StandardScaler_debc924ffa61,
 VectorAssembler_c2c382815ebb]

 

 

 

 

(4) Final, Pipeline() ์„ ํ†ตํ•ด์„œ ๋‹ค ๋ฌถ์–ด์ฃผ๊ธฐ

 

# pipeline ๋งŒ๋“ค๊ธฐ
from pyspark.ml import Pipeline

transform_stages = stages
pipeline = Pipeline(stages = transform_stages)
fitted_transformer = pipeline.fit(train_df)

 

 

 

 

์˜ˆ์ธก & ์ถ”๋ก 


pipeline fit()๋ฅผ ํ•ด์ฃผ๊ณ , ๋’ค์ด์–ด ๋ฐ์ดํ„ฐ๋“ค์„ pipeline์— ํ†ต๊ณผ์‹œ์ผœ ์ „์ฒ˜๋ฆฌํ•ด์ฃผ์ž.

 

vtrain_df = fitted_transformer.transform(train_df)
vtest_df = fitted_transformer.transform(test_df)

 

 

์„ ํ˜• ํšŒ๊ท€ ๋ชจ๋ธ์„ ๋งŒ๋“ค๊ณ , ํ›ˆ๋ จ์‹œํ‚ค๊ณ  ์˜ˆ์ธกํ•˜๊ธฐ

 

lr = LinearRegression(
    maxIter=50,
    solver='normal',
    labelCol='total_amount',
    featuresCol='feature_vector'
)


# ํ›ˆ๋ จํ•˜๊ธฐ
model = lr.fit(vtrain_df)


# ์ถ”๋ก ํ•˜๊ธฐ
prediction = model.transform(vtest_df)


# ์ถ”๋ก  ๊ฒฐ๊ณผ ํ™•์ธํ•˜๊ธฐ
prediction.select(['trip_distance', 'day_of_week','total_amount','prediction']).show(5)


+-------------+-----------+------------+------------------+
|trip_distance|day_of_week|total_amount|        prediction|
+-------------+-----------+------------+------------------+
|          0.8|   Thursday|       120.3|   89.980284493094|
|        24.94|   Saturday|        70.8| 131.2247006891777|
|         0.01|  Wednesday|      102.36|13.995815841451657|
|          0.1|     Monday|       71.85|10.013635081914366|
|          0.5|    Tuesday|         7.8|   12.046081421887|
+-------------+-----------+------------+------------------+
only showing top 5 rows

 

 

 

 

Evaluator


 

(1) RMSE

 

model.summary.rootMeanSquaredError


5.818945295076586

 

 

(2) R2

 

model.summary.r2


0.7997047915616821

 

 

 

 

 

์ด์ „ ํฌ์ŠคํŒ…์— ๋น„ํ•ด ์„ฑ๋Šฅ์ด ์‚ด์ง ํ–ฅ์ƒ๋œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„  Spark๋กœ ๋จธ์‹ ๋Ÿฌ๋‹ ๊ตฌ์ถ•ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ์•„๋ณด์•˜๋‹ค. ํ›จ์”ฌ ๋‹ค์–‘ํ•œ ์ „์ฒ˜๋ฆฌ ๊ธฐ๊ตฌ๋“ค์ด ์กด์žฌํ•˜๋ฉฐ, ๋‹ค๋ฅด๊ฒŒ ์กฐํ•ฉํ•œ๋‹ค๋ฉด ๋ถ„๋ช… ๋” ์ข‹์€ ์„ฑ๋Šฅ์˜ ๋ชจ๋ธ์ด ๋งŒ๋“ค์–ด์งˆ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋Ÿฌ๋‹ˆ ์Šค์Šค๋กœ ์ด๊ฒƒ์ €๊ฒƒ ๋งŒ์ ธ๋ณด๊ธธ ๋ฐ”๋ž€๋‹ค.

์ˆ˜๊ณ ํ•˜์…จ์Šต๋‹ˆ๋‹ค.