[SparkML] MLlib Parameter ํŠœ๋‹ ๊ฐœ๋… ๋ฐ ์ฝ”๋“œ

2022. 5. 22. 18:13ใ†๐Ÿ›  Data Engineering/Apache Spark

 

Parameter Tuning

๋จธ์‹ ๋Ÿฌ๋‹์„ ๋‹ค๋ค„๋ณธ ์‚ฌ๋žŒ์ด๋ผ๋ฉด ์ต์ˆ™ํ•œ ๊ฐœ๋…์ผ ๊ฒƒ์ด๋‹ค.

๋จธ์‹ ๋Ÿฌ๋‹์„ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ด์šฉํ•˜๋”๋ผ๋„, ๊ทธ ์•ˆ์˜ ๋ณ€์ˆ˜๋ฅผ ์–ด๋–ป๊ฒŒ ์กฐ์ •ํ•˜๋ƒ์— ๋”ฐ๋ผ์„œ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ๋‹ฌ๋ผ์ง„๋‹ค.

MLlib์€ ๋‹น์—ฐํžˆ๋„ Paramter Tuning ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•œ๋‹ค.

์ฝ”๋“œ๋ฅผ ์‚ดํŽด๋ณด๋ฉฐ ์–ด๋–ป๊ฒŒ ํŠœ๋‹์„ ํ•˜๋Š”์ง€ ์‚ดํŽด๋ณด์ž.

 

 

 

 

[๋งŒ์•ฝ Spark์— ๋Œ€ํ•ด ์•„๋ฌด๊ฒƒ๋„ ๋ชจ๋ฅธ๋‹ค๋ฉด?]

https://mengu.tistory.com/26?category=932924 

 

[Spark] Apache Spark ๊ฐœ๋… ๋ฐ ๋ฒ„์ „

Apache Spark์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์ž. Apache Spark๋ž€ ๋น…๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์˜คํ”ˆ์†Œ์Šค ์—”์ง„(๊ณ ์† ๋ถ„์‚ฐ์ฒ˜๋ฆฌ)์ด๋‹ค. ์ด๋ฏธ ์•„๋งˆ์กด, ์šฐ๋ฒ„, ๋„ทํ”Œ๋ฆญ์Šค ๋“ฑ ๋‹ค์–‘ํ•œ ๊ธฐ์—…์—์„œ ์“ฐ์ด๊ณ  ์žˆ๋‹ค. Apache Spark ๋“ฑ์žฅ ๋ฐฐ๊ฒฝ ๋ฐ์ดํ„ฐ๊ฐ€

mengu.tistory.com

 

 

 

[SparkML์„ ๋ชจ๋ฅธ๋‹ค๋ฉด?]

https://mengu.tistory.com/56?category=932924 

 

[Spark] MLlib ๊ฐœ๋… ๋ฐ ์‹ค์Šต ์ฝ”๋“œ

MLlib Machine Learning Library ์ŠคํŒŒํฌ์˜ ์ปดํฌ๋„ŒํŠธ ์ค‘ ํ•˜๋‚˜๋กœ, ๋จธ์‹ ๋Ÿฌ๋‹ ํŒŒ์ดํ”„๋ผ์ธ ๊ฐœ๋ฐœ์„ ์‰ฝ๊ฒŒ ํ•˜๊ธฐ ์œ„ํ•ด ๊ฐœ๋ฐœ๋˜์—ˆ๋‹ค. MLlib์—๋Š” ๋‹ค์‹œ ์•Œ๊ณ ๋ฆฌ์ฆ˜(Classification ๋“ฑ) + ํŒŒ์ดํ”„๋ผ์ธ(Training, Evaluation ๋“ฑ) +..

mengu.tistory.com

 

 

 

[์‹ค์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์ง€ ์•Š๋‹ค๋ฉด?]

https://mengu.tistory.com/50?category=932924 

 

[SparkSQL] ํƒ์‹œ ๋ฐ์ดํ„ฐ ๋‹ค์šด/์ „์ฒ˜๋ฆฌ/๋ถ„์„ feat. TLC

