英文:
Find all subsets of a mask which are divisible by a certain number
问题
所有4位数1101
的子集是0000
、0001
、0100
、0101
、1000
、1001
、1100
、1011
。这个掩码的所有可被2整除的子集是0000
、0100
、1000
、1100
。
给定一个64位掩码M和一个64位整数P,如何迭代所有可被P整除的M的子集呢?
要迭代位掩码的子集,可以使用以下C代码:
uint64_t superset = ...;
uint64_t subset = 0;
do {
print(subset);
subset = (subset - superset) & superset;
} while (subset != 0);
如果M是~0
,可以从0
开始,不断添加P以迭代所有P的倍数。如果P是2的幂,可以执行M &= ~(P - 1)
来截断永远不会设置的位。
但如果没有上述的限制,是否有更好的方法来检查每个子集是否可被P整除?平均而言,使用朴素算法来获取下一个可被P整除的子集需要O(P)操作。是否可以做得比O(P)更好呢?
英文:
All subsets of of the 4-bit number 1101
are 0000
, 0001
, 0100
, 0101
, 1000
, 1001
, 1100
, 1011
. All subsets of this mask which are divisible by 2 are 0000
, 0100
, 1000
, 1100
.
Given a 64-bit mask M and a 64-bit bit integer P, how do I iterate over all subsets of M which are divisible by P?
To iterate over subsets of a bit mask, I can do
uint64_t superset = ...;
uint64_t subset = 0;
do {
print(subset);
subset = (subset - superset) & superset;
} while (subset != 0);
If M is ~0
I can just start with 0
and keep adding P to iterate over all multiples of P. If P is a power of two I can just do M &= ~(P - 1)
to chop off bits which are never going to be set.
But if I have none of the constraints above, do I have a better shot than naively checking each and every subset for divisibility by P? This naive algorithm on average to get the next subset which is divisible by P takes O(P) operations. Can I do better than O(P)?
答案1
得分: 1
Here is the translated content you requested:
并行算法
有一些输入情况下,检查因子的倍数比掩码的子集更有效,而在其他输入情况下则相反。例如,当 M 为 0xFFFFFFFFFFFFFFFF
而 P 为 0x4000000000000000
时,检查 P 的三个倍数几乎是瞬间的,但即使你每秒能够计算和检查十亿个 M 的子集,枚举它们需要三十年。仅找到大于或等于 P 的子集的优化也只能将时间减少到四年。
然而,有一个强烈的理由去枚举和检查 P 的倍数而不是 M 的子集:并行性。我想强调一下,因为此代码的其他地方有不正确的注释:OP 中的算法本质上是顺序的,因为每个 subset
的值都使用前一个 subset
的值。它不能在所有较低的子集已经计算完之前运行。它不能被矢量化以使用 AVX 寄存器或类似寄存器。你不能将四个值加载到 AVX2 寄存器中并在它们上运行 SIMD 指令,因为你需要计算第一个值以初始化第二个元素,第二个值以初始化第三个元素,以此类推,然后你又回到一次只计算一个值。它也不能在不同 CPU 核心上的工作线程之间拆分工作,这不同于拆分工作。 (接受的答案可以修改为执行后者,但除非进行全面重构,否则不能执行前者。)你不能将工作负载划分为子集 0 到 63、子集 64 到 127 等,然后在每个子集上并行处理不同线程,因为你无法在知道第六十三个子集是什么之前开始第六十四个子集,而要知道第六十三个子集是什么,你需要第六十二个子集,以此类推。
如果你从中没有得到其他任何东西,我强烈建议你在启用完整优化的情况下在 Godbolt 上尝试此代码,亲自看看它是如何编译成顺序代码的。如果你熟悉 OpenMP,请尝试添加 #pragma omp simd
和 #pragma omp parallel
指令,看看会发生什么。问题不在于编译器,而在于算法本质上是顺序的。但看看真正的编译器会至少让你相信,2023 年的编译器不能像这样矢量化代码。
作为参考,这是 Clang 16 对 find
的处理:
Find: # @Find
push r15
push r14
push r12
push rbx
push rax
mov rbx, rdi
cmp rdi, rsi
jne .LBB1_1
.LBB1_6:
lea rdi, [rip + .L.str]
mov rsi, rbx
xor eax, eax
add rsp, 8
pop rbx
pop r12
pop r14
pop r15
jmp printf@PLT # TAILCALL
.LBB1_1:
mov r14, rdx
mov r15, rsi
jmp .LBB1_2
.LBB1_5: # in Loop: Header=BB1_2 Depth=1
imul r12, r14
add r15, r12
cmp r15, rbx
je .LBB1_6
.LBB1_2: # =>This Inner Loop Header: Depth=1
cmp r15, rbx
ja .LBB1_7
mov rax, r15
xor rax, rbx
blsi r12, rax
test r12, rbx
je .LBB1_5
mov rdi, rbx
sub rdi, r12
mov rsi, r15
mov rdx, r14
call Find
jmp .LBB1_5
.LBB1_7:
add rsp, 8
pop rbx
pop r12
pop r14
pop r15
ret
枚举并检查倍数而不是子集
除了具有更多的并行性外,这还具有速度上的几个优势:
- 找到后继者,或者
(i+4)*p
给定i*p
以在四个元素的向量上使用,可以简化为单个加法。 - 测试因子是否是子集是一个与操作,而测试子集是否是因子需要进行
%
运算,大多数 CPU 都不具备作为本机指令的能力,即使有也始终是最慢的 ALU 运算。
因此,可以使用多线程和 SIMD 进行加速的此代码版本:
#include <assert.h>
#include <omp.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint_fast32_t word;
/* Sets each element results[i], where i <= mask/factor, to true if factor*i
* is a subset of the mask, false otherwise. The results array MUST have at
* least (mask/factor + 1U) elements. The capacity of results in elements is
* required and checked, just in case.
*
* Returns a pointer to the results.
<details>
<summary>英文:</summary>
## A Parallel Algorithm
There are inputs for which it is vastly more efficient to check the multiples of the factor than the subsets of the mask, and inputs where it’s the other way around. For example, when _M_ is `0xFFFFFFFFFFFFFFFF` and _P_ is `0x4000000000000000`, checking the three multiples of _P_ is nigh-instantaneous, but even if you could crunch and check a billion subsets of _M_ each second, enumerating them all would take thirty years. The optimization of finding only subsets greater than or equal to _P_ would only cut that to four years.
However, there is a strong reason to enumerate and check the multiples of _P_ instead of the subsets of _M_: parallelism. I want to emphasize, because of incorrect comments on this code elsewhere: the algorithm in the OP is **inherently sequential**, because each value of `subset` uses the previous value of `subset`. It cannot run until all the lower subsets have already been calculated. It cannot be vectorized to use AVX registers or similar. You cannot load four values into an AVX2 register and run SIMD instructions on them, because you would need to calculate the first value to initialize the second element, the second to initialize the third, and all three to initialize the final one, and then you are back to computing only one value at a time. It cannot be split between worker threads on different CPU cores either, which is not the same thing. (The accepted answer can be modified to do the latter, but not the former without a total refactoring.) You cannot divide the workload into subsets 0 to 63, subsets 64 to 127, and so on, and have different threads work on each in parallel, because you cannot start on the sixty-fourth subset until you know what the sixty-third subset is, for which you need the sixty-second, and so on.
If you take nothing else away from this, I **highly** recommend that you [try this code out on Godbolt](https://godbolt.org/z/PMao7K76c) with full optimizations enabled, and see for yourself that it compiles to sequential code. If you’re familiar with OpenMP, try adding `#pragma omp simd` and `#pramga omp parallel` directives and see what happens. The problem isn’t with the compiler, it’s that the algorithm is inherently sequential. But seeing what real compilers do should *at least* convince you that compilers in the year 2023 are not able to vectorize code like this.
For reference, here is what Clang 16 does with `find`:
```lang:none
Find: # @Find
push r15
push r14
push r12
push rbx
push rax
mov rbx, rdi
cmp rdi, rsi
jne .LBB1_1
.LBB1_6:
lea rdi, [rip + .L.str]
mov rsi, rbx
xor eax, eax
add rsp, 8
pop rbx
pop r12
pop r14
pop r15
jmp printf@PLT # TAILCALL
.LBB1_1:
mov r14, rdx
mov r15, rsi
jmp .LBB1_2
.LBB1_5: # in Loop: Header=BB1_2 Depth=1
imul r12, r14
add r15, r12
cmp r15, rbx
je .LBB1_6
.LBB1_2: # =>This Inner Loop Header: Depth=1
cmp r15, rbx
ja .LBB1_7
mov rax, r15
xor rax, rbx
blsi r12, rax
test r12, rbx
je .LBB1_5
mov rdi, rbx
sub rdi, r12
mov rsi, r15
mov rdx, r14
call Find
jmp .LBB1_5
.LBB1_7:
add rsp, 8
pop rbx
pop r12
pop r14
pop r15
ret
Enumerate and Check the Multiples Instead of the Subsets
In addition to having more parallelism, this has several advantages in speed:
- Finding the successor, or
(i+4)*p
giveni*p
to use this on a vector of four elements, can be strength-reduced to a single addition. - Testing whether a factor is a subset is a single and operation, whereas testing whether a subset is a factor requires a
%
operation, which most CPUs do not have as a native instruction and is always the slowest ALU operation even when it is there.
So, a version of this code that uses both multi-threading and SIMD for speed-up:
#include <assert.h>
#include <omp.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint_fast32_t word;
/* Sets each element results[i], where i <= mask/factor, to true if factor*i
* is a subset of the mask, false otherwise. The results array MUST have at
* least (mask/factor + 1U) elements. The capacity of results in elements is
* required and checked, just in case.
*
* Returns a pointer to the results.
*/
static bool* check_multiples( const word mask,
const word factor,
const size_t n,
bool results[n] )
{
const word end = mask/factor;
const word complement = ~mask;
assert(&results);
assert(n > end);
#pragma omp parallel for simd schedule(static)
for (word i = 0; i <= end; ++i) {
results[i] = (factor*i & complement) == 0;
}
return results;
}
/* Replace these with non-constants so that the compiler actually
* actually instantiates the function:
*/
/*
#define MASK 0xA0A0UL
#define FACTOR 0x50UL
#define NRESULTS (MASK/FACTOR + 1U)
*/
extern const word MASK, FACTOR;
#define NRESULTS 1024UL
int main(void)
{
bool are_subsets[NRESULTS] = {false};
(void)check_multiples(MASK, FACTOR, NRESULTS, are_subsets);
for (word i = 0; i < NRESULTS; ++i) {
if (are_subsets[i]) {
const unsigned long long multiple = (unsigned long long)FACTOR*i;
printf("%llx ", multiple);
assert((multiple & MASK) == multiple && (multiple & ~MASK) == 0U);
}
}
return EXIT_SUCCESS;
}
The inner loop of check_multiples
compiles, on ICX 2022, to:
.LBB1_5: # =>This Inner Loop Header: Depth=1
vpmullq ymm15, ymm1, ymm0
vpmullq ymm16, ymm2, ymm0
vpmullq ymm17, ymm3, ymm0
vpmullq ymm18, ymm4, ymm0
vpmullq ymm19, ymm5, ymm0
vpmullq ymm20, ymm6, ymm0
vpmullq ymm21, ymm7, ymm0
vpmullq ymm22, ymm8, ymm0
vptestnmq k0, ymm22, ymm9
vptestnmq k1, ymm21, ymm9
kshiftlb k1, k1, 4
korb k0, k0, k1
vptestnmq k1, ymm20, ymm9
vptestnmq k2, ymm19, ymm9
kshiftlb k2, k2, 4
korb k1, k1, k2
kunpckbw k0, k1, k0
vptestnmq k1, ymm18, ymm9
vptestnmq k2, ymm17, ymm9
kshiftlb k2, k2, 4
korb k1, k1, k2
vptestnmq k2, ymm16, ymm9
vptestnmq k3, ymm15, ymm9
kshiftlb k3, k3, 4
korb k2, k2, k3
kunpckbw k1, k2, k1
kunpckwd k1, k1, k0
vmovdqu8 ymm15 {k1} {z}, ymm10
vmovdqu ymmword ptr [rbx + rsi], ymm15
vpaddq ymm15, ymm11, ymm7
vpaddq ymm16, ymm6, ymm11
vpaddq ymm17, ymm5, ymm11
vpaddq ymm18, ymm4, ymm11
vpaddq ymm19, ymm3, ymm11
vpaddq ymm20, ymm2, ymm11
vpaddq ymm21, ymm1, ymm11
vpmullq ymm21, ymm21, ymm0
vpmullq ymm20, ymm20, ymm0
vpmullq ymm19, ymm19, ymm0
vpmullq ymm18, ymm18, ymm0
vpmullq ymm17, ymm17, ymm0
vpmullq ymm16, ymm16, ymm0
vpmullq ymm15, ymm15, ymm0
vpaddq ymm22, ymm8, ymm11
vpmullq ymm22, ymm22, ymm0
vptestnmq k0, ymm22, ymm9
vptestnmq k1, ymm15, ymm9
kshiftlb k1, k1, 4
korb k0, k0, k1
vptestnmq k1, ymm16, ymm9
vptestnmq k2, ymm17, ymm9
kshiftlb k2, k2, 4
korb k1, k1, k2
kunpckbw k0, k1, k0
vptestnmq k1, ymm18, ymm9
vptestnmq k2, ymm19, ymm9
kshiftlb k2, k2, 4
korb k1, k1, k2
vptestnmq k2, ymm20, ymm9
vptestnmq k3, ymm21, ymm9
kshiftlb k3, k3, 4
korb k2, k2, k3
kunpckbw k1, k2, k1
kunpckwd k1, k1, k0
vmovdqu8 ymm15 {k1} {z}, ymm10
vmovdqu ymmword ptr [rbx + rsi + 32], ymm15
vpaddq ymm15, ymm12, ymm7
vpaddq ymm16, ymm6, ymm12
vpaddq ymm17, ymm5, ymm12
vpaddq ymm18, ymm4, ymm12
vpaddq ymm19, ymm3, ymm12
vpaddq ymm20, ymm2, ymm12
vpaddq ymm21, ymm1, ymm12
vpmullq ymm21, ymm21, ymm0
vpmullq ymm20, ymm20, ymm0
vpmullq ymm19, ymm19, ymm0
vpmullq ymm18, ymm18, ymm0
vpmullq ymm17, ymm17, ymm0
vpmullq ymm16, ymm16, ymm0
vpmullq ymm15, ymm15, ymm0
vpaddq ymm22, ymm8, ymm12
vpmullq ymm22, ymm22, ymm0
vptestnmq k0, ymm22, ymm9
vptestnmq k1, ymm15, ymm9
kshiftlb k1, k1, 4
korb k0, k0, k1
vptestnmq k1, ymm16, ymm9
vptestnmq k2, ymm17, ymm9
kshiftlb k2, k2, 4
korb k1, k1, k2
kunpckbw k0, k1, k0
vptestnmq k1, ymm18, ymm9
vptestnmq k2, ymm19, ymm9
kshiftlb k2, k2, 4
korb k1, k1, k2
vptestnmq k2, ymm20, ymm9
vptestnmq k3, ymm21, ymm9
kshiftlb k3, k3, 4
korb k2, k2, k3
kunpckbw k1, k2, k1
kunpckwd k1, k1, k0
vmovdqu8 ymm15 {k1} {z}, ymm10
vmovdqu ymmword ptr [rbx + rsi + 64], ymm15
vpaddq ymm15, ymm13, ymm7
vpaddq ymm16, ymm6, ymm13
vpaddq ymm17, ymm5, ymm13
vpaddq ymm18, ymm4, ymm13
vpaddq ymm19, ymm3, ymm13
vpaddq ymm20, ymm2, ymm13
vpaddq ymm21, ymm1, ymm13
vpmullq ymm21, ymm21, ymm0
vpmullq ymm20, ymm20, ymm0
vpmullq ymm19, ymm19, ymm0
vpmullq ymm18, ymm18, ymm0
vpmullq ymm17, ymm17, ymm0
vpmullq ymm16, ymm16, ymm0
vpmullq ymm15, ymm15, ymm0
vpaddq ymm22, ymm8, ymm13
vpmullq ymm22, ymm22, ymm0
vptestnmq k0, ymm22, ymm9
vptestnmq k1, ymm15, ymm9
kshiftlb k1, k1, 4
korb k0, k0, k1
vptestnmq k1, ymm16, ymm9
vptestnmq k2, ymm17, ymm9
kshiftlb k2, k2, 4
korb k1, k1, k2
kunpckbw k0, k1, k0
vptestnmq k1, ymm18, ymm9
vptestnmq k2, ymm19, ymm9
kshiftlb k2, k2, 4
korb k1, k1, k2
vptestnmq k2, ymm20, ymm9
vptestnmq k3, ymm21, ymm9
kshiftlb k3, k3, 4
korb k2, k2, k3
kunpckbw k1, k2, k1
kunpckwd k1, k1, k0
vmovdqu8 ymm15 {k1} {z}, ymm10
vmovdqu ymmword ptr [rbx + rsi + 96], ymm15
vpaddq ymm8, ymm8, ymm14
vpaddq ymm7, ymm14, ymm7
vpaddq ymm6, ymm14, ymm6
vpaddq ymm5, ymm14, ymm5
vpaddq ymm4, ymm14, ymm4
vpaddq ymm3, ymm14, ymm3
vpaddq ymm2, ymm14, ymm2
vpaddq ymm1, ymm14, ymm1
sub rsi, -128
add rdi, -4
jne .LBB1_5
I encourage you to try your variations on the algorithm in this compiler, under the same settings, and see what happens. If you think it should be possible to generate vectorized code on the subsets as good as that, you should get some practice.
A Possible Improvement
The number of candidates to check could get extremely large, but one way to limit it is to also compute the multiplicative inverse of P, and use that if it is better.
Every value of P decomposes into 2ⁱ · Q, where Q is odd. Since Q and 2⁶⁴ are coprime, Q will have a modular multiplicative inverse, Q', whose product QQ' = 1 (mod 2⁶⁴). You can find this with the extended Euclidean algorithm (but not the method I proposed here initially).
This is useful for optimizing the algorithm because, for many values of P, Q' < P. If m is a solution, m = nP for some integer n. Multiply both sides by Q', and Q'Pm = 2ⁱ · m = Q'n. This means we can enumerate (with a bit of extra logic to make sure they have enough trailing zero bits) the multiples of Q' or of P. Note that, since Q' is odd, it is not necessary to check all multiples of Q'; if the constant in front of m is 4, for example, you need only check the products of 4·Q'.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论