How can I define a generic function in scala that accepts all 'computable' Datatypes?




  1. def square(n: Int): Int = n * n


然后,我对另一种数据类型使用相同的“算法”,比如说 long:

  1. def square(n: Long): Long = n * n

然后对于 BigInt、Short、Byte 等等也是如此。



  1. def square[T <: AnyVal](n: T): T = n * n

但这不起作用,因为在类层次结构中,除了 AnyVal 中有 Int、Long 和 Float 外,还有 Boolean 和 Unit。
而对于 Boolean 和 Unit,术语 n * n 没有意义,我会收到编译器错误(这是正确的)。

我只想让我的函数适用于“可计算”的数据类型,比如 Int、Long、Float 等等,

当然,我可以匹配函数的输入变量 n,然后分别处理每种情况,

我尝试过创建自己的 trait 'Computables',然后扩展到其他类,比如 Int、Long 等等,但编译器会抱怨 '... 无法扩展最终类 Int'。



I've had this problem a couple of times now:

I have a function that computes something, lets say

  1. def square(n: Int): Int = n * n

(Very simple example, but this will do)

Then I have the same 'algorithm' for another datatype, lets say long:

  1. def square(n: Long): Long = n * n

And then for BigInt, Short, Byte and so on.

If my algorithm is more complex and longer than in this example, I have a lot of code repetition.

What I would like to have is a generic definition like:

  1. def square[T :&gt; AnyVal](n: T): T = n * n

But this does not work, because in the class hirachy, below AnyVal with Int and Long and Float there also is Boolean and Unit.
And for Boolean and Unit the term n * n does not make sense and I get a compiler error (correctly).

I only want my function to work for the 'computable' Datatypes like Int, Long, Float, ...
that have all the usual math operators like +, *, /, < and so on
and then formulate my algorithm or calculation with this operators for all at once.

Of course I can match on the functions input variable n, and then handle each case
differently, but there I also will repeat all the code like before with overloading.

I tried to make my own trait 'Computables' and then extend to the other classes Int, Long, ..., but the compiler complains '... cannot extend final class Int'

Is this even possible? Am I missing something?


得分: 7


  1. def square[T](n: T)(using numeric: Numeric[T]): T = numeric.times(n,n)


  1. def square[T](n: T)(using numeric: Numeric[T]): T = {
  2. import numeric._
  3. n * n
  4. }

演示 @scastie


You can use the Numeric trait:

  1. def square[T](n: T)(using numeric: Numeric[T]): T = numeric.times(n,n)


  1. def square[T](n: T)(using numeric: Numeric[T]): T = {
  2. import numeric._
  3. n * n
  4. }

Demo @scastie


得分: 7

我认为你正在寻找一个type class



  1. def someFunction[T : SomeType](t: T)` 等同于 `def someFunction[T](t: T)(implicit ev: SomeType[T])


  1. val foo: Foo = ???
  2. someFunction(foo)




  1. def someFunction[T : SomeType](t: T) = implicitly[SomeType[T]].doStuff(t)


  1. def square[T : Numeric](t: T) = implicitly[Numeric[T]].times(t, t)

I think, you are looking for a type class

The Numeric mentioned in the other answer is an example of a type class describing all numeric types. But you can generalize that to describe any kind of behavior.

Check out the link I mentioned above, it has the details and examples, but at a high level, the idea is this:

def someFunction[T : SomeType](t: T) is equivalent to def someFunction[T](t: T)(implicit ev: SomeType[T])

This means that something like

  1. val foo: Foo = ???
  2. someFunction(foo)

will compile if and only if, you have an implicit instance of type SomeType[Foo] somewhere in scope.

So, that solves the first part of your problem: we have Numeric instances defined for all the "numeric" types, but not for strings or boolean, so this way you are restricting the types that can be sent to your function to a specific "class" (thus "type class") of types.

The other purpose of the type class is to express the common behaviors of the types it includes. You can "summon" the implicit "evidence" of the type class to access it:

  1. def someFunction[T : SomeType](t: T) = implicitly[SomeType[T]].doStuff(t)

In your case, with the square:

  1. def square[T : Numeric](t: T) = implicitly[Numeric[T]].times(t, t)


得分: 1




  1. def iSqrt1(n: Int): Int =
  2. if n < 0 then throw IllegalArgumentException("Argument to iSqrt can not be negative! Called with: " + n)
  3. else if n < 2 then n
  4. else
  5. val x0 = n / 2
  6. @tailrec
  7. def iterate(x0: Int, x1: Int): Int = if x1 < x0 then iterate(x1, (x1 + n / x1) / 2) else x0
  8. iterate(x0, (x0 + n / x0) / 2)

我不想在这里详细讨论。如果您对算法感兴趣,我直接从 维基百科 中获取了它(参见'仅使用整数除法')。

我为 Int、Long 和 BigInt 类型实现了这个函数。对于每种类型,除了有时为了提高性能而不同的起始值 x0 外,它几乎看起来完全相同。

