找到所有可以被某个数字整除的掩码子集。

huangapple go评论88阅读模式
英文:

Find all subsets of a mask which are divisible by a certain number

问题

所有4位数1101的子集是00000001010001011000100111001011。这个掩码的所有可被2整除的子集是0000010010001100

给定一个64位掩码M和一个64位整数P,如何迭代所有可被P整除的M的子集呢?

要迭代位掩码的子集,可以使用以下C代码:

uint64_t superset = ...;
uint64_t subset = 0;
do {
    print(subset);
    subset = (subset - superset) & superset;
} while (subset != 0);

如果M~0,可以从0开始,不断添加P以迭代所有P的倍数。如果P是2的幂,可以执行M &= ~(P - 1)来截断永远不会设置的位。

但如果没有上述的限制,是否有更好的方法来检查每个子集是否可被P整除?平均而言,使用朴素算法来获取下一个可被P整除的子集需要O(P)操作。是否可以做得比O(P)更好呢?

英文:

All subsets of of the 4-bit number 1101 are 0000, 0001, 0100, 0101, 1000, 1001, 1100, 1011. All subsets of this mask which are divisible by 2 are 0000, 0100, 1000, 1100.

Given a 64-bit mask M and a 64-bit bit integer P, how do I iterate over all subsets of M which are divisible by P?

To iterate over subsets of a bit mask, I can do

uint64_t superset = ...;
uint64_t subset = 0;
do {
    print(subset);
    subset = (subset - superset) & superset;
} while (subset != 0);

If M is ~0 I can just start with 0 and keep adding P to iterate over all multiples of P. If P is a power of two I can just do M &= ~(P - 1) to chop off bits which are never going to be set.

But if I have none of the constraints above, do I have a better shot than naively checking each and every subset for divisibility by P? This naive algorithm on average to get the next subset which is divisible by P takes O(P) operations. Can I do better than O(P)?

答案1

得分: 1

Here is the translated content you requested:

并行算法

有一些输入情况下,检查因子的倍数比掩码的子集更有效,而在其他输入情况下则相反。例如,当 M0xFFFFFFFFFFFFFFFFP0x4000000000000000 时,检查 P 的三个倍数几乎是瞬间的,但即使你每秒能够计算和检查十亿个 M 的子集,枚举它们需要三十年。仅找到大于或等于 P 的子集的优化也只能将时间减少到四年。

然而,有一个强烈的理由去枚举和检查 P 的倍数而不是 M 的子集:并行性。我想强调一下,因为此代码的其他地方有不正确的注释:OP 中的算法本质上是顺序的,因为每个 subset 的值都使用前一个 subset 的值。它不能在所有较低的子集已经计算完之前运行。它不能被矢量化以使用 AVX 寄存器或类似寄存器。你不能将四个值加载到 AVX2 寄存器中并在它们上运行 SIMD 指令,因为你需要计算第一个值以初始化第二个元素,第二个值以初始化第三个元素,以此类推,然后你又回到一次只计算一个值。它也不能在不同 CPU 核心上的工作线程之间拆分工作,这不同于拆分工作。 (接受的答案可以修改为执行后者,但除非进行全面重构,否则不能执行前者。)你不能将工作负载划分为子集 0 到 63、子集 64 到 127 等,然后在每个子集上并行处理不同线程,因为你无法在知道第六十三个子集是什么之前开始第六十四个子集,而要知道第六十三个子集是什么,你需要第六十二个子集,以此类推。

如果你从中没有得到其他任何东西,我强烈建议你在启用完整优化的情况下在 Godbolt 上尝试此代码,亲自看看它是如何编译成顺序代码的。如果你熟悉 OpenMP,请尝试添加 #pragma omp simd#pragma omp parallel 指令,看看会发生什么。问题不在于编译器,而在于算法本质上是顺序的。但看看真正的编译器会至少让你相信,2023 年的编译器不能像这样矢量化代码。

作为参考,这是 Clang 16 对 find 的处理:

Find:                                   # @Find
        push    r15
        push    r14
        push    r12
        push    rbx
        push    rax
        mov     rbx, rdi
        cmp     rdi, rsi
        jne     .LBB1_1
