如何在Scala中定义一个通用函数,接受所有可“计算”的数据类型?

huangapple go评论87阅读模式
英文:

How can I define a generic function in scala that accepts all 'computable' Datatypes?

问题

我遇到过这个问题几次:

我有一个计算某些内容的函数,比如说:

def square(n: Int): Int = n * n

(这只是一个非常简单的示例,但这足够说明问题)

然后,我对另一种数据类型使用相同的“算法”,比如说 long:

def square(n: Long): Long = n * n

然后对于 BigInt、Short、Byte 等等也是如此。

如果我的算法比这个示例更复杂且更长,那么就会有很多重复的代码。

我想要的是一个通用的定义,如下所示:

def square[T <: AnyVal](n: T): T = n * n

但这不起作用,因为在类层次结构中,除了 AnyVal 中有 Int、Long 和 Float 外,还有 Boolean 和 Unit。
而对于 Boolean 和 Unit,术语 n * n 没有意义,我会收到编译器错误(这是正确的)。

我只想让我的函数适用于“可计算”的数据类型,比如 Int、Long、Float 等等,
这些数据类型都具有常见的数学运算符,如+、*、/、<等等,
然后一次性使用这些运算符为所有数据类型制定我的算法或计算。

当然,我可以匹配函数的输入变量 n,然后分别处理每种情况,
但在那里我还是会重复所有的代码,就像之前的重载一样。

我尝试过创建自己的 trait 'Computables',然后扩展到其他类,比如 Int、Long 等等,但编译器会抱怨 '... 无法扩展最终类 Int'。

这是否可能?我有什么遗漏吗?

英文:

I've had this problem a couple of times now:

I have a function that computes something, lets say

def square(n: Int): Int = n * n

(Very simple example, but this will do)

Then I have the same 'algorithm' for another datatype, lets say long:

def square(n: Long): Long = n * n

And then for BigInt, Short, Byte and so on.

If my algorithm is more complex and longer than in this example, I have a lot of code repetition.

What I would like to have is a generic definition like:

def square[T :&gt; AnyVal](n: T): T = n * n

But this does not work, because in the class hirachy, below AnyVal with Int and Long and Float there also is Boolean and Unit.
And for Boolean and Unit the term n * n does not make sense and I get a compiler error (correctly).

I only want my function to work for the 'computable' Datatypes like Int, Long, Float, ...
that have all the usual math operators like +, *, /, < and so on
and then formulate my algorithm or calculation with this operators for all at once.

Of course I can match on the functions input variable n, and then handle each case
differently, but there I also will repeat all the code like before with overloading.

I tried to make my own trait 'Computables' and then extend to the other classes Int, Long, ..., but the compiler complains '... cannot extend final class Int'

Is this even possible? Am I missing something?

答案1

得分: 7

你可以使用Numeric特质:

def square[T](n: T)(using numeric: Numeric[T]): T = numeric.times(n,n)

或者

def square[T](n: T)(using numeric: Numeric[T]): T = {
  import numeric._  
  n * n
}

演示 @scastie

英文:

You can use the Numeric trait:

def square[T](n: T)(using numeric: Numeric[T]): T = numeric.times(n,n)

or

def square[T](n: T)(using numeric: Numeric[T]): T = {
  import numeric._
  n * n
}

Demo @scastie

答案2

得分: 7

我认为你正在寻找一个type class

其他答案中提到的Numeric是描述所有数值类型的类型类的一个示例。但你可以将其泛化以描述任何种类的行为。

请查看我上面提到的链接,它包含了详细的例子,但在高层次上,思想是这样的:

def someFunction[T : SomeType](t: T)` 等同于 `def someFunction[T](t: T)(implicit ev: SomeType[T])

这意味着像这样的东西:

val foo: Foo = ???
someFunction(foo)

只有在你的作用域中有一个SomeType[Foo]类型的隐式实例时,才会编译通过。

所以,这解决了你问题的第一部分:我们为所有"数值"类型定义了Numeric实例,但没有为字符串或布尔类型定义,因此你可以通过这种方式限制可以传递给你的函数的类型。