所以,我尽力编写了一种通用的 isqrt 函数,适用于 Int、Long 和 BigInt 类型,并且我使用了 Scala 版本 3.3.0。


  1. import scala.annotation.{tailrec, targetName}
  2. trait MathOperators[T] {
  3. def plus(x: T, y: T): T
  4. def minus(x: T, y: T): T
  5. def multiply(x: T, y: T): T
  6. def divide(x: T, y: T): T
  7. def divide(x: T, y: Int): T
  8. def lessThan(x: T, y: T): Boolean
  9. def lessThan(x: T, y: Int): Boolean
  10. def toBigInt(x: T): BigInt
  11. def startValueFrom(x: T): T
  12. }
  13. given MathOperators[Int] with {
  14. // 实现 Int 类型的 MathOperators
  15. // ...
  16. }
  17. given MathOperators[Long] with {
  18. // 实现 Long 类型的 MathOperators
  19. // ...
  20. }
  21. given MathOperators[BigInt] with {
  22. // 实现 BigInt 类型的 MathOperators
  23. // ...
  24. }
  25. def startValue[A](n: A)(using mathOp: MathOperators[A]): A = {
  26. mathOp.startValueFrom(n)
  27. }
  28. implicit class MathOperatorsSyntax[T](x: T)(using op: MathOperators[T]) {
  29. // 定义一些操作符的扩展方法
  30. // ...
  31. }
  32. def iSqrt[T](n: T)(using op: MathOperators[T]): T = {
  33. // 实现 iSqrt 函数
  34. // ...
  35. }
  36. @main
  37. def main(): Unit = {
  38. val n1 = BigInt("1236549865413213456498765432136546879854651321")
  39. println(iSqrt(n1))
  40. }

正如 Dima 在上面评论中提到的:

但你可以编写自己的 [类型类] 来执行任何你想要的操作。


所以,我研究了第二种方法:使用已经构建好的类型类,如 Numeric。

对于我的新示例,我选择了 Integral 类。起初,这也不是很令人满意,直到我找到了 infixOrderingOps 和 infixIntegralOps。有了这些导入,我可以使用诸如 +、-、%、< 等操作符:

  1. import scala.annotation.tailrec
  2. import scala.math.Integral
  3. import scala.math.Integral.Implicits
  4. import math.Ordering.Implicits.infixOrderingOps
  5. import math.Integral.Implicits.infixIntegralOps
  6. import collection.immutable.Range.BigDecimal.bigDecAsIntegral
  7. def iSqrt[T](input: T)(using integral: Integral[T]): T = {
  8. // 实现 iSqrt 函数
  9. // ...
  10. }

我甚至设法在某些类型中偷偷加入了特殊的起始值 x0,但是这无疑不如第一个版本干净。




It was very hard for me to decide wich answer to accept.
It turned out the simple example I choose was a bit to simple, so i decided to try this with a real life example. But I didn't want to change the question, because of the already given answers. I post this here mostly for my own reference.

So here is my new example. I have a function that computes an integer square root:

  1. def iSqrt1(n: Int): Int =
  2. if n &lt; 0 then throw IllegalArgumentException(&quot;Argument to iSqrt can not be negative! Called with: &quot; + n)
  3. else if n &lt; 2 then n
  4. else
  5. val x0 = n / 2
  6. @tailrec
  7. def iterate(x0: Int, x1: Int): Int = if x1 &lt; x0 then iterate(x1, (x1 + n / x1) / 2) else x0
  8. iterate(x0, (x0 + n / x0) / 2)

I don't want to go into details here. If you are interested in the algorithm, I took it straight form wikipedia (see 'Using only integer division').

I implemented this function for the types Int, Long, and BigInt. And it almost looks excactly the same for each type, except for the starting value x0, that is different sometimes to have a better performance.

So I tried to write a generalised isqrt function for the classes Int, Long and BigInt to the best of my abilities, and I used scala Version 3.3.0 for it.