์ด์ „ ํฌ์ŠคํŒ…์—์„œ ๊ณต๋ถ€ํ•œ SparkSQL ์ง€์‹์„ ๋ฐ”ํƒ•์œผ๋กœ, ์‹ค์ œ Taxi ๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•ด๋ณด์ž. * ์ „์ฒ˜๋ฆฌ๋ž€? ์ด์ƒ์น˜ ์ œ๊ฑฐ, ๊ทธ๋ฃนํ™” ๋“ฑ ๋ฐ์ดํ„ฐ ๋ถ„์„์ด ์šฉ์ดํ•˜๋„๋ก ๋ฐ์ดํ„ฐ๋ฅผ ๋ณ€ํ˜•ํ•˜๋Š” ๊ณผ์ •์„ ๋งํ•œ๋‹ค. TLC Trip Recor

mengu.tistory.com

 

 

 

"๋ณธ ํฌ์ŠคํŒ…์€ ํŒจ์ŠคํŠธ์บ ํผ์Šค์˜ ๊ฐ•์˜๋ฅผ ๋“ฃ๊ณ , ์ •๋ฆฌํ•œ ์ž๋ฃŒ์ž„์„ ๋ฐํž™๋‹ˆ๋‹ค."

 

 

 

 

 

Basic Settings

๊ธฐ๋ณธ ์„ธํŒ…์ด๋‹ค. ์ž„ํฌํŠธ ํ•ด์ค˜์•ผ ํ•  ๊ฒƒ๋“ค์€ ๋ฏธ๋ฆฌ ํ•ด๋†จ๊ณ , SparkSession์„ ์—ด์–ด๋‘์—ˆ๋‹ค.

 

# ํฐํŠธ ์„ค์ •
from matplotlib import font_manager, rc
font_path = 'C:\\WINDOWS\\Fonts\\HBATANG.TTF'
font = font_manager.FontProperties(fname=font_path).get_name()
rc('font', family=font)

# basic settings
import os
import findspark
findspark.init(os.environ.get("SPARK_HOME"))
import pyspark
from pyspark import SparkConf, SparkContext
import pandas as pd
import faulthandler
faulthandler.enable()
from pyspark.sql import SparkSession

MAX_MEMORY = "5g"
spark = SparkSession.builder.master('local').appName("taxi-fare-prediction")\
.config("spark.executor.memory", MAX_MEMORY)\
.config("spark.driver.memory", MAX_MEMORY).getOrCreate()

# ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋Š” ํŒŒ์ผ
zone_data = "C:/DE study/data-engineering/01-spark/data/taxi_zone_lookup.csv"
trip_files = "C:/DE study/data-engineering/01-spark/data/trips/*"

# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
trips_df = spark.read.csv(f"file:///{trip_files}", inferSchema = True, header = True)
zone_df = spark.read.csv(f"file:///{zone_data}", inferSchema = True, header = True)

# ๋ฐ์ดํ„ฐ ์Šคํ‚ค๋งˆ
trips_df.printSchema()
zone_df.printSchema()

# ๋ฐ์ดํ„ฐ createOrReplaceTempView()
trips_df.createOrReplaceTempView("trips")
zone_df.createOrReplaceTempView("zone")


root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: string (nullable = true)
 |-- tpep_dropoff_datetime: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)

root
 |-- LocationID: integer (nullable = true)
 |-- Borough: string (nullable = true)
 |-- Zone: string (nullable = true)
 |-- service_zone: string (nullable = true)

 

 

 

 

 

 

 

DataFrame


๋ฐ์ดํ„ฐ๋ฅผ ์ผ๋‹จ DataFrame ํ˜•์‹์œผ๋กœ ๊ฐ€์ ธ์™€์•ผ ํ•œ๋‹ค. ๋˜ํ•œ ์›ํ•˜๋Š” Input๊ณผ Output์„ ๋‚จ๊ธฐ๊ณ ๋Š” ์ œ๊ฑฐํ•ด๋‘ฌ์•ผ ํ•œ๋‹ค.

 

 

query = '''
SELECT
    trip_distance,
    total_amount
FROM
    trips
WHERE
    total_amount < 5000
    AND total_amount > 0
    AND trip_distance > 0
    AND trip_distance < 500
    AND passenger_count < 4
    AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
    AND TO_DATE(tpep_pickup_datetime) < '2021-08-01'
'''

# ์›ํ•˜๋Š” data๋งŒ ๋‚จ๊ฒจ๋‘๊ธฐ
data_df = spark.sql(query)
data_df.createOrReplaceTempView('data')


data_df.show


