为什么Rust编译器无法自动矢量化这个浮点数点积实现?

huangapple go评论95阅读模式
英文:

Why can't the Rust compiler auto-vectorize this FP dot product implementation?

问题

考虑一个简单的缩减,比如点乘:

  1. pub fn add(a: &[f32], b: &[f32]) -> f32 {
  2. a.iter().zip(b.iter()).fold(0.0, |c, (x, y)| c + x * y)
  3. }

使用rustc 1.68,带有-C opt-level=3 -C target-feature=+avx2,+fma,得到:

  1. .LBB0_5:
  2. vmovss xmm1, dword ptr [rdi + 4*rsi]
  3. vmulss xmm1, xmm1, dword ptr [rdx + 4*rsi]
  4. vmovss xmm2, dword ptr [rdi + 4*rsi + 4]
  5. vaddss xmm0, xmm0, xmm1
  6. vmulss xmm1, xmm2, dword ptr [rdx + 4*rsi + 4]
  7. vaddss xmm0, xmm0, xmm1
  8. vmovss xmm1, dword ptr [rdi + 4*rsi + 8]
  9. vmulss xmm1, xmm1, dword ptr [rdx + 4*rsi + 8]
  10. vaddss xmm0, xmm0, xmm1
  11. vmovss xmm1, dword ptr [rdi + 4*rsi + 12]
  12. vmulss xmm1, xmm1, dword ptr [rdx + 4*rsi + 12]
  13. lea rax, [rsi + 4]
  14. vaddss xmm0, xmm0, xmm1
  15. mov rsi, rax
  16. cmp rcx, rax
  17. jne .LBB0_5

这是一个标量实现,带有循环展开,甚至没有将乘法和加法合并成FMA。从这段代码转换为SIMD代码应该很容易,为什么rustc没有进行这种优化呢?

如果我用i32替换f32,我会得到期望的自动矢量化:

  1. .LBB0_5:
  2. vmovdqu ymm4, ymmword ptr [rdx + 4*rax]
  3. vmovdqu ymm5, ymmword ptr [rdx + 4*rax + 32]
  4. vmovdqu ymm6, ymmword ptr [rdx + 4*rax + 64]
  5. vmovdqu ymm7, ymmword ptr [rdx + 4*rax + 96]
  6. vpmulld ymm4, ymm4, ymmword ptr [rdi + 4*rax]
  7. vpaddd ymm0, ymm4, ymm0
  8. vpmulld ymm4, ymm5, ymmword ptr [rdi + 4*rax + 32]
  9. vpaddd ymm1, ymm4, ymm1
  10. vpmulld ymm4, ymm6, ymmword ptr [rdi + 4*rax + 64]
  11. vpmulld ymm5, ymm7, ymmword ptr [rdi + 4*rax + 96]
  12. vpaddd ymm2, ymm4, ymm2
  13. vpaddd ymm3, ymm5, ymm3
  14. add rax, 32
  15. cmp r8, rax
  16. jne .LBB0_5
英文:

Lets consider a simple reduction, such as a dot product:

  1. pub fn add(a:&[f32], b:&[f32]) -> f32 {
  2. a.iter().zip(b.iter()).fold(0.0, |c,(x,y)| c+x*y))
  3. }

Using rustc 1.68 with -C opt-level=3 -C target-feature=+avx2,+fma
I get

  1. .LBB0_5:
  2. vmovss xmm1, dword ptr [rdi + 4*rsi]
  3. vmulss xmm1, xmm1, dword ptr [rdx + 4*rsi]
  4. vmovss xmm2, dword ptr [rdi + 4*rsi + 4]
  5. vaddss xmm0, xmm0, xmm1
  6. vmulss xmm1, xmm2, dword ptr [rdx + 4*rsi + 4]
  7. vaddss xmm0, xmm0, xmm1
  8. vmovss xmm1, dword ptr [rdi + 4*rsi + 8]
  9. vmulss xmm1, xmm1, dword ptr [rdx + 4*rsi + 8]
  10. vaddss xmm0, xmm0, xmm1
  11. vmovss xmm1, dword ptr [rdi + 4*rsi + 12]
  12. vmulss xmm1, xmm1, dword ptr [rdx + 4*rsi + 12]
  13. lea rax, [rsi + 4]
  14. vaddss xmm0, xmm0, xmm1
  15. mov rsi, rax
  16. cmp rcx, rax
  17. jne .LBB0_5

