2022. 5. 22. 18:13ใ๐ Data Engineering/Apache Spark
Parameter Tuning
๋จธ์ ๋ฌ๋์ ๋ค๋ค๋ณธ ์ฌ๋์ด๋ผ๋ฉด ์ต์ํ ๊ฐ๋ ์ผ ๊ฒ์ด๋ค.
๋จธ์ ๋ฌ๋์ ์๊ณ ๋ฆฌ์ฆ์ ์ด์ฉํ๋๋ผ๋, ๊ทธ ์์ ๋ณ์๋ฅผ ์ด๋ป๊ฒ ์กฐ์ ํ๋์ ๋ฐ๋ผ์ ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๋ฌ๋ผ์ง๋ค.
MLlib์ ๋น์ฐํ๋ Paramter Tuning ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ค.
์ฝ๋๋ฅผ ์ดํด๋ณด๋ฉฐ ์ด๋ป๊ฒ ํ๋์ ํ๋์ง ์ดํด๋ณด์.
[๋ง์ฝ Spark์ ๋ํด ์๋ฌด๊ฒ๋ ๋ชจ๋ฅธ๋ค๋ฉด?]
https://mengu.tistory.com/26?category=932924
[SparkML์ ๋ชจ๋ฅธ๋ค๋ฉด?]
https://mengu.tistory.com/56?category=932924
[์ค์ต ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ง๊ณ ์์ง ์๋ค๋ฉด?]
https://mengu.tistory.com/50?category=932924
"๋ณธ ํฌ์คํ ์ ํจ์คํธ์บ ํผ์ค์ ๊ฐ์๋ฅผ ๋ฃ๊ณ , ์ ๋ฆฌํ ์๋ฃ์์ ๋ฐํ๋๋ค."
Basic Settings
๊ธฐ๋ณธ ์ธํ ์ด๋ค. ์ํฌํธ ํด์ค์ผ ํ ๊ฒ๋ค์ ๋ฏธ๋ฆฌ ํด๋จ๊ณ , SparkSession์ ์ด์ด๋์๋ค.
# ํฐํธ ์ค์
from matplotlib import font_manager, rc
font_path = 'C:\\WINDOWS\\Fonts\\HBATANG.TTF'
font = font_manager.FontProperties(fname=font_path).get_name()
rc('font', family=font)
# basic settings
import os
import findspark
findspark.init(os.environ.get("SPARK_HOME"))
import pyspark
from pyspark import SparkConf, SparkContext
import pandas as pd
import faulthandler
faulthandler.enable()
from pyspark.sql import SparkSession
MAX_MEMORY = "5g"
spark = SparkSession.builder.master('local').appName("taxi-fare-prediction")\
.config("spark.executor.memory", MAX_MEMORY)\
.config("spark.driver.memory", MAX_MEMORY).getOrCreate()
# ๋ฐ์ดํฐ๊ฐ ์๋ ํ์ผ
zone_data = "C:/DE study/data-engineering/01-spark/data/taxi_zone_lookup.csv"
trip_files = "C:/DE study/data-engineering/01-spark/data/trips/*"
# ๋ฐ์ดํฐ ๋ก๋
trips_df = spark.read.csv(f"file:///{trip_files}", inferSchema = True, header = True)
zone_df = spark.read.csv(f"file:///{zone_data}", inferSchema = True, header = True)
# ๋ฐ์ดํฐ ์คํค๋ง
trips_df.printSchema()
zone_df.printSchema()
# ๋ฐ์ดํฐ createOrReplaceTempView()
trips_df.createOrReplaceTempView("trips")
zone_df.createOrReplaceTempView("zone")
root
|-- VendorID: integer (nullable = true)
|-- tpep_pickup_datetime: string (nullable = true)
|-- tpep_dropoff_datetime: string (nullable = true)
|-- passenger_count: integer (nullable = true)
|-- trip_distance: double (nullable = true)
|-- RatecodeID: integer (nullable = true)
|-- store_and_fwd_flag: string (nullable = true)
|-- PULocationID: integer (nullable = true)
|-- DOLocationID: integer (nullable = true)
|-- payment_type: integer (nullable = true)
|-- fare_amount: double (nullable = true)
|-- extra: double (nullable = true)
|-- mta_tax: double (nullable = true)
|-- tip_amount: double (nullable = true)
|-- tolls_amount: double (nullable = true)
|-- improvement_surcharge: double (nullable = true)
|-- total_amount: double (nullable = true)
|-- congestion_surcharge: double (nullable = true)
root
|-- LocationID: integer (nullable = true)
|-- Borough: string (nullable = true)
|-- Zone: string (nullable = true)
|-- service_zone: string (nullable = true)
DataFrame
๋ฐ์ดํฐ๋ฅผ ์ผ๋จ DataFrame ํ์์ผ๋ก ๊ฐ์ ธ์์ผ ํ๋ค. ๋ํ ์ํ๋ Input๊ณผ Output์ ๋จ๊ธฐ๊ณ ๋ ์ ๊ฑฐํด๋ฌ์ผ ํ๋ค.
query = '''
SELECT
trip_distance,
total_amount
FROM
trips
WHERE
total_amount < 5000
AND total_amount > 0
AND trip_distance > 0
AND trip_distance < 500
AND passenger_count < 4
AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
AND TO_DATE(tpep_pickup_datetime) < '2021-08-01'
'''
# ์ํ๋ data๋ง ๋จ๊ฒจ๋๊ธฐ
data_df = spark.sql(query)
data_df.createOrReplaceTempView('data')
data_df.show
+-------------+------------+
|trip_distance|total_amount|
+-------------+------------+
| 16.5| 70.07|
| 1.13| 11.16|
| 2.68| 18.59|
| 12.4| 43.8|
| 9.7| 32.3|
| 9.3| 43.67|
| 9.58| 46.1|
| 16.2| 45.3|
| 3.58| 19.3|
| 0.91| 14.8|
| 2.57| 12.8|
| 0.4| 5.3|
| 3.26| 17.3|
| 13.41| 47.25|
| 18.3| 61.42|
| 1.53| 14.16|
| 2.0| 11.8|
| 16.6| 54.96|
| 15.5| 56.25|
| 1.3| 16.8|
+-------------+------------+
only showing top 20 rows
์ด๋ฒ ํฌ์คํ ์ ๋ชฉ์ ์ ํ๋ผ๋ฏธํฐ ํ๋์ด๋ค. ์ผ์ ํ ์์ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ง๊ณ ๊ณ์ ๋ชจ๋ธ์ ํ ์คํธํ์ฌ, ์ด๋ค ํ๋ผ๋ฏธํฐ๊ฐ ์ข์ ์ฑ๋ฅ์ ๋ด๋์ง ๋น๊ต ๋ถ์ํด์ผ ํ๋ค. ๊ทธ๋ ๋ค๋ฉด ๋ฐ์ดํฐ์ ํฌ๊ธฐ๋ ์ผ๋ง๋ก ์ก์์ผ ํ ๊น?
toy_df = data_df.sample(False, 0.01, seed=1)
๋ณดํต์ ์๋ ๋ฐ์ดํฐ์ 10%๋ฅผ ๋ฝ์์ ์ฌ์ฉํ๋ค. ํ์ง๋ง ๊ธฐ์กด ๋ฐ์ดํฐ์ ํฌ๊ธฐ๊ฐ ๋๋ฌด ํฐ ๊ฒฝ์ฐ, ์ปดํจํฐ๊ฐ ํฐ์ ธ๋ฒ๋ฆด ์ ์๋ค. ํ์ ๋ฐ์ดํฐ๋ ํฌ๊ธฐ๊ฐ ์ด๋ง์ด๋งํ๊ธฐ์... ๊ทธ๋ฆฌ๊ณ 10%๋ฅผ ๋ฝ์์ ๋๋ ค๋ดค๋๋ฐ ์ปดํจํฐ๊ฐ ํฐ์ ธ์ 1%๋ก ์ค์ ํ๋ค.
์ ์ฒ๋ฆฌ & Pipeline
์ด์ ํฌ์คํ ์์ ์งํํ๋ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ์ ๊ทธ๋๋ก ๋ง๋ค์ด ์ค๋ค.
(1) ์-ํซ ์ธ์ฝ๋ฉ
์นดํ ๊ณ ๋ฆฌํ ์นผ๋ผ๋ง ์-ํซ ์ธ์ฝ๋ฉ์ ํด์ฃผ๋ฉด ๋๋ค.
- StringIndexer : ์นดํ ๊ณ ๋ฆฌํ์ ์-ํซ ์ธ์ฝ๋ฉํ๊ธฐ ์ ์, ์ซ์๋ฅผ ๋จผ์ ๋ถ์ฌํด์ค๋ค. ex) (๋ธ๊ธฐ, ๋ฐ๋๋, ์ด์ฝ) -> (1, 3, 2)
- OneHotEncoder : ์ซ์ ๋ถ์ฌ๋ ๊ฒ์ ๋ฒกํฐํํ๋ค. ex) 1 -> [1, 0, 0]
# ์-ํซ ์ธ์ฝ๋ฉ
from pyspark.ml.feature import OneHotEncoder, StringIndexer
# ์นดํ
๊ณ ๋ฆฌํ ์นผ๋ผ
cat_feats = [
'pickup_location_id',
'dropoff_location_id',
'day_of_week'
]
# ํ์ดํ๋ผ์ธ ๊ตฌ์ฑ์ ์ํ stages ๋ฆฌ์คํธ
stages = []
# index๋ฅผ ๋ฐ๊พธ๊ณ , ๊ทธ ๋ฐ๋ indexer์ ์-ํซ ์ธ์ฝ๋ฉ์ ์ ์ฉํด์ค๋ค.
for c in cat_feats:
cat_indexer = StringIndexer(inputCol=c, outputCol = c + "_idx").setHandleInvalid("keep")
onehot_encoder = OneHotEncoder(inputCols = [cat_indexer.getOutputCol()], outputCols=[c + '_onehot'])
stages += [cat_indexer, onehot_encoder]
(2) ์ซ์ํ ๋ฐ์ดํฐ, ๋ฒกํฐํ ๋ฐ ์ค์ผ์ผ๋ฌ ์ ์ฉ
์ซ์ํ ๋ฐ์ดํฐ๋ฅผ ํ๋์ ๋ฒกํฐ๋ก ๋ฌถ๊ณ , ์ค์ผ์ผ๋ฌ๋ฅผ ์ ์ฉํ๋ค.
# ๋ฒกํฐํ + ์ค์ผ์ผ๋ฌ ์ ์ฉ
from pyspark.ml.feature import VectorAssembler, StandardScaler
# numericํ ์นผ๋ผ
num_feats = [
'passenger_count',
'trip_distance',
'pickup_time'
]
# ๋ฒกํฐํ ํ ํ, ์ค์ผ์ผ๋ฌ๋ฅผ ์ ์ฉํ๋ค.
for n in num_feats:
num_assembler = VectorAssembler(inputCols=[n], outputCol= n + '_vecotr')
num_scaler = StandardScaler(inputCol=num_assembler.getOutputCol(), outputCol = n + '_scaled')
stages += [num_assembler, num_scaler]
* ์ค๊ฐ stages ์ ๊ฒ
stages
[StringIndexer_115f74e6efea,
OneHotEncoder_714f494271bb,
StringIndexer_806a3b8e8a32,
OneHotEncoder_f7bc9266f650,
StringIndexer_2f125ebb95a8,
OneHotEncoder_eb212d50e427,
VectorAssembler_5b7bfff3be42,
StandardScaler_3e59d49af9ad,
VectorAssembler_a61858dada0f,
StandardScaler_c739fc7f7d49,
VectorAssembler_f683b2eeb4d2,
StandardScaler_96e5ba925088]
(3) ๋ชจ๋ VectorAssemble ํ์ฌ ํ๋ จ์ ์ ํฉํ ๋ฐ์ดํฐ ์ ์ ์์ฑํ๋๋ก ๊ตฌ์ถํ์.
# inputs ์นผ๋ผ
assembler_inputs = [c + '_onehot' for c in cat_feats] + [n + '_scaled' for n in num_feats]
print(assembler_inputs)
['pickup_location_id_onehot',
'dropoff_location_id_onehot',
'day_of_week_onehot',
'passenger_count_scaled',
'trip_distance_scaled',
'pickup_time_scaled']
# stages์ VectorAssemble ์ถ๊ฐ
assembler = VectorAssembler(inputCols=assembler_inputs, outputCol= 'feature_vector')
stages += [assembler]
stages
[StringIndexer_dfc09cc586be,
OneHotEncoder_987fbfa36a2d,
StringIndexer_bf2338365d7f,
OneHotEncoder_5a91ea5195e8,
StringIndexer_c416d64272f1,
OneHotEncoder_0dfab0742066,
VectorAssembler_4c5f47a3740c,
StandardScaler_65dfe2363318,
VectorAssembler_7e0a4e81ec39,
StandardScaler_22d11d283c0a,
VectorAssembler_c2b692153924,
StandardScaler_debc924ffa61,
VectorAssembler_c2c382815ebb]
(4) Final, Pipeline()์ ํตํด์ ๋ค ๋ฌถ์ด์ฃผ๊ธฐ
# pipeline ๋ง๋ค๊ธฐ
from pyspark.ml import Pipeline
transform_stages = stages
pipeline = Pipeline(stages = transform_stages)
fitted_transformer = pipeline.fit(train_df)
Hyperparameter Tuning
from pyspark.ml import Pipeline
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator
(1) ํ์ดํ๋ผ์ธ์ ์ ํ ํ๊ท ์ถ๊ฐํ๊ธฐ
lr = LinearRegression(
maxIter=30,
solver='normal',
labelCol='total_amount',
featuresCol='feature_vector'
)
cv_stages = stages + [lr]
# ํ์ดํ๋ผ์ธ ์์ฑํ๊ธฐ
cv_pipeline = Pipeline(stages = cv_stages)
(2) Parameter map ์์ฑํ๊ธฐ
MLlib์์ Parameter map์ ๋ฐ๋ก ์์ฑํ ํ, ๊ต์ฐจ ๊ฒ์ฆ์ ๋ณ์๋ก ์ง์ด๋ฃ์ด์ผ ํ๋ค. ๊ทธ๋์ผ Parameter ๋ค์ ์ํํ์ฌ ๋ชจ๋ ๊ฒฝ์ฐ์ ์๋ฅผ ๋์๋ณผ ์ ์๋ค.
lr.elasticNetParam -> ์ ํ ํ๊ท์ elasticNet์ alpha ํ๋ผ๋ฏธํฐ๋ฅผ ๋งํ๋ค.
lr.regParam -> ์ ํ ํ๊ท์ regParam ํ๋ผ๋ฏธํฐ๋ฅผ ์๋ฏธํ๋ค.
# '\' ๋ ์ค๋ฐ๊ฟ์ ๊ฐ๋ฅํ๊ฒ ํ๋ ๋ฌธ์๋ก ์ฐ์ธ๋ค.
param_grid = ParamGridBuilder()\
.addGrid(lr.elasticNetParam, [0.1, 0.2, 0.3, 0.4, 0.5])\
.addGrid(lr.regParam, [0.01, 0.02, 0.03, 0.04, 0.05]).build()
(3) ๊ต์ฐจ ๊ฒ์ฆ ํจ์ ๋ง๋ค๊ธฐ
# estimator์๋ ํ์ดํ๋ผ์ธ์,
# ParamMaps์๋ ํ๋ผ๋ฏธํฐ ๋งต์
# evaluator์๋ ํ๊ฐ ์งํ๋ฅผ,
# numFolds์๋ ๊ต์ฐจ ๊ฒ์ฆ์ ํด๋ ์๋ฅผ ์
๋ ฅํ๋ค.
cross_val = CrossValidator(estimator = cv_pipeline,
estimatorParamMaps=param_grid,
evaluator=RegressionEvaluator(labelCol='total_amount'),
numFolds=5)
(4) ๊ต์ฐจ ๊ฒ์ฆ Start
cv_model = cross_val.fit(toy_df)
(5) Best Parameter ๋ฝ์๋ด๊ธฐ
alpha = cv_model.bestModel.stages[-1]._java_obj.getElasticNetParam()
reg_param = cv_model.bestModel.stages[-1]._java_obj.getRegParam()
print(f'alpha is {alpha}')
print(f'reg_param is {reg_param}')
alpha is 0.4
reg_param is 0.05
Training & ์ถ๋ก
(1) train/test ๋ฐ์ดํฐ๋ฅผ ํ์ดํ๋ผ์ธ ํต๊ณผ์ํค๊ธฐ
transform_stages = stages
pipeline = Pipeline(stages = transform_stages)
fitted_transformer = pipeline.fit(train_df)
vtrain_df = fitted_transformer.transform(train_df)
vtest_df = fitted_transformer.transform(test_df)
(2) Training
lr = LinearRegression(
maxIter=50,
solver='normal',
labelCol='total_amount',
featuresCol='feature_vector'
elasticNetParam = alpha,
regParam = reg_param
)
model = lr.fit(vtrain_df)
(3) ์ถ๋ก ๋ฐ ์ฑ๋ฅ ํ๊ฐ
prediction = model.transform(vtest_df)
predictions.select(["trip_distance", "day_of_week", "total_amount", "prediction"]).show()
+-------------+-----------+------------+------------------+
|trip_distance|day_of_week|total_amount| prediction|
+-------------+-----------+------------+------------------+
| 3.1| Saturday| 22.55| 18.44882896087039|
| 7.9| Saturday| 30.3|28.898380309569866|
| 1.4| Saturday| 16.0|13.635679102878225|
| 1.1| Tuesday| 12.95|14.050765065622219|
| 0.7| Saturday| 9.8|12.151950922741019|
| 1.1| Monday| 11.8|14.302981883348586|
| 3.7| Friday| 24.35|19.745504411653762|
| 2.4| Saturday| 14.75|16.012957291356248|
| 3.4| Saturday| 20.15|17.933430832644525|
| 6.5| Friday| 32.75|27.541169751290077|
| 2.6| Saturday| 17.8|16.614148948305857|
| 0.2| Saturday| 5.15| 7.565190314421683|
| 0.9| Sunday| 6.3| 9.491154409419867|
| 9.9| Saturday| 46.35|38.455267206674144|
| 4.5| Saturday| 22.85|15.766824456415733|
| 0.6| Tuesday| 4.8| 7.422348054207551|
| 1.4| Tuesday| 10.8|12.006766375046109|
| 4.3| Saturday| 20.3|19.040664139319205|
| 3.9| Thursday| 17.8|18.389827313235557|
| 5.8| Tuesday| 27.0|23.073838896480723|
+-------------+-----------+------------+------------------+
only showing top 20 rows
RMSE
# before
model.summary.rootMeanSquaredError
5.818945295076586
# After
model.summary.rootMeanSquaredError
5.610492491836879
R2
# before
model.summary.r2
0.7997047915616821
# After
model.summary.r2
0.8108436137289087
ํ๋ผ๋ฏธ๋ด ํ๋์ ํ๊ธฐ ์ ๊ณผ ์ฑ๋ฅ ์ฐจ์ด๊ฐ ๊ฝค ๋๋ค๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
๋ชจ๋ธ ์ ์ฅ ๋ฐ ๋ถ๋ฌ์ค๊ธฐ
# ์ ์ฅํ๊ธฐ
model_dir = "์ ์ฅ ๊ฒฝ๋ก"
model.save(model_dir)
# ๋ถ๋ฌ์ค๊ธฐ
from pyspark.ml.regression import LinearRegressionModel
lr_model = LinearRegressionModel().load(model_dir)
# ๋ฐ๋ก ์ถ๋ก ํ๊ธฐ
predictions = lr_model.transform(vtest_df)
ํ๋ผ๋ฏธํฐ๋ฅผ ํ๋ํ๊ณ , ์ต์ ์ ํ๋ผ๋ฏธํฐ๋ก ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋์ฌ๋ณด์๋ค. ๋ ๋์๊ฐ ๋ชจ๋ธ์ ์ ์ฅํ๊ณ , ์ด๋ป๊ฒ ๋ถ๋ฌ์ค๋์ง๋ ํ์ธํ๋ค. ์ด์ ์ค์ค๋ก ๋ค๋ฅธ ๋ชจ๋ธ๋ ์ ์ฉํ๋ฉด์ MLlib์ ์ต์ํด์ ธ ๋ณด์!
GOOD BYE!
'๐ Data Engineering > Apache Spark' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[SparkML] ALS, ์ถ์ฒ ์๊ณ ๋ฆฌ์ฆ ํ์ฉํ๊ธฐ (0) | 2022.05.23 |
---|---|
[SparkML] MLlib Pipeline ๊ตฌ์ถํ๊ธฐ (0) | 2022.05.21 |
[SparkML] MLlib ๊ฐ๋ ๋ฐ ์ค์ต ์ฝ๋ (0) | 2022.05.20 |
[SparkSQL] ํ์ ๋ฐ์ดํฐ ๋ค์ด/์ ์ฒ๋ฆฌ/๋ถ์ feat. TLC (0) | 2022.05.10 |
[SparkSQL] Catalyst, Tungsten ์๋ ์๋ฆฌ (0) | 2022.05.09 |