I first tried the approach with type classes:

  1. import scala.annotation.{tailrec, targetName}
  2. trait MathOperators[T] {
  3. def plus(x: T, y: T): T
  4. def minus(x: T, y: T): T
  5. def multiply(x: T, y: T): T
  6. def divide(x: T, y: T): T
  7. def divide(x: T, y: Int): T
  8. def lessThan(x: T, y: T): Boolean
  9. def lessThan(x: T, y: Int): Boolean
  10. def toBigInt(x: T): BigInt
  11. def startValueFrom(x: T): T
  12. }
  13. given MathOperators[Int] with {
  14. def plus(x: Int, y: Int): Int = x + y
  15. def minus(x: Int, y: Int): Int = x - y
  16. def multiply(x: Int, y: Int): Int = x * y
  17. def divide(x: Int, y: Int): Int = x / y
  18. def lessThan(x: Int, y: Int): Boolean = x &lt; y
  19. def toBigInt(x: Int): BigInt = BigInt(x)
  20. def startValueFrom(n: Int): Int = n / 2
  21. }
  22. given MathOperators[Long] with {
  23. def plus(x: Long, y: Long): Long = x + y
  24. def minus(x: Long, y: Long): Long = x - y
  25. def multiply(x: Long, y: Long): Long = x * y
  26. def divide(x: Long, y: Long): Long = x / y
  27. def divide(x: Long, y: Int): Long = x / y
  28. def lessThan(x: Long, y: Long): Boolean = x &lt; y
  29. def lessThan(x: Long, y: Int): Boolean = x &lt; y
  30. def toBigInt(x: Long): BigInt = BigInt(x)
  31. def startValueFrom(n: Long): Long = n / 2
  32. }
  33. given MathOperators[BigInt] with {
  34. def plus(x: BigInt, y: BigInt): BigInt = x + y
  35. def minus(x: BigInt, y: BigInt): BigInt = x - y
  36. def multiply(x: BigInt, y: BigInt): BigInt = x * y
  37. def divide(x: BigInt, y: BigInt): BigInt = x / y
  38. def divide(x: BigInt, y: Int): BigInt = x / y
  39. def lessThan(x: BigInt, y: BigInt): Boolean = x &lt; y
  40. def lessThan(x: BigInt, y: Int): Boolean = x &lt; y
  41. def toBigInt(x: BigInt): BigInt = x
  42. def startValueFrom(n: BigInt): BigInt = BigInt(2).pow((n.bitLength / 2) + 1)
  43. }
  44. def startValue[A](n: A)(using mathOp: MathOperators[A]): A = {
  45. mathOp.startValueFrom(n)
  46. }
  47. implicit class MathOperatorsSyntax[T](x: T)(using op: MathOperators[T]) {
  48. @targetName(&quot;plus&quot;)
  49. def +(y: T): T = op.plus(x, y)
  50. @targetName(&quot;minus&quot;)
  51. def -(y: T): T = op.minus(x, y)
  52. @targetName(&quot;multiply&quot;)
  53. def *(y: T): T = op.multiply(x, y)
  54. @targetName(&quot;divide&quot;)
  55. def /(y: T): T = op.divide(x, y)
  56. @targetName(&quot;divide&quot;)
  57. def /(y: Int): T = op.divide(x, y)
  58. @targetName(&quot;lessThen&quot;)
  59. def &lt;(y: T): Boolean = op.lessThan(x, y)
  60. @targetName(&quot;lessThen&quot;)
  61. def &lt;(y: Int): Boolean = op.lessThan(x, y)
  62. }
  63. def iSqrt[T](n: T)(using op: MathOperators[T]): T =
  64. if n &lt; 0 then throw IllegalArgumentException(&quot;Argument to iSqrt can not be negative! Called with: &quot; + n)
  65. else if n &lt; 2 then n
  66. else
  67. val x0: T = startValue(n)
  68. @tailrec
  69. def iterate(x0: T, x1: T): T = if x1 &lt; x0 then iterate(x1, (x1 + n / x1) / 2) else x0
  70. iterate(x0, (x0 + n / x0) / 2)
  71. @main
  72. def main(): Unit = {
  73. val n1 = BigInt(&quot;1236549865413213456498765432136546879854651321&quot;)
  74. println(iSqrt(n1))
  75. }

As Dima commented above:
> But you can just write your own [type class] to do whatever you like.

Yes, One can. It for shure is a good thing to know how to do it just in case. But in my example here it looks like some boilerplate to me. (Or at least I could not figure out how to do this shorter.)

So I looked into the second approach: use a already build in type class like Numeric

For my new example, I picked the class Integral.
At first this was not very satisfying as well, until I found infixOrderingOps and infixIntegralOps. With these imports I could use operators like +, -, %, < and so on:

  1. import scala.annotation.tailrec
  2. import scala.math.Integral
  3. import scala.math.Integral.Implicits
  4. import math.Ordering.Implicits.infixOrderingOps
  5. import math.Integral.Implicits.infixIntegralOps
  6. import collection.immutable.Range.BigDecimal.bigDecAsIntegral
  7. def iSqrt[T](input: T)(using integral: Integral[T]): T =
  8. val two: T = integral.fromInt(2)
  9. val zero: T = integral.zero
  10. val n: T = input match {
  11. case _: BigDecimal =&gt; throw IllegalArgumentException(&quot;iSqrt does not accept arguments of type BigDecimal&quot;)
  12. case x: T if x &lt; zero =&gt; throw IllegalArgumentException(&quot;Argument to iSqrt can not be negative! Called with: &quot; + x)
  13. case x: T =&gt; x
  14. }
  15. val x0: T =
  16. val startValue: Any =
  17. n match
  18. case x: BigInt =&gt; BigInt(2).pow(x.bitLength / 2 + 1)
  19. case x: T =&gt; x / two
  20. startValue match
  21. case x: T =&gt; x
  22. @tailrec
  23. def iterate(x0: T, x1: T): T = if x1 &lt; x0 then iterate(x1, (x1 + n / x1) / two) else x0
  24. if n &lt; two then n
  25. else
  26. iterate(x0, (x0 + n / x0) / two)

I even managed to sneak in the speacial start value x0 for certain types, but this is admittedly not so clean as in the first version.

After trying all that out, the first answer looks like the right choice for what I wanted, so I accepted it.

