英文:
How to extend built-in aggregate function in Spark SQL (using Scala)?
问题
以下是您要翻译的内容:
"基本上最终目标是创建类似于 dollarSum
的东西,它将返回与 ROUND(SUM(col), 2)
相同的值。
我正在使用 Databricks runtime 10.4 LTS ML,显然对应于 Spark 3.2.1 和 Scala 2.12。
我能够按照UDAF教程/示例代码,并使用它来创建类似于内置的 EVERY
函数的东西。但那似乎更像是 ImperativeAggregate
,而我想要的可能更像是 DeclarativeAggregate
,请参考Spark源代码中的注释。
总的来说,我尚未能够在网上找到如何以简单的方式扩展内置聚合函数的任何文档,其中您只需修改“完成”或“评估”步骤,而且即使是在额外的行为上也只需添加。
我迄今为止尝试过的内容:
我迄今为止尝试过至少四种方法,但都没有成功。
尝试1:
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{sum, round}
object dollarSum extends Aggregator[Double, Double, Double] {
def zero: Double = sum.zero
def reduce(buffer: Double, row: Double): Double = sum.reduce
def merge(buffer1: Double, buffer2: Double) Double = sum.merge
def finish(reduction: Double): Double = {
sum.finish(reduction)
round(reduction, 2)
}
def bufferEncoder: Encoder[Double] = sum.bufferEncoder
def outputEncoder: Encoder[Double] = sum.outputEncoder
}
尝试2: 我尝试从这里复制粘贴修改代码。这似乎失败了,因为内置的 Sum
类的大多数属性和方法似乎是私有的(可能是因为开发人员不希望像我这样不知所措的人破坏代码)。但是我不知道可以使用哪个公共接口/ API 来获得我想要的内容。
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.functions.round
import org.apache.spark.sql.catalyst.expressions.EvalMode
import org.apache.spark.sql.types.DecimalType
trait dollarSum extends Sum {
override lazy val evaluateExpression: Expression = {
Sum.resultType match {
case d: DecimalType =>
val checkOverflowInSum =
CheckOverflowInSum(Sum.sum, d, evalMode != EvalMode.ANSI, getContextOrNull())
If(isEmpty, Literal.create(null, Sum.resultType), checkOverflowInSum)
case _ if shouldTrackIsEmpty =>
If(isEmpty, Literal.create(null, Sum.resultType), Sum.sum)
case _ => round(Sum.sum, 2)
}
}
}
这可能仍然会失败,因为可能存在其他缺少的导入,但由于试图访问私有方法和属性,我无法深入调试。
尝试3: 同一文件中的 try_sum
的源代码似乎更接近使用和内置的 sum
的“公共API”,因此我尝试复制粘贴修改它。但是 ExpressionBuilder
也似乎是一个私有类,因此这也失败了。
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression
object DollarSumExpressionBuilder extends ExpressionBuilder {
override def build(funcName: String, expressions: Seq[Expression]): Expression = {
val numArgs = expressions.length
if (numArgs == 1) {
round(Sum(expressions.head),2)
} else {
throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), numArgs)
}
}
}
然后,想法是如果这个方法有效,我将尝试以与 Spark SQL 中的 TRY_SUM
相同的方式注册该函数,源代码中的这里。但是我遇到了关于 ExpressionBuilder
不存在的错误,这似乎表明它也是该包的私有类,因此不能用来扩展 SUM
的公共接口。
还有一个不确定的问题是 SUM
构造函数的返回类型是什么,我认为它可能是从 Expression
继承的 AggregateExpression
。而且我不确定 round
的输入类型是什么,似乎可能是 org.apache.spark.sql.Column
,如果是这样,我不知道如何从 Expression
转换为 Column
。
例如,在上面的代码中:
round(org.apache.spark.sql.Column((Sum(expressions.head)),2)
或者
round(org.apache.spark.sql.functions.col((Sum(expressions.head)),2)
是否能够实现所需的类型转换(似乎都不起作用)。
尝试4:
在与上述相似的情况下,不知道需要哪些类型以及如何在它们之间进行转换,以及 SUM
的公共接口是什么,我尝试使用 org.apache.spark.sql.functions.sum
作为 SUM
的“公共接口”,但这也没有成功。
具体来说:
import org.apache.spark.sql.functions.{round, sum}
import org.apache.spark.sql.Column
// 最初我有 `expression: org.apache.spark.sql.catalyst.expressions.Expression` 但那不起作用
def dollarSum(expression: Column): Column = {round(sum(expression), 2)}
实际上没有引发任何错误,但是当我尝试将生成的对象注册为一个(聚合)函数时,它失败了,具体来说:
spark.udf.register("dollar_sum", functions.udaf(dollarSum))
不起作用,也不是
spark.udf.register("dollar_sum", functions.udf(dollarSum))
英文:
Basically the end goal would be to create something like dollarSum
which would return the same values as ROUND(SUM(col), 2)
.
I'm using Databricks runtime 10.4 LTS ML, which apparently corresponds to Spark 3.2.1 and Scala 2.12.
I am able to follow the tutorial / example code for UDAFs, and used it to create something analogous to the built-in EVERY
function. But that seems to be more like ImperativeAggregate
, whereas what I want might be more like DeclarativeAggregate
, cf. the comments in the Spark source code.
Overall I haven't been able to find any documentation online of how you would extend build-in aggregate functions in a simple way, where you only modify the "finish" or "evaluate" step, and even then just by adding on extra behavior.
What I have tried so far:
I have tried at least four things so far, and none of them work.
Attempt 1:
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{sum, round}
object dollarSum extends Aggregator[Double, Double, Double] {
def zero: Double = sum.zero
def reduce(buffer: Double, row: Double): Double = sum.reduce
def merge(buffer1: Double, buffer2: Double) Double = sum.merge
def finish(reduction: Double): Double = {
sum.finish(reduction)
round(reduction, 2)
}
def bufferEncoder: Encoder[Double] = sum.bufferEncoder
def outputEncoder: Encoder[Double] = sum.outputEncoder
}
Attempt 2: I tried to copy-paste-modify code from here. This seems to fail because most of the attributes and methods of the built-in Sum
class appear to be private (probably because the developers didn't want people like me who don't know what they're doing to break the code). But I don't what public interface / API I could use instead to get what I want.
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.functions.round
import org.apache.spark.sql.catalyst.expressions.EvalMode
import org.apache.spark.sql.types.DecimalType
trait dollarSum extends Sum {
override lazy val evaluateExpression: Expression = {
Sum.resultType match {
case d: DecimalType =>
val checkOverflowInSum =
CheckOverflowInSum(Sum.sum, d, evalMode != EvalMode.ANSI, getContextOrNull())
If(isEmpty, Literal.create(null, Sum.resultType), checkOverflowInSum)
case _ if shouldTrackIsEmpty =>
If(isEmpty, Literal.create(null, Sum.resultType), Sum.sum)
case _ => round(Sum.sum, 2)
}
}
}
This would probably still fail due to some other missing imports, but again I wasn't able to get that far in debugging due to trying to access private methods and attributes that probably shouldn't be accessed.
Attempt 3: The source code for try_sum
in the same file seemed closer to using a "public API" for sum, so I tried copy-paste-modifying that instead. But ExpressionBuilder
also seems like it's a private class, so this fails too.
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression
object DollarSumExpressionBuilder extends ExpressionBuilder {
override def build(funcName: String, expressions: Seq[Expression]): Expression = {
val numArgs = expressions.length
if (numArgs == 1) {
round(Sum(expressions.head),2)
} else {
throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), numArgs)
}
}
}
Then the idea would be that if that worked, I would try registering the function the same way that TRY_SUM
is registered with Spark SQL in the source code, cf. here. But I got an error about ExpressionBuilder
not existing, which seems to indicate that it is also a private class for the package and thus not the public interface I could use to extend SUM
.
Also it's not clear to me what the return type is for the SUM
constructor, I think it might be AggregateExpression
inheriting from Expression
. And I'm not certain what the input type is for round
, it seems like it might be org.apache.spark.sql.Column
, if so, I'm not sure how to convert from Expression
to Column
.
E.g. whether in the above
round(org.apache.spark.sql.Column((Sum(expressions.head)),2)
or
round(org.apache.spark.sql.functions.col((Sum(expressions.head)),2)
would be able to achieve the desired type conversion (seemingly neither works).
Attempt 4:
Along the lines of the above, not knowing which types are needed and how to convert between them, and what the public interface for SUM
is, I tried using org.apache.spark.sql.functions.sum
as the "public interface" for SUM
instead, but this also didn't work.
Specifically
import org.apache.spark.sql.functions.{round, sum}
import org.apache.spark.sql.Column
// originally I had `expression: org.apache.spark.sql.catalyst.expressions.Expression` but that didn't work
def dollarSum(expression: Column): Column = {round(sum(expression), 2)}
actually doesn't throw any errors, but then when I try to actually register the resulting object as a(n aggregate) function, it fails, specifically
spark.udf.register("dollar_sum", functions.udaf(dollarSum))
doesn't work, nor does
spark.udf.register("dollar_sum", functions.udf(dollarSum))
答案1
得分: 1
这段文字似乎涉及到Apache Spark和代码方面的内容,如果您需要任何进一步的翻译或解释,请随时提问。
英文:
Wow, lots of fun stuff in this question and awfully familiar: Quality's agg_expr was my journey into that space.
To build a custom expression you may need to put code into the org.apache.spark.sql package e.g. registerFunction. Using the SparkSession instance FunctionRegistry createOrReplaceTempFunction (e.g. SparkSession.getActiveSession.get.sessionState.functionRegistry) you can use the function within a session. If you need it in hive views etc. you must use SparkSessionExtensions for scope and FunctionRegistry.builtin.registerFunction.
The actual registration ExpressionBuilder is just an alias for Seq[Expression] => Expression, representing the parameters passed into construct your expression.
So, depending on Spark version (the internal api changes alot):
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Round, Literal, EvalMode}
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
SparkSession.getActiveSession.get.sessionState.functionRegistry.
createOrReplaceTempFunction("dollarSum", exps => Round(
Sum(exps.head, EvalMode.TRY).toAggregateExpression(), Literal(2)), "built-in")
val seq = Seq(1.245, 242.535, 65656.234425, 2343.666)
import sparkSession.implicits._
seq.toDF("amount")//.selectExpr("round(sum(amount), 2)").show
.selectExpr("dollarSum(amount)").show
NB/FYI: An obvious idea with Quality would be to use a lambda:
import com.sparkutils.quality.{LambdaFunction, Id, registerLambdaFunctions, registerQualityFunctions}
registerQualityFunctions()
registerLambdaFunctions(Seq(
LambdaFunction("dollarSum", "a -> round(sum(a), 2)", Id(1,1))
))
this however fails as Spark LambdaFunction's and AggregateFunctions don't readily mix. The direct FunctionRegistry route doesn't involve a LambdaFunction and so works correctly.
Extra info, per comments questions...
Why "built-in", it's used to specify sources, you can't create the function unless it's a valid source (from ExpressionInfo):
private static final Set<String> validSources =
new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", "java_udf"));
as such only built-in is close. The name refers to the static FunctionRegistry.builtin instance which houses all the normal spark sql functions - and what you need to use if you want to use the function in create view etc.
Re the builder - as I wrote above it's a function that takes expressions and returns an expression i.e. the constructor. You will need to call createOrReplaceTempFunction (or the others I mention above) to actually register but it's just a name and Seq[Expression] => Expression pair, easy enough to manage differently. As Spark's interface changes each couple of releases for this the actual call in Quality is made in different Spark compatibility layers e.g. 10.4 LTS or oss 2.4, the functions themselves are however managed here and below.
In order to provide some useful errors on parameter handling I also specify parameter combinations handled here.
Now, in order to make more complicated logic you will have to understand each of the Spark Expressions themselves and many of them change each release, worst as you are using Databricks is that the OSS version advertised is only for the public interfaces, this means you must sometimes guess or use reflection to figure out what the Expressions actually look like on Databricks. Typically this is just backports of future releases, but not always, there have been traits that were swapped for abstract classes leading to hideous workarounds like this where I have to shim the types to correctly compile under OSS with a target of DBR 9.1, caveat emptor.
That said although there is the odd surprise and risk waiting, e.g. a DBR version backports a fix or feature that breaks interface without bumping version. So you are calmly and happily using your Sum code on 10.4 but overnight 10.4 stops working and your DECIMAL sums are clearly suffering overflow of some kind. Every other user of 10.4 they get a nice performance bump, but you get broken math... So be prepared to continuously test and be able to make fixes quickly, this is the price for using internals.
To be really clear - I deeply appreciate the Databricks product and team, this issue is not one of their making, it'd be yours (and clearly mine) for using internals apis.
The core Spark team has also openly stated they don't approve of such usage wrt. Frameless (3.2.0 to 3.2.1 changed the internal Encoder API, breaking Frameless users). Clearly they too should be free to innovate and re-organise internal api's. The performance and flexibility of using them though...
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论