+-------------+------------+
|trip_distance|total_amount|
+-------------+------------+
|         16.5|       70.07|
|         1.13|       11.16|
|         2.68|       18.59|
|         12.4|        43.8|
|          9.7|        32.3|
|          9.3|       43.67|
|         9.58|        46.1|
|         16.2|        45.3|
|         3.58|        19.3|
|         0.91|        14.8|
|         2.57|        12.8|
|          0.4|         5.3|
|         3.26|        17.3|
|        13.41|       47.25|
|         18.3|       61.42|
|         1.53|       14.16|
|          2.0|        11.8|
|         16.6|       54.96|
|         15.5|       56.25|
|          1.3|        16.8|
+-------------+------------+
only showing top 20 rows

 

 

 

์ด๋ฒˆ ํฌ์ŠคํŒ…์˜ ๋ชฉ์ ์€ ํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹์ด๋‹ค. ์ผ์ •ํ•œ ์–‘์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ง€๊ณ  ๊ณ„์† ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•˜์—ฌ, ์–ด๋–ค ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋‚ด๋Š”์ง€ ๋น„๊ต ๋ถ„์„ํ•ด์•ผ ํ•œ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ๋Š” ์–ผ๋งˆ๋กœ ์žก์•„์•ผ ํ• ๊นŒ?

 

 

toy_df = data_df.sample(False, 0.01, seed=1)

 

 

๋ณดํ†ต์€ ์›๋ž˜ ๋ฐ์ดํ„ฐ์˜ 10%๋ฅผ ๋ฝ‘์•„์„œ ์‚ฌ์šฉํ•œ๋‹ค. ํ•˜์ง€๋งŒ ๊ธฐ์กด ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ๊ฐ€ ๋„ˆ๋ฌด ํฐ ๊ฒฝ์šฐ, ์ปดํ“จํ„ฐ๊ฐ€ ํ„ฐ์ ธ๋ฒ„๋ฆด ์ˆ˜ ์žˆ๋‹ค. ํƒ์‹œ ๋ฐ์ดํ„ฐ๋Š” ํฌ๊ธฐ๊ฐ€ ์–ด๋งˆ์–ด๋งˆํ•˜๊ธฐ์—... ๊ทธ๋ฆฌ๊ณ  10%๋ฅผ ๋ฝ‘์•„์„œ ๋Œ๋ ค๋ดค๋Š”๋ฐ ์ปดํ“จํ„ฐ๊ฐ€ ํ„ฐ์ ธ์„œ 1%๋กœ ์„ค์ •ํ–ˆ๋‹ค.

 

 

 

 

 

 

์ „์ฒ˜๋ฆฌ & Pipeline


์ด์ „ ํฌ์ŠคํŒ…์—์„œ ์ง„ํ–‰ํ–ˆ๋˜ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ทธ๋Œ€๋กœ ๋งŒ๋“ค์–ด ์ค€๋‹ค.

 

 

 

(1) ์›-ํ•ซ ์ธ์ฝ”๋”ฉ

์นดํ…Œ๊ณ ๋ฆฌํ˜• ์นผ๋Ÿผ๋งŒ ์›-ํ•ซ ์ธ์ฝ”๋”ฉ์„ ํ•ด์ฃผ๋ฉด ๋œ๋‹ค. 

 

- StringIndexer : ์นดํ…Œ๊ณ ๋ฆฌํ˜•์„ ์›-ํ•ซ ์ธ์ฝ”๋”ฉํ•˜๊ธฐ ์ „์—, ์ˆซ์ž๋ฅผ ๋จผ์ € ๋ถ€์—ฌํ•ด์ค€๋‹ค. ex) (๋”ธ๊ธฐ, ๋ฐ”๋‚˜๋‚˜, ์ดˆ์ฝ”) -> (1, 3, 2)

- OneHotEncoder : ์ˆซ์ž ๋ถ€์—ฌ๋œ ๊ฒƒ์„ ๋ฒกํ„ฐํ™”ํ•œ๋‹ค. ex) 1 -> [1, 0, 0]

 

# ์›-ํ•ซ ์ธ์ฝ”๋”ฉ
from pyspark.ml.feature import OneHotEncoder, StringIndexer


# ์นดํ…Œ๊ณ ๋ฆฌํ˜• ์นผ๋Ÿผ
cat_feats = [
    'pickup_location_id',
    'dropoff_location_id',
    'day_of_week'
]

# ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์„ฑ์„ ์œ„ํ•œ stages ๋ฆฌ์ŠคํŠธ
stages = []


# index๋ฅผ ๋ฐ”๊พธ๊ณ , ๊ทธ ๋ฐ”๋€ indexer์— ์›-ํ•ซ ์ธ์ฝ”๋”ฉ์„ ์ ์šฉํ•ด์ค€๋‹ค.
for c in cat_feats:
    cat_indexer = StringIndexer(inputCol=c, outputCol = c + "_idx").setHandleInvalid("keep")
    onehot_encoder = OneHotEncoder(inputCols = [cat_indexer.getOutputCol()], outputCols=[c + '_onehot'])
    stages += [cat_indexer, onehot_encoder]

 

 

 

 

(2) ์ˆซ์žํ˜• ๋ฐ์ดํ„ฐ, ๋ฒกํ„ฐํ™” ๋ฐ ์Šค์ผ€์ผ๋Ÿฌ ์ ์šฉ

์ˆซ์žํ˜• ๋ฐ์ดํ„ฐ๋ฅผ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ๋กœ ๋ฌถ๊ณ , ์Šค์ผ€์ผ๋Ÿฌ๋ฅผ ์ ์šฉํ•œ๋‹ค.

 

# ๋ฒกํ„ฐํ™” + ์Šค์ผ€์ผ๋Ÿฌ ์ ์šฉ
from pyspark.ml.feature import VectorAssembler, StandardScaler

# numericํ˜• ์นผ๋Ÿผ
num_feats = [
    'passenger_count',
    'trip_distance',
    'pickup_time'
]


# ๋ฒกํ„ฐํ™” ํ•œ ํ›„, ์Šค์ผ€์ผ๋Ÿฌ๋ฅผ ์ ์šฉํ•œ๋‹ค.
for n in num_feats:
    num_assembler = VectorAssembler(inputCols=[n], outputCol= n + '_vecotr')
    num_scaler = StandardScaler(inputCol=num_assembler.getOutputCol(), outputCol = n + '_scaled')
    stages += [num_assembler, num_scaler]

 

 

* ์ค‘๊ฐ„ stages ์ ๊ฒ€

 

stages


[StringIndexer_115f74e6efea,
 OneHotEncoder_714f494271bb,
 StringIndexer_806a3b8e8a32,
 OneHotEncoder_f7bc9266f650,
 StringIndexer_2f125ebb95a8,
 OneHotEncoder_eb212d50e427,
 VectorAssembler_5b7bfff3be42,
 StandardScaler_3e59d49af9ad,
 VectorAssembler_a61858dada0f,
 StandardScaler_c739fc7f7d49,
 VectorAssembler_f683b2eeb4d2,
 StandardScaler_96e5ba925088]

 

 

 

 

(3) ๋ชจ๋‘  VectorAssemble ํ•˜์—ฌ ํ›ˆ๋ จ์— ์ ํ•ฉํ•œ ๋ฐ์ดํ„ฐ ์…‹์„ ์™„์„ฑํ•˜๋„๋ก ๊ตฌ์ถ•ํ•˜์ž.

 

# inputs ์นผ๋Ÿผ
assembler_inputs = [c + '_onehot' for c in cat_feats] + [n + '_scaled' for n in num_feats]
print(assembler_inputs)


['pickup_location_id_onehot',
 'dropoff_location_id_onehot',
 'day_of_week_onehot',
 'passenger_count_scaled',
 'trip_distance_scaled',
 'pickup_time_scaled']
 
 
 
 # stages์— VectorAssemble ์ถ”๊ฐ€
assembler = VectorAssembler(inputCols=assembler_inputs, outputCol= 'feature_vector')
stages += [assembler]
stages


[StringIndexer_dfc09cc586be,
 OneHotEncoder_987fbfa36a2d,
 StringIndexer_bf2338365d7f,
 OneHotEncoder_5a91ea5195e8,
 StringIndexer_c416d64272f1,
 OneHotEncoder_0dfab0742066,
 VectorAssembler_4c5f47a3740c,
 StandardScaler_65dfe2363318,
 VectorAssembler_7e0a4e81ec39,
 StandardScaler_22d11d283c0a,
 VectorAssembler_c2b692153924,
 StandardScaler_debc924ffa61,
 VectorAssembler_c2c382815ebb]

 

 

 

 

