[SparkSQL] UDF ๊ฐœ๋… ๋ฐ ์ฝ”๋“œ

2022. 5. 8. 12:57ใ†๐Ÿ›  Data Engineering/Apache Spark

 

 

UDF

 

User Define Function. ์ฆ‰, ์‚ฌ์šฉ์ž ์ง€์ • ํ•จ์ˆ˜๋ฅผ ๋งํ•œ๋‹ค.

์•ž์„  ํฌ์ŠคํŒ…์—์„œ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์—ฌ๋Ÿฌ ํ•จ์ˆ˜๋“ค์„ ๋ณด์•˜๋‹ค.

SparkSQL์—์„  ์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ์–ด๋–ค ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ๋ช…๋ช…ํ•˜๊ณ , Spark์— ๋“ฑ๋กํ•˜์—ฌ ์“ธ ์ˆ˜ ์žˆ๋‹ค.

์ž์„ธํ•˜๊ฒŒ ์•Œ์•„๋ณด์ž. 

 

 

"๋ณธ ํฌ์ŠคํŒ…์€ ํŒจ์ŠคํŠธ์บ ํผ์Šค์˜ ๊ฐ•์˜๋ฅผ ๋“ฃ๊ณ , ์ •๋ฆฌํ•œ ์ž๋ฃŒ์ž„์„ ๋ฐํž™๋‹ˆ๋‹ค."

 

 

 

 

Basic Setting

 

import os
import findspark
findspark.init(os.environ.get("SPARK_HOME"))
import pyspark
from pyspark import SparkConf, SparkContext
import pandas as pd
import faulthandler
faulthandler.enable()
from pyspark.sql import SparkSession

spark = SparkSession.builder.master('local').appName("dataframe").getOrCreate()

# Data

stocks = [
    ('Google', 'GOOGL', 'USA', 2984, 'USD'), 
    ('Netflix', 'NFLX', 'USA', 645, 'USD'),
    ('Amazon', 'AMZN', 'USA', 3518, 'USD'),
    ('Tesla', 'TSLA', 'USA', 1222, 'USD'),
    ('Tencent', '0700', 'Hong Kong', 483, 'HKD'),
    ('Toyota', '7203', 'Japan', 2006, 'JPY'),
    ('Samsung', '005930', 'Korea', 70600, 'KRW'),
    ('Kakao', '035720', 'Korea', 125000, 'KRW'),
]


# Schema

stockSchema = ['name', 'ticker', 'country', 'price', 'currency']


# createDataFrame()
df = spark.createDataFrame(data = stocks, schema=stockSchema)

# createOrReplaceTempView()
df.createOrReplaceTempView("stock")

 

 

 

 

UDF


์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ์ •์˜ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ๋งํ•œ๋‹ค.

User Define Function.

 

 

1. spark.udf.register()

์ •์˜๋œ ํ•จ์ˆ˜๋ฅผ spark์—์„œ ์“ธ ์ˆ˜ ์žˆ๋„๋ก ๋“ฑ๋กํ•œ๋‹ค.

spark.udf.register("ํ•จ์ˆ˜ ์ด๋ฆ„", ํ•จ์ˆ˜, returnํ•  ๋ฐ์ดํ„ฐ ํƒ€์ž…)

 

from pyspark.sql.types import LongType

def squared(n):
    return n * n

spark.udf.register("squared", squared, LongType())

 

์‚ฌ์šฉํ•˜๊ธฐ

 

# ๊ฐ€๊ฒฉ์„ ์ œ๊ณฑํ–ˆ๋‹ค.
spark.sql("select name, squared(price) from stocks").show()


+-------+--------------+
|   name|squared(price)|
+-------+--------------+
| Google|       8904256|
|Netflix|        416025|
| Amazon|      12376324|
|  Tesla|       1493284|
|Tencent|        233289|
| Toyota|       4024036|
|Samsung|    4984360000|
|  Kakao|   15625000000|
+-------+--------------+

----

# ์ œ๊ณฑํ•œ ๊ฐ€๊ฒฉ์ด 1,000,000์„ ๋„˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ์—๋งŒ ์ถœ๋ ฅํ•˜๋„๋ก ํ–ˆ๋‹ค.
spark.sql("select name, squared(price) from stocks where squared(price) < 1000000").show()


+-------+--------------+
|   name|squared(price)|
+-------+--------------+
|Netflix|        416025|
|Tencent|        233289|
+-------+--------------+

 

 

 

2. ์กฐ๊ธˆ ๋” ์‹ค์šฉ์ ์ธ ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ๋ณด์ž.

 

# ํ†ตํ™”๋ฅผ ํ•œ๊ธ€๋กœ ๋ณ€ํ™˜ํ•˜๊ณ , ๊ฐ€๊ฒฉ์ด๋ž‘ ํ•ฉ์ณ์ฃผ์ž.
def currency_ko(n):
    if n == 'USD':
        return '๋‹ฌ๋Ÿฌ'
    elif n == 'KRW':
        return '์›'
    elif n == 'JPY':
        return '์—”'
    else:
        return '์œ„์•ˆ'

spark.udf.register("currency_ko", currency_ko)


<function __main__.currency_ko(n)>

----

# SQL๋ฌธ์˜ CONCAT() ํ•จ์ˆ˜๋ฅผ ์จ์คฌ๋‹ค.
spark.sql("select name, concat(price, currency_ko(currency)) as price from stocks").show()


+-------+--------+
|   name|   price|
+-------+--------+
| Google|2984๋‹ฌ๋Ÿฌ|
|Netflix| 645๋‹ฌ๋Ÿฌ|
| Amazon|3518๋‹ฌ๋Ÿฌ|
|  Tesla|1222๋‹ฌ๋Ÿฌ|
|Tencent| 483์œ„์•ˆ|
| Toyota|  2006์—”|
|Samsung| 70600์›|
|  Kakao|125000์›|
+-------+--------+

 

 

 

UDF์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์•˜๋‹ค.

๋‹ค์Œ ํฌ์ŠคํŒ…์—์„  SparkSQL์˜ ๋ฐฑ์—”๋“œ ํ”„๋กœ์„ธ์‹ฑ์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์ž.

์ˆ˜๊ณ ํ•˜์…จ์Šต๋‹ˆ๋‹ค.