.LBB1_6:
        lea     rdi, [rip + .L.str]
        mov     rsi, rbx
        xor     eax, eax
        add     rsp, 8
        pop     rbx
        pop     r12
        pop     r14
        pop     r15
        jmp     printf@PLT                      # TAILCALL
.LBB1_1:
        mov     r14, rdx
        mov     r15, rsi
        jmp     .LBB1_2
.LBB1_5:                                #   in Loop: Header=BB1_2 Depth=1
        imul    r12, r14
        add     r15, r12
        cmp     r15, rbx
        je      .LBB1_6
.LBB1_2:                                # =>This Inner Loop Header: Depth=1
        cmp     r15, rbx
        ja      .LBB1_7
        mov     rax, r15
        xor     rax, rbx
        blsi    r12, rax
        test    r12, rbx
        je      .LBB1_5
        mov     rdi, rbx
        sub     rdi, r12
        mov     rsi, r15
        mov     rdx, r14
        call    Find
        jmp     .LBB1_5
.LBB1_7:
        add     rsp, 8
        pop     rbx
        pop     r12
        pop     r14
        pop     r15
        ret

枚举并检查倍数而不是子集

除了具有更多的并行性外,这还具有速度上的几个优势:

  • 找到后继者,或者 (i+4)*p 给定 i*p 以在四个元素的向量上使用,可以简化为单个加法。
  • 测试因子是否是子集是一个与操作,而测试子集是否是因子需要进行 % 运算,大多数 CPU 都不具备作为本机指令的能力,即使有也始终是最慢的 ALU 运算。

因此,可以使用多线程和 SIMD 进行加速的此代码版本:

#include <assert.h>
#include <omp.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>


typedef uint_fast32_t word;

/* Sets each element results[i], where i <= mask/factor, to true if factor*i
 * is a subset of the mask, false otherwise.  The results array MUST have at
 * least (mask/factor + 1U) elements.  The capacity of results in elements is
 * required and checked, just in case.
 *
 * Returns a pointer to the results.


<details>
<summary>英文:</summary>

## A Parallel Algorithm

There are inputs for which it is vastly more efficient to check the multiples of the factor than the subsets of the mask, and inputs where it’s the other way around.  For example, when _M_ is `0xFFFFFFFFFFFFFFFF` and _P_ is `0x4000000000000000`, checking the three multiples of _P_ is nigh-instantaneous, but even if you could crunch and check a billion subsets of _M_ each second, enumerating them all would take thirty years.  The optimization of finding only subsets greater than or equal to _P_ would only cut that to four years.

However, there is a strong reason to enumerate and check the multiples of _P_ instead of the subsets of _M_: parallelism. I want to emphasize, because of incorrect comments on this code elsewhere: the algorithm in the OP is **inherently sequential**, because each value of `subset` uses the previous value of `subset`. It cannot run until all the lower subsets have already been calculated.  It cannot be vectorized to use AVX registers or similar.  You cannot load four values into an AVX2 register and run SIMD instructions on them, because you would need to calculate the first value to initialize the second element, the second to initialize the third, and all three to initialize the final one, and then you are back to computing only one value at a time.  It cannot be split between worker threads on different CPU cores either, which is not the same thing. (The accepted answer can be modified to do the latter, but not the former without a total refactoring.)  You cannot divide the workload into subsets 0 to 63, subsets 64 to 127, and so on, and have different threads work on each in parallel, because you cannot start on the sixty-fourth subset until you know what the sixty-third subset is, for which you need the sixty-second, and so on.

If you take nothing else away from this, I **highly** recommend that you [try this code out on Godbolt](https://godbolt.org/z/PMao7K76c) with full optimizations enabled, and see for yourself that it compiles to sequential code.  If you’re familiar with OpenMP, try adding `#pragma omp simd` and `#pramga omp parallel` directives and see what happens.  The problem isn’t with the compiler, it’s that the algorithm is inherently sequential.  But seeing what real compilers do should *at least* convince you that compilers in the year 2023 are not able to vectorize code like this.

For reference, here is what Clang 16 does with `find`:

```lang:none
Find:                                   # @Find
        push    r15
        push    r14
        push    r12
        push    rbx
        push    rax
        mov     rbx, rdi
        cmp     rdi, rsi
        jne     .LBB1_1
.LBB1_6:
        lea     rdi, [rip + .L.str]
        mov     rsi, rbx
        xor     eax, eax
        add     rsp, 8
        pop     rbx
        pop     r12
        pop     r14
        pop     r15
        jmp     printf@PLT                      # TAILCALL
.LBB1_1:
        mov     r14, rdx
        mov     r15, rsi
        jmp     .LBB1_2
.LBB1_5:                                #   in Loop: Header=BB1_2 Depth=1
        imul    r12, r14
        add     r15, r12
        cmp     r15, rbx
        je      .LBB1_6
.LBB1_2:                                # =&gt;This Inner Loop Header: Depth=1
        cmp     r15, rbx
        ja      .LBB1_7
        mov     rax, r15
        xor     rax, rbx
        blsi    r12, rax
        test    r12, rbx
        je      .LBB1_5
        mov     rdi, rbx
        sub     rdi, r12
        mov     rsi, r15
        mov     rdx, r14
        call    Find
        jmp     .LBB1_5
.LBB1_7:
        add     rsp, 8
        pop     rbx
        pop     r12
        pop     r14
        pop     r15
        ret

Enumerate and Check the Multiples Instead of the Subsets

In addition to having more parallelism, this has several advantages in speed:

  • Finding the successor, or (i+4)*p given i*p to use this on a vector of four elements, can be strength-reduced to a single addition.
  • Testing whether a factor is a subset is a single and operation, whereas testing whether a subset is a factor requires a % operation, which most CPUs do not have as a native instruction and is always the slowest ALU operation even when it is there.

So, a version of this code that uses both multi-threading and SIMD for speed-up:

#include &lt;assert.h&gt;
#include &lt;omp.h&gt;
#include &lt;stdbool.h&gt;
#include &lt;stdint.h&gt;
#include &lt;stdio.h&gt;
#include &lt;stdlib.h&gt;


typedef uint_fast32_t word;

/* Sets each element results[i], where i &lt;= mask/factor, to true if factor*i
 * is a subset of the mask, false otherwise.  The results array MUST have at
 * least (mask/factor + 1U) elements.  The capacity of results in elements is
 * required and checked, just in case.
 *
 * Returns a pointer to the results.
 */
static bool* check_multiples( const word mask,
                              const word factor,
                              const size_t n,
                              bool results[n] )
{
    const word end = mask/factor;
    const word complement = ~mask;
    assert(&amp;results);
    assert(n &gt; end);

    #pragma omp parallel for simd schedule(static)
    for (word i = 0; i &lt;= end; ++i) {
        results[i] = (factor*i &amp; complement) == 0;
    }

    return results;
}

/* Replace these with non-constants so that the compiler actually
 * actually instantiates the function:
 */
/*
#define MASK 0xA0A0UL
#define FACTOR 0x50UL
#define NRESULTS (MASK/FACTOR + 1U)
 */
extern const word MASK, FACTOR;
#define NRESULTS 1024UL

int main(void)
{
    bool are_subsets[NRESULTS] = {false};
    (void)check_multiples(MASK, FACTOR, NRESULTS, are_subsets);

    for (word i = 0; i &lt; NRESULTS; ++i) {
        if (are_subsets[i]) {
            const unsigned long long multiple = (unsigned long long)FACTOR*i;
            printf(&quot;%llx &quot;, multiple);
            assert((multiple &amp; MASK) == multiple &amp;&amp; (multiple &amp; ~MASK) == 0U);
        }
    }

    return EXIT_SUCCESS;
}

The inner loop of check_multiples compiles, on ICX 2022, to:

.LBB1_5:                                # =&gt;This Inner Loop Header: Depth=1
        vpmullq         ymm15, ymm1, ymm0
        vpmullq         ymm16, ymm2, ymm0
        vpmullq         ymm17, ymm3, ymm0
        vpmullq         ymm18, ymm4, ymm0
        vpmullq         ymm19, ymm5, ymm0
        vpmullq         ymm20, ymm6, ymm0
        vpmullq         ymm21, ymm7, ymm0
        vpmullq         ymm22, ymm8, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm21, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm20, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm16, ymm9
        vptestnmq       k3, ymm15, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi], ymm15
        vpaddq          ymm15, ymm11, ymm7
        vpaddq          ymm16, ymm6, ymm11
        vpaddq          ymm17, ymm5, ymm11
        vpaddq          ymm18, ymm4, ymm11
        vpaddq          ymm19, ymm3, ymm11
        vpaddq          ymm20, ymm2, ymm11
        vpaddq          ymm21, ymm1, ymm11
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm11
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 32], ymm15
        vpaddq          ymm15, ymm12, ymm7
        vpaddq          ymm16, ymm6, ymm12
        vpaddq          ymm17, ymm5, ymm12
        vpaddq          ymm18, ymm4, ymm12
        vpaddq          ymm19, ymm3, ymm12
        vpaddq          ymm20, ymm2, ymm12
        vpaddq          ymm21, ymm1, ymm12
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm12
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 64], ymm15
        vpaddq          ymm15, ymm13, ymm7
        vpaddq          ymm16, ymm6, ymm13
        vpaddq          ymm17, ymm5, ymm13
        vpaddq          ymm18, ymm4, ymm13
        vpaddq          ymm19, ymm3, ymm13
        vpaddq          ymm20, ymm2, ymm13
        vpaddq          ymm21, ymm1, ymm13
        vpmullq         ymm21, ymm21, ymm0
        vpmullq         ymm20, ymm20, ymm0
        vpmullq         ymm19, ymm19, ymm0
        vpmullq         ymm18, ymm18, ymm0
        vpmullq         ymm17, ymm17, ymm0
        vpmullq         ymm16, ymm16, ymm0
        vpmullq         ymm15, ymm15, ymm0
        vpaddq          ymm22, ymm8, ymm13
        vpmullq         ymm22, ymm22, ymm0
        vptestnmq       k0, ymm22, ymm9
        vptestnmq       k1, ymm15, ymm9
        kshiftlb        k1, k1, 4
        korb            k0, k0, k1
        vptestnmq       k1, ymm16, ymm9
        vptestnmq       k2, ymm17, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        kunpckbw        k0, k1, k0
        vptestnmq       k1, ymm18, ymm9
        vptestnmq       k2, ymm19, ymm9
        kshiftlb        k2, k2, 4
        korb            k1, k1, k2
        vptestnmq       k2, ymm20, ymm9
        vptestnmq       k3, ymm21, ymm9
        kshiftlb        k3, k3, 4
        korb            k2, k2, k3
        kunpckbw        k1, k2, k1
        kunpckwd        k1, k1, k0
        vmovdqu8        ymm15 {k1} {z}, ymm10
        vmovdqu         ymmword ptr [rbx + rsi + 96], ymm15
        vpaddq          ymm8, ymm8, ymm14
        vpaddq          ymm7, ymm14, ymm7
        vpaddq          ymm6, ymm14, ymm6
        vpaddq          ymm5, ymm14, ymm5
        vpaddq          ymm4, ymm14, ymm4
        vpaddq          ymm3, ymm14, ymm3
        vpaddq          ymm2, ymm14, ymm2
        vpaddq          ymm1, ymm14, ymm1
        sub             rsi, -128
        add             rdi, -4
        jne             .LBB1_5

I encourage you to try your variations on the algorithm in this compiler, under the same settings, and see what happens. If you think it should be possible to generate vectorized code on the subsets as good as that, you should get some practice.

A Possible Improvement

The number of candidates to check could get extremely large, but one way to limit it is to also compute the multiplicative inverse of P, and use that if it is better.

Every value of P decomposes into 2ⁱ · Q, where Q is odd. Since Q and 2⁶⁴ are coprime, Q will have a modular multiplicative inverse, Q', whose product QQ' = 1 (mod 2⁶⁴). You can find this with the extended Euclidean algorithm (but not the method I proposed here initially).

This is useful for optimizing the algorithm because, for many values of P, Q' < P. If m is a solution, m = nP for some integer n. Multiply both sides by Q', and Q'Pm = 2ⁱ · m = Q'n. This means we can enumerate (with a bit of extra logic to make sure they have enough trailing zero bits) the multiples of Q' or of P. Note that, since Q' is odd, it is not necessary to check all multiples of Q'; if the constant in front of m is 4, for example, you need only check the products of 4·Q'.

huangapple
  • 本文由 发表于 2023年6月5日 05:25:53
  • 转载请务必保留本文链接:https://go.coder-hub.com/76402482.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定