(4) Final, Pipeline()์„ ํ†ตํ•ด์„œ ๋‹ค ๋ฌถ์–ด์ฃผ๊ธฐ

 

# pipeline ๋งŒ๋“ค๊ธฐ
from pyspark.ml import Pipeline

transform_stages = stages
pipeline = Pipeline(stages = transform_stages)
fitted_transformer = pipeline.fit(train_df)

 

 

 

 

 

Hyperparameter Tuning


 

from pyspark.ml import Pipeline
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator

 

 

(1) ํŒŒ์ดํ”„๋ผ์ธ์— ์„ ํ˜• ํšŒ๊ท€ ์ถ”๊ฐ€ํ•˜๊ธฐ

 

lr = LinearRegression(
    maxIter=30,
    solver='normal',
    labelCol='total_amount',
    featuresCol='feature_vector'
)

cv_stages = stages + [lr]


# ํŒŒ์ดํ”„๋ผ์ธ ์™„์„ฑํ•˜๊ธฐ
cv_pipeline = Pipeline(stages = cv_stages)

 

 

 

(2) Parameter map ์ž‘์„ฑํ•˜๊ธฐ

MLlib์—์„  Parameter map์„ ๋”ฐ๋กœ ์ž‘์„ฑํ•œ ํ›„, ๊ต์ฐจ ๊ฒ€์ฆ์˜ ๋ณ€์ˆ˜๋กœ ์ง‘์–ด๋„ฃ์–ด์•ผ ํ•œ๋‹ค. ๊ทธ๋ž˜์•ผ Parameter ๋“ค์„ ์ˆœํ™˜ํ•˜์—ฌ ๋ชจ๋“  ๊ฒฝ์šฐ์˜ ์ˆ˜๋ฅผ ๋Œ์•„๋ณผ ์ˆ˜ ์žˆ๋‹ค.

 

lr.elasticNetParam -> ์„ ํ˜• ํšŒ๊ท€์˜ elasticNet์˜ alpha ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋งํ•œ๋‹ค.

lr.regParam -> ์„ ํ˜• ํšŒ๊ท€์˜ regParam ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์˜๋ฏธํ•œ๋‹ค.

 

# '\' ๋Š” ์ค„๋ฐ”๊ฟˆ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š” ๋ฌธ์ž๋กœ ์“ฐ์ธ๋‹ค.
param_grid = ParamGridBuilder()\
.addGrid(lr.elasticNetParam, [0.1, 0.2, 0.3, 0.4, 0.5])\
.addGrid(lr.regParam, [0.01, 0.02, 0.03, 0.04, 0.05]).build()

 

 

 

(3) ๊ต์ฐจ ๊ฒ€์ฆ ํ•จ์ˆ˜ ๋งŒ๋“ค๊ธฐ

 

# estimator์—๋Š” ํŒŒ์ดํ”„๋ผ์ธ์„,
# ParamMaps์—๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ ๋งต์„
# evaluator์—๋Š” ํ‰๊ฐ€ ์ง€ํ‘œ๋ฅผ,
# numFolds์—๋Š” ๊ต์ฐจ ๊ฒ€์ฆ์˜ ํด๋“œ ์ˆ˜๋ฅผ ์ž…๋ ฅํ–ˆ๋‹ค.

cross_val = CrossValidator(estimator = cv_pipeline,
                           estimatorParamMaps=param_grid,
                           evaluator=RegressionEvaluator(labelCol='total_amount'),
                           numFolds=5)

 

 

 

(4) ๊ต์ฐจ ๊ฒ€์ฆ Start

 

cv_model = cross_val.fit(toy_df)

 

 

 

(5) Best Parameter ๋ฝ‘์•„๋‚ด๊ธฐ

 

alpha = cv_model.bestModel.stages[-1]._java_obj.getElasticNetParam()
reg_param = cv_model.bestModel.stages[-1]._java_obj.getRegParam()

print(f'alpha is {alpha}')
print(f'reg_param is {reg_param}')


alpha is 0.4
reg_param is 0.05

 

 

 

 

 

 

Training & ์ถ”๋ก 


 

(1) train/test ๋ฐ์ดํ„ฐ๋ฅผ ํŒŒ์ดํ”„๋ผ์ธ ํ†ต๊ณผ์‹œํ‚ค๊ธฐ

 

