如何在Scala中使用Spark SQL创建用于读取数据的Predicate

huangapple go评论59阅读模式
英文:

How to create Predicate for reading data using Spark SQL in Scala

问题

I can read the Oracle table using this simple Scala program:

val spark = SparkSession
  .builder
  .master("local[4]")
  .config("spark.sql.sources.partitionColumnTypeInference.enabled", false)
  .config("spark.executor.memory", "8g")
  .config("spark.executor.cores", 4)
  .config("spark.task.cpus", 1)
  .appName("Spark SQL basic example")
  .config("spark.some.config.option", "some-value")
  .getOrCreate()

val jdbcDF = spark.read
  .format("jdbc")
  .option("url", "jdbc:oracle:thin:@x.x.x.x:1521:orcl")
  .option("dbtable", "big_table")
  .option("user", "test")
  .option("password", "123456")
  .load()

jdbcDF.show()

However, the table is huge and each node has to read part of it. So, I must use a hash function to distribute data among Spark nodes. To do that in Scala, you can create a predicates list as follows:

val numPartitions = 19 // Set the number of partitions
val partitionKey = "partition_key" // Replace with your actual partition key
val hashCol = "hash_col" // Replace with your actual hash column
val currentDate = "current_date" // Replace with your actual current date

val hashValues = (0 until numPartitions).toList

val predicates = hashValues.map { hashVal =>
  s"""to_date($partitionKey,'YYYYMMDD','nls_calendar=persian') = to_date($currentDate,'YYYYMMDD','nls_calendar=persian') and hash_func($hashCol, $numPartitions) = $hashVal"""
}

// Now you can use the 'predicates' list in your Spark SQL query
val dataframe = spark.read
  .option("driver", "oracle.jdbc.driver.OracleDriver")
  .jdbc(
    url = "your_spark_url",
    table = "your_table_name",
    predicates = predicates
  )

This code defines the predicates list in Scala as you explained in Python, and you can use it to read the table based on the specified predicates.

英文:

I can read the Oracle table using this simple Scala program:

   val spark = SparkSession
.builder
.master("local[4]")
.config("spark.sql.sources.partitionColumnTypeInference.enabled", false)
.config("spark.executor.memory", "8g")
.config("spark.executor.cores", 4)
.config("spark.task.cpus", 1)
.appName("Spark SQL basic example")
.config("spark.some.config.option", "some-value")
.getOrCreate()

val jdbcDF = spark.read
.format("jdbc")
.option("url", "jdbc:oracle:thin:@x.x.x.x:1521:orcl")
.option("dbtable", "big_table")
.option("user", "test")
.option("password", "123456")
.load()

 jdbcDF.show()

However, the table is huge and each node have to read part of it. So, I must use a hash function to distribute data among Spark nodes. To have that Spark has Predicates. In fact, I did that in Python. The table has the column named NUM, that Hash Function receives each value and returns an Integer between num_partitions and 0. The predicate list is in following:

 hash_function = lambda x: 'ora_hash({}, {})'.format(x, num_partitions)    
 hash_df = connection.read_sql_full(
    'SELECT distinct {0} hash FROM {1}'.format(hash_function(var.hash_col), source_table_name))
hash_values = list(hash_df.loc[:, 'HASH'])

hash_values for num_partitions=19 is :

hash_values=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19]

predicates = [
    "to_date({0},'YYYYMMDD','nls_calendar=persian')= to_date({1} ,'YYYYMMDD','nls_calendar=persian') " \
    "and hash_func({2},{3}) = {4}"
        .format(partition_key, current_date, hash_col, num_partitions, hash_val) for hash_val in
    hash_values]

Then I read the table based on the predicates like this:

 dataframe = spark.read \
        .option('driver', 'oracle.jdbc.driver.OracleDriver') \
        .jdbc(url=spark_url,
              table=table_name,
              predicates=predicates)

Would you please guide me how to create Predicates List in Scala as I explained in Python?

Any help is really appreciated.

答案1

得分: 1

问题已解决。

我将代码更改如下,然后它可以正常工作:

import org.apache.spark.sql.SparkSession
import java.sql.Connection
import oracle.jdbc.pool.OracleDataSource

object main extends App {

  def read_spark(): Unit = {
    val numPartitions = 19
    val partitionColumn = "name"
    val field_date = "test"
    val current_date = "********"
    // 定义JDBC属性
    val url = "jdbc:oracle:thin:@//x.x.x.x:1521/orcl"
    val properties = new java.util.Properties()
    properties.put("url", url)
    properties.put("user", "user")
    properties.put("password", "pass")
    // 定义用于将每行分配到分区的WHERE子句
    val predicateFct = (partition: Int) => s"""ora_hash("$partitionColumn",$numPartitions) = $partition"""
    val predicates = (0 until numPartitions).map{partition => predicateFct(partition)}.toArray

    val test_table = s"(SELECT * FROM table where $field_date=$current_date) dbtable"
    // 将表加载到Spark中
    val df = spark.read
      .format("jdbc")
      .option("driver", "oracle.jdbc.driver.OracleDriver")
      .option("dbtable", test_table)
      .jdbc(url, test_table, predicates, properties)
    df.show()
  }
  val spark = SparkSession
    .builder
    .master("local[4]")
    .config("spark.sql.sources.partitionColumnTypeInference.enabled", false)
    .config("spark.executor.memory", "8g")
    .config("spark.executor.cores", 4)
    .config("spark.task.cpus", 1)
    .appName("Spark SQL基本示例")
    .config("spark.some.config.option", "some-value")
    .getOrCreate()

  read_spark()

}

注意:上述代码中的中文部分已被保留,不进行翻译。

英文:

Problem Solved.

I changed the code like this, then it's work:

import org.apache.spark.sql.SparkSession
import java.sql.Connection
import oracle.jdbc.pool.OracleDataSource
object main extends App {
def read_spark(): Unit = {
val numPartitions = 19
val partitionColumn = "name"
val field_date = "test"
val current_date = "********"
// Define JDBC properties
val url = "jdbc:oracle:thin:@//x.x.x.x:1521/orcl"
val properties = new java.util.Properties()
properties.put("url", url)
properties.put("user", "user")
properties.put("password", "pass")
// Define the where clauses to assign each row to a partition
val predicateFct = (partition: Int) => s"""ora_hash("$partitionColumn",$numPartitions) = $partition"""
val predicates = (0 until numPartitions).map{partition => predicateFct(partition)}.toArray
val test_table = s"(SELECT * FROM table where $field_date=$current_date) dbtable"
// Load the table into Spark
val df = spark.read
.format("jdbc")
.option("driver", "oracle.jdbc.driver.OracleDriver")
.option("dbtable", test_table)
.jdbc(url, test_table, predicates, properties)
df.show()
}
val spark = SparkSession
.builder
.master("local[4]")
.config("spark.sql.sources.partitionColumnTypeInference.enabled", false)
.config("spark.executor.memory", "8g")
.config("spark.executor.cores", 4)
.config("spark.task.cpus", 1)
.appName("Spark SQL basic example")
.config("spark.some.config.option", "some-value")
.getOrCreate()
read_spark()
}

huangapple
  • 本文由 发表于 2023年2月8日 19:41:29
  • 转载请务必保留本文链接:https://go.coder-hub.com/75385292.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定