类型类的另一个目的是表达它包括的类型的通用行为。你可以"召唤"类型类的隐式"证明"来访问它:

def someFunction[T : SomeType](t: T) = implicitly[SomeType[T]].doStuff(t)

在你的情况下,对于square函数:

def square[T : Numeric](t: T) = implicitly[Numeric[T]].times(t, t)
英文:

I think, you are looking for a type class

The Numeric mentioned in the other answer is an example of a type class describing all numeric types. But you can generalize that to describe any kind of behavior.

Check out the link I mentioned above, it has the details and examples, but at a high level, the idea is this:

def someFunction[T : SomeType](t: T) is equivalent to def someFunction[T](t: T)(implicit ev: SomeType[T])

This means that something like

val foo: Foo = ???
someFunction(foo)

will compile if and only if, you have an implicit instance of type SomeType[Foo] somewhere in scope.

So, that solves the first part of your problem: we have Numeric instances defined for all the "numeric" types, but not for strings or boolean, so this way you are restricting the types that can be sent to your function to a specific "class" (thus "type class") of types.

The other purpose of the type class is to express the common behaviors of the types it includes. You can "summon" the implicit "evidence" of the type class to access it:

def someFunction[T : SomeType](t: T) = implicitly[SomeType[T]].doStuff(t)

In your case, with the square:

def square[T : Numeric](t: T) = implicitly[Numeric[T]].times(t, t)

答案3

得分: 1

以下是您要翻译的内容:

"对我来说,很难决定接受哪个答案。
结果,我选择的简单示例太简单了,所以我决定尝试一个真实的例子。但我不想更改问题,因为已经给出了答案。我主要是为了自己的参考而发布这个。

所以这是我的新例子。我有一个计算整数平方根的函数:

def iSqrt1(n: Int): Int =
  if n < 0 then throw IllegalArgumentException("Argument to iSqrt can not be negative! Called with: " + n)
  else if n < 2 then n
  else
    val x0 = n / 2

    @tailrec
    def iterate(x0: Int, x1: Int): Int = if x1 < x0 then iterate(x1, (x1 + n / x1) / 2) else x0

    iterate(x0, (x0 + n / x0) / 2)

我不想在这里详细讨论。如果您对算法感兴趣,我直接从 维基百科 中获取了它(参见'仅使用整数除法')。

我为 Int、Long 和 BigInt 类型实现了这个函数。对于每种类型,除了有时为了提高性能而不同的起始值 x0 外,它几乎看起来完全相同。

所以,我尽力编写了一种通用的 isqrt 函数,适用于 Int、Long 和 BigInt 类型,并且我使用了 Scala 版本 3.3.0。

我首先尝试了使用类型类的方法:

import scala.annotation.{tailrec, targetName}

trait MathOperators[T] {
  def plus(x: T, y: T): T

  def minus(x: T, y: T): T

  def multiply(x: T, y: T): T

  def divide(x: T, y: T): T

  def divide(x: T, y: Int): T

  def lessThan(x: T, y: T): Boolean

  def lessThan(x: T, y: Int): Boolean

  def toBigInt(x: T): BigInt

  def startValueFrom(x: T): T
}

given MathOperators[Int] with {
  // 实现 Int 类型的 MathOperators
  // ...
}

given MathOperators[Long] with {
  // 实现 Long 类型的 MathOperators
  // ...
}

given MathOperators[BigInt] with {
  // 实现 BigInt 类型的 MathOperators
  // ...
}

def startValue[A](n: A)(using mathOp: MathOperators[A]): A = {
  mathOp.startValueFrom(n)
}

implicit class MathOperatorsSyntax[T](x: T)(using op: MathOperators[T]) {
  // 定义一些操作符的扩展方法
  // ...
}

def iSqrt[T](n: T)(using op: MathOperators[T]): T = {
  // 实现 iSqrt 函数
  // ...
}

@main
def main(): Unit = {
  val n1 = BigInt("1236549865413213456498765432136546879854651321")
  println(iSqrt(n1))
}