which is a scalar implementation with loop unrolling, not even contracting the mul+add into FMAs. From this code to simd code should be easy, why does rustc not optimize this?

If I replace f32 with i32 I get the desired auto-vectorization:

  1. .LBB0_5:
  2. vmovdqu ymm4, ymmword ptr [rdx + 4*rax]
  3. vmovdqu ymm5, ymmword ptr [rdx + 4*rax + 32]
  4. vmovdqu ymm6, ymmword ptr [rdx + 4*rax + 64]
  5. vmovdqu ymm7, ymmword ptr [rdx + 4*rax + 96]
  6. vpmulld ymm4, ymm4, ymmword ptr [rdi + 4*rax]
  7. vpaddd ymm0, ymm4, ymm0
  8. vpmulld ymm4, ymm5, ymmword ptr [rdi + 4*rax + 32]
  9. vpaddd ymm1, ymm4, ymm1
  10. vpmulld ymm4, ymm6, ymmword ptr [rdi + 4*rax + 64]
  11. vpmulld ymm5, ymm7, ymmword ptr [rdi + 4*rax + 96]
  12. vpaddd ymm2, ymm4, ymm2
  13. vpaddd ymm3, ymm5, ymm3
  14. add rax, 32
  15. cmp r8, rax
  16. jne .LBB0_5

答案1

得分: 5

这是因为浮点数不是可结合的,通常意味着 a+(b+c) != (a+b)+c。因此,对浮点数求和变成了串行任务,因为编译器不会将 ((a+b)+c)+d 重新排序为 (a+b)+(c+d)。后者可以矢量化,而前者则不能。

在大多数情况下,程序员不关心求和顺序的差异。

gcc 和 clang 提供 -fassociative-math 标志,允许编译器为了性能而重新排序浮点运算。

rustc 不提供这一选项,据我所知,llvm 也不接受更改这种行为的标志。

在 Rust 的 nightly 版本中,你可以使用 #![feature(core_intrinsics)] 来进行优化:

  1. #![feature(core_intrinsics)]
  2. pub fn add(a: &[f32], b: &[f32]) -> f32 {
  3. unsafe {
  4. a.iter().zip(b.iter()).fold(0.0, |c, (x, y)| std::intrinsics::fadd_fast(c, x * y))
  5. }
  6. }

这不使用 fma。要使用 fma,你可以这样做:

  1. #![feature(core_intrinsics)]
  2. pub fn add(a: &[f32], b: &[f32]) -> f32 {
  3. unsafe {
  4. a.iter().zip(b.iter()).fold(0.0, |c, (&x, &y)| std::intrinsics::fadd_fast(c, std::intrinsics::fmul_fast(x, y)))
  5. }
  6. }

我不知道一个稳定的 Rust 解决方案,不涉及显式的 simd 内置函数。

英文:

This is because floating points are not associative, meaning in general a+(b+c) != (a+b)+c. So summing up floating points becomes are serial task, because the compiler will not reorder ((a+b)+c)+d into (a+b)+(c+d). The last can be vectorized, the first cannot.

In most cases the programmer does not care about the differences in summing order.

gcc and clang provide the -fassociative-math flag which will allow the compiler to reorder floating point operations for performance.

rustc does not provide this and for all I know llvm also does not accept flags which will change this behavior.

In nightly Rust you can use #![feature(core_intrinsics)] to get the optimization:

  1. #![feature(core_intrinsics)]
  2. pub fn add(a:&[f32], b:&[f32]) -> f32 {
  3. unsafe {
  4. a.iter().zip(b.iter()).fold(0.0, |c,(x,y)| std::intrinsics::fadd_fast(c,x*y))
  5. }
  6. }

This does not use fma. So for fma you have to use:

  1. #![feature(core_intrinsics)]
  2. pub fn add(a:&[f32], b:&[f32]) -> f32 {
  3. unsafe {
  4. a.iter().zip(b.iter()).fold(0.0, |c,(&x,&y)| std::intrinsics::fadd_fast(c,std::intrinsics::fmul_fast(x,y)))
  5. }
  6. }

I am not aware of a stable Rust solution, which does not involve explicit simd intrinsics.

huangapple
  • 本文由 发表于 2023年4月19日 21:23:12
  • 转载请务必保留本文链接:https://go.coder-hub.com/76055058.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定