transform_stages = stages
pipeline = Pipeline(stages = transform_stages)
fitted_transformer = pipeline.fit(train_df)


vtrain_df = fitted_transformer.transform(train_df)
vtest_df = fitted_transformer.transform(test_df)

 

 

 

(2) Training

 

lr = LinearRegression(
    maxIter=50,
    solver='normal',
    labelCol='total_amount',
    featuresCol='feature_vector'
    elasticNetParam = alpha,
    regParam = reg_param
)


model = lr.fit(vtrain_df)

 

 

 

(3) ์ถ”๋ก  ๋ฐ ์„ฑ๋Šฅ ํ‰๊ฐ€

 

prediction = model.transform(vtest_df)


predictions.select(["trip_distance", "day_of_week", "total_amount", "prediction"]).show()


+-------------+-----------+------------+------------------+
|trip_distance|day_of_week|total_amount|        prediction|
+-------------+-----------+------------+------------------+
|          3.1|   Saturday|       22.55| 18.44882896087039|
|          7.9|   Saturday|        30.3|28.898380309569866|
|          1.4|   Saturday|        16.0|13.635679102878225|
|          1.1|    Tuesday|       12.95|14.050765065622219|
|          0.7|   Saturday|         9.8|12.151950922741019|
|          1.1|     Monday|        11.8|14.302981883348586|
|          3.7|     Friday|       24.35|19.745504411653762|
|          2.4|   Saturday|       14.75|16.012957291356248|
|          3.4|   Saturday|       20.15|17.933430832644525|
|          6.5|     Friday|       32.75|27.541169751290077|
|          2.6|   Saturday|        17.8|16.614148948305857|
|          0.2|   Saturday|        5.15| 7.565190314421683|
|          0.9|     Sunday|         6.3| 9.491154409419867|
|          9.9|   Saturday|       46.35|38.455267206674144|
|          4.5|   Saturday|       22.85|15.766824456415733|
|          0.6|    Tuesday|         4.8| 7.422348054207551|
|          1.4|    Tuesday|        10.8|12.006766375046109|
|          4.3|   Saturday|        20.3|19.040664139319205|
|          3.9|   Thursday|        17.8|18.389827313235557|
|          5.8|    Tuesday|        27.0|23.073838896480723|
+-------------+-----------+------------+------------------+
only showing top 20 rows

 

 

 

 

RMSE

 

# before
model.summary.rootMeanSquaredError

5.818945295076586



# After
model.summary.rootMeanSquaredError

5.610492491836879

 

 

 

 

 

R2

 

# before
model.summary.r2

0.7997047915616821



# After
model.summary.r2

0.8108436137289087

 

 

ํŒŒ๋ผ๋ฏธ๋‰ด ํŠœ๋‹์„ ํ•˜๊ธฐ ์ „๊ณผ ์„ฑ๋Šฅ ์ฐจ์ด๊ฐ€ ๊ฝค ๋‚œ๋‹ค๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

 

 

 

 

๋ชจ๋ธ ์ €์žฅ ๋ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ


# ์ €์žฅํ•˜๊ธฐ
model_dir = "์ €์žฅ ๊ฒฝ๋กœ"
model.save(model_dir)


# ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
from pyspark.ml.regression import LinearRegressionModel
lr_model = LinearRegressionModel().load(model_dir)


# ๋ฐ”๋กœ ์ถ”๋ก ํ•˜๊ธฐ
predictions = lr_model.transform(vtest_df)



ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํŠœ๋‹ํ•˜๊ณ , ์ตœ์ ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ๋†’์—ฌ๋ณด์•˜๋‹ค. ๋” ๋‚˜์•„๊ฐ€ ๋ชจ๋ธ์„ ์ €์žฅํ•˜๊ณ , ์–ด๋–ป๊ฒŒ ๋ถˆ๋Ÿฌ์˜ค๋Š”์ง€๋„ ํ™•์ธํ–ˆ๋‹ค. ์ด์ œ ์Šค์Šค๋กœ ๋‹ค๋ฅธ ๋ชจ๋ธ๋„ ์ ์šฉํ•˜๋ฉด์„œ MLlib์— ์ต์ˆ™ํ•ด์ ธ ๋ณด์ž!

GOOD BYE!