正如 Dima 在上面评论中提到的:

但你可以编写自己的 [类型类] 来执行任何你想要的操作。

是的,可以。了解如何执行这些操作肯定是一个好事。但在我的例子中,这看起来对我来说有点啰嗦。(或者至少我无法找出如何更短地执行这些操作。)

所以,我研究了第二种方法:使用已经构建好的类型类,如 Numeric。

对于我的新示例,我选择了 Integral 类。起初,这也不是很令人满意,直到我找到了 infixOrderingOps 和 infixIntegralOps。有了这些导入,我可以使用诸如 +、-、%、< 等操作符:

import scala.annotation.tailrec
import scala.math.Integral
import scala.math.Integral.Implicits
import math.Ordering.Implicits.infixOrderingOps
import math.Integral.Implicits.infixIntegralOps
import collection.immutable.Range.BigDecimal.bigDecAsIntegral

def iSqrt[T](input: T)(using integral: Integral[T]): T = {
  // 实现 iSqrt 函数
  // ...
}

我甚至设法在某些类型中偷偷加入了特殊的起始值 x0,但是这无疑不如第一个版本干净。

尝试了所有这些后,第一个答案看起来是我想要的正确选择,所以我接受了它。

"

英文:

It was very hard for me to decide wich answer to accept.
It turned out the simple example I choose was a bit to simple, so i decided to try this with a real life example. But I didn't want to change the question, because of the already given answers. I post this here mostly for my own reference.

So here is my new example. I have a function that computes an integer square root:

def iSqrt1(n: Int): Int =
if n &lt; 0 then throw IllegalArgumentException(&quot;Argument to iSqrt can not be negative! Called with: &quot; + n)
else if n &lt; 2 then n
else
val x0 = n / 2
@tailrec
def iterate(x0: Int, x1: Int): Int = if x1 &lt; x0 then iterate(x1, (x1 + n / x1) / 2) else x0
iterate(x0, (x0 + n / x0) / 2)

I don't want to go into details here. If you are interested in the algorithm, I took it straight form wikipedia (see 'Using only integer division').

I implemented this function for the types Int, Long, and BigInt. And it almost looks excactly the same for each type, except for the starting value x0, that is different sometimes to have a better performance.

So I tried to write a generalised isqrt function for the classes Int, Long and BigInt to the best of my abilities, and I used scala Version 3.3.0 for it.

I first tried the approach with type classes:

import scala.annotation.{tailrec, targetName}
trait MathOperators[T] {
def plus(x: T, y: T): T
def minus(x: T, y: T): T
def multiply(x: T, y: T): T
def divide(x: T, y: T): T
def divide(x: T, y: Int): T
def lessThan(x: T, y: T): Boolean
def lessThan(x: T, y: Int): Boolean
def toBigInt(x: T): BigInt
def startValueFrom(x: T): T
}
given MathOperators[Int] with {
def plus(x: Int, y: Int): Int = x + y
def minus(x: Int, y: Int): Int = x - y
def multiply(x: Int, y: Int): Int = x * y
def divide(x: Int, y: Int): Int = x / y
def lessThan(x: Int, y: Int): Boolean = x &lt; y
def toBigInt(x: Int): BigInt = BigInt(x)
def startValueFrom(n: Int): Int = n / 2
}
given MathOperators[Long] with {
def plus(x: Long, y: Long): Long = x + y
def minus(x: Long, y: Long): Long = x - y
def multiply(x: Long, y: Long): Long = x * y
def divide(x: Long, y: Long): Long = x / y
def divide(x: Long, y: Int): Long = x / y
def lessThan(x: Long, y: Long): Boolean = x &lt; y
def lessThan(x: Long, y: Int): Boolean = x &lt; y
def toBigInt(x: Long): BigInt = BigInt(x)
def startValueFrom(n: Long): Long = n / 2
}
given MathOperators[BigInt] with {
def plus(x: BigInt, y: BigInt): BigInt = x + y
def minus(x: BigInt, y: BigInt): BigInt = x - y
def multiply(x: BigInt, y: BigInt): BigInt = x * y
def divide(x: BigInt, y: BigInt): BigInt = x / y
def divide(x: BigInt, y: Int): BigInt = x / y
def lessThan(x: BigInt, y: BigInt): Boolean = x &lt; y
def lessThan(x: BigInt, y: Int): Boolean = x &lt; y
def toBigInt(x: BigInt): BigInt = x
def startValueFrom(n: BigInt): BigInt = BigInt(2).pow((n.bitLength / 2) + 1)
}
def startValue[A](n: A)(using mathOp: MathOperators[A]): A = {
mathOp.startValueFrom(n)
}
implicit class MathOperatorsSyntax[T](x: T)(using op: MathOperators[T]) {
@targetName(&quot;plus&quot;)
def +(y: T): T = op.plus(x, y)
@targetName(&quot;minus&quot;)
def -(y: T): T = op.minus(x, y)
@targetName(&quot;multiply&quot;)
def *(y: T): T = op.multiply(x, y)
@targetName(&quot;divide&quot;)
def /(y: T): T = op.divide(x, y)
@targetName(&quot;divide&quot;)
def /(y: Int): T = op.divide(x, y)
@targetName(&quot;lessThen&quot;)
def &lt;(y: T): Boolean = op.lessThan(x, y)
@targetName(&quot;lessThen&quot;)
def &lt;(y: Int): Boolean = op.lessThan(x, y)
}
def iSqrt[T](n: T)(using op: MathOperators[T]): T =
if n &lt; 0 then throw IllegalArgumentException(&quot;Argument to iSqrt can not be negative! Called with: &quot; + n)
else if n &lt; 2 then n
else
val x0: T = startValue(n)
@tailrec
def iterate(x0: T, x1: T): T = if x1 &lt; x0 then iterate(x1, (x1 + n / x1) / 2) else x0
iterate(x0, (x0 + n / x0) / 2)
@main
def main(): Unit = {
val n1 = BigInt(&quot;1236549865413213456498765432136546879854651321&quot;)
println(iSqrt(n1))
}

As Dima commented above:
> But you can just write your own [type class] to do whatever you like.

Yes, One can. It for shure is a good thing to know how to do it just in case. But in my example here it looks like some boilerplate to me. (Or at least I could not figure out how to do this shorter.)

So I looked into the second approach: use a already build in type class like Numeric

For my new example, I picked the class Integral.
At first this was not very satisfying as well, until I found infixOrderingOps and infixIntegralOps. With these imports I could use operators like +, -, %, < and so on:

import scala.annotation.tailrec
import scala.math.Integral
import scala.math.Integral.Implicits
import math.Ordering.Implicits.infixOrderingOps
import math.Integral.Implicits.infixIntegralOps
import collection.immutable.Range.BigDecimal.bigDecAsIntegral
def iSqrt[T](input: T)(using integral: Integral[T]): T =
val two: T = integral.fromInt(2)
val zero: T = integral.zero
val n: T = input match {
case _: BigDecimal =&gt; throw IllegalArgumentException(&quot;iSqrt does not accept arguments of type BigDecimal&quot;)
case x: T if x &lt; zero =&gt; throw IllegalArgumentException(&quot;Argument to iSqrt can not be negative! Called with: &quot; + x)
case x: T =&gt; x
}
val x0: T =
val startValue: Any =
n match
case x: BigInt =&gt; BigInt(2).pow(x.bitLength / 2 + 1) 
case x: T =&gt; x / two
startValue match
case x: T =&gt; x
@tailrec
def iterate(x0: T, x1: T): T = if x1 &lt; x0 then iterate(x1, (x1 + n / x1) / two) else x0
if n &lt; two then n
else
iterate(x0, (x0 + n / x0) / two)

I even managed to sneak in the speacial start value x0 for certain types, but this is admittedly not so clean as in the first version.

After trying all that out, the first answer looks like the right choice for what I wanted, so I accepted it.

huangapple
  • 本文由 发表于 2023年7月12日 20:52:17
  • 转载请务必保留本文链接:https://go.coder-hub.com/76670783.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定