PySpark 3高阶函数用于提取到列中

huangapple go评论76阅读模式
英文:

PySpark 3 higher order function to extract into columns

问题

我有一个Spark DataFrame的ArrayType列 ~ ['db1.schema1.table1','schema5.table3','table4']

我的目标是创建三个ArrayType列 - dbs例如 [db1],schemas [schema1,schema5],tables ['table1','table3','table4']

为此,我使用了一个Python UDF,但速度极慢/内存效率极低。

我的任务是使用PySpark原生函数。我正在使用Spark 3.3。

from pyspark.sql.types import *
from pyspark.sql.functions import *

cSchema = StructType([StructField("WordList", ArrayType(StringType()))])
test_list = [['db1.s1.t1']], [['t3','d1.s1.t1','s2.t2']]
df = spark.createDataFrame(test_list, schema=cSchema)
df = df.withColumn("random", expr("uuid()"))

df = df.select('*', explode("WordList").alias("x"))
df = df.withColumn("x_split", split(col("x"), "\\."))
df = df.withColumn("size", size(col("x_split")))

df = df.withColumn("table", element_at(col("x_split"), col("size")))
df = df.withColumn("database", when(col("size") == 3, element_at(col("x_split"), 1)).otherwise(lit('na')))
df = df.withColumn("schema", when(col("size") > 1, element_at(col("x_split"), col("size") - 1)).otherwise(lit('na')))

#然后根据uuid执行窗口collect_set...然后删除重复项

但这可能会在explode/drop duplicates时导致数据洗牌。

是否可以使用高阶函数解决这个问题?
https://docs.databricks.com/_extras/notebooks/source/higher-order-functions-tutorial-python.html

我尝试过,但没有取得进展。

英文:

I have a spark df ArrayType column ~ ['db1.schema1.table1','schema5.table3','table4']

My objective is to create three ArrayType columns - dbs for example [db1], schemas [schema1,schema5], tables ['table1','table3','table4']

For this, i'd used a python udf which was extremely slow/memory inefficient.

My task is to use pyspark native functions. I'm on spark 3.3.

from pyspark.sql.types import *
from pyspark.sql.functions import *

cSchema = StructType([StructField("WordList", ArrayType(StringType()))])
test_list = [['db1.s1.t1']], [['t3','d1.s1.t1','s2.t2']]
df = spark.createDataFrame(test_list,schema=cSchema)
df = df.withColumn("random",expr("uuid()"))

df=df.select('*',explode("WordList").alias("x"))
df=df.withColumn('x_split',split(col('x'), "\\."))
df=df.withColumn("size", size(col('x_split')))

df = df.withColumn("table", element_at(col('x_split'),col('size')))
df = df.withColumn("database", when(col('size')==3, element_at(col('x_split'),1)).otherwise(lit('na')))
df = df.withColumn("schema", when(col('size')>1, element_at(col('x_split'),col('size')-1)).otherwise(lit('na')))

#then do window collect_set per uuid... then drop duplicates

but this risks shuffling at explode/drop duplicates...

Can this be solved using higher order functions?
https://docs.databricks.com/_extras/notebooks/source/higher-order-functions-tutorial-python.html

I tried but couldn't make progress.

答案1

得分: 1

我会首先使用 transform 函数将每个值按 . 进行分割。
然后,在处理分割后的值时,根据分割数组的 size 提取值。
最后,去除空值和数组去重(不是洗牌!)

def clean_array(arr):
  return array_distinct(filter(arr, lambda v: v.isNotNull()))

df \
.withColumn(
    'split_words',
    transform(
        'wordlist',
        lambda v: split(v, '[.]')
    )
) \
.select(
    clean_array(transform(
        'split_words',
        lambda v: when(size(v) == 3, v[0])
    )).alias('dbs'),
    clean_array(transform(
        'split_words',
        lambda v: when(size(v) == 3, v[1]).when(size(v) == 2, v[0])
    )).alias('schemas'),
    clean_array(transform(
        'split_words',
        lambda v: when(size(v) == 3, v[2]).when(size(v) == 2, v[1]).otherwise(v[0])
    )).alias('tables')
) \
.show(truncate=0)
英文:

I would first split each value by . using the transform function.
then when running on the split values, extract the value based on the size of the split array.
lastly remove nulls and array_distinct (not a shuffle!)

def clean_array(arr):
  return array_distinct(filter(arr, lambda v: v.isNotNull()))


df \
.withColumn(
    'split_words',
    transform(
        'wordlist',
        lambda v: split(v, '[.]')
    )
) \
.select(
    clean_array(transform(
        'split_words',
        lambda v: when(size(v) == 3, v[0])
    )).alias('dbs'),
    clean_array(transform(
        'split_words',
        lambda v: when(size(v) == 3, v[1]).when(size(v) == 2, v[0])
    )).alias('schemas'),
    clean_array(transform(
        'split_words',
        lambda v: when(size(v) == 3, v[2]).when(size(v) == 2, v[1]).otherwise(v[0])
    )).alias('tables')
) \
.show(truncate=0)

huangapple
  • 本文由 发表于 2023年3月7日 03:05:28
  • 转载请务必保留本文链接:https://go.coder-hub.com/75654835.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定