如何在使用多进程池时传递可迭代对象的索引

huangapple go评论61阅读模式
英文:

How to pass the index of the iterable when using multiprocessing pool

问题

I would like to call a function task() in parallel N times. The function accepts two arguments, one is an array and the second is an index to write the return result in to the array:

我想并行调用函数 task() N 次。该函数接受两个参数,一个是数组,第二个是要将返回结果写入数组的索引:

def task(arr, index):
    arr[index] = "some result to return"

To be explicit, the reason for the array is so I can process all the parallel tasks once they have completed. I presume this is ok?

明确一下,数组的原因是为了在并行任务完成后能够处理它们。我假设这样做没问题?

I have created a multiprocessing pool and it calls task():

我创建了一个多进程池,并在其中调用 task()

def main():
    N = 10
    arr = np.empty(N)
    
    pool = Pool(os.cpu_count())
    pool.map(task, arr)
    pool.close()

    # Process results in arr

However, the problem is because map() is already iterable, how do I explicitly pass in the index? Each call to task() should pass in 0, 1, 2.... N.

然而,问题在于 map() 已经是可迭代的,如何明确传递索引?每次调用 task() 应该传入 0、1、2.... N。

英文:

I would like to call a function task() in parallel N times. The function accepts two arguments, one is an array and the second is an index to write the return result in to the array:

def task(arr, index):
    arr[index] = "some result to return"

To be explicit, the reason for the array is so I can process all the parallel tasks once they have completed. I presume this is ok?

I have created a multiprocessing pool and it calls task():

def main():
    N = 10
    arr = np.empty(N)
    
    pool = Pool(os.cpu_count())
    pool.map(task, arr)
    pool.close()

    # Process results in arr

However, the problem is because map() is already iterable, how do I explicitly pass in the index? Each call to task() should pass in 0, 1, 2.... N.

答案1

得分: 1

import multiprocessing as mp
import numpy as np

def task(index, arr):
    print(index, arr)

if __name__ == '__main__':
    N = 10
    arr = np.empty(N)
    with mp.Pool(mp.cpu_count()) as pool:
        pool.starmap(task, enumerate(arr))

Output:

0 6.9180446290108e-310
1 6.9180446290108e-310
2 6.91804453329406e-310
3 6.91804425777776e-310
4 6.9180448957438e-310
5 6.9180105412701e-310
6 6.9180443068017e-310
7 6.91804453327193e-310
9 6.9180449088978e-310
8 6.91804436388567e-310
英文:

You can use:

import multiprocessing as mp
import numpy as np

def task(index, arr):
    print(index, arr)

if __name__ == '__main__':
    N = 10
    arr = np.empty(N)
    with mp.Pool(mp.cpu_count()) as pool:
        pool.starmap(task, enumerate(arr))

Output:

0 6.9180446290108e-310
1 6.9180446290108e-310
2 6.91804453329406e-310
3 6.91804425777776e-310
4 6.9180448957438e-310
5 6.9180105412701e-310
6 6.9180443068017e-310
7 6.91804453327193e-310
9 6.9180449088978e-310
8 6.91804436388567e-310

huangapple
  • 本文由 发表于 2023年3月7日 23:45:16
  • 转载请务必保留本文链接:https://go.coder-hub.com/75664160.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定