如何在使用多进程池时传递可迭代对象的索引

huangapple go评论92阅读模式
英文:

How to pass the index of the iterable when using multiprocessing pool

问题

I would like to call a function task() in parallel N times. The function accepts two arguments, one is an array and the second is an index to write the return result in to the array:

我想并行调用函数 task() N 次。该函数接受两个参数,一个是数组,第二个是要将返回结果写入数组的索引:

  1. def task(arr, index):
  2. arr[index] = "some result to return"

To be explicit, the reason for the array is so I can process all the parallel tasks once they have completed. I presume this is ok?

明确一下,数组的原因是为了在并行任务完成后能够处理它们。我假设这样做没问题?

I have created a multiprocessing pool and it calls task():

我创建了一个多进程池,并在其中调用 task()

  1. def main():
  2. N = 10
  3. arr = np.empty(N)
  4. pool = Pool(os.cpu_count())
  5. pool.map(task, arr)
  6. pool.close()
  7. # Process results in arr

However, the problem is because map() is already iterable, how do I explicitly pass in the index? Each call to task() should pass in 0, 1, 2.... N.

然而,问题在于 map() 已经是可迭代的,如何明确传递索引?每次调用 task() 应该传入 0、1、2.... N。

英文:

I would like to call a function task() in parallel N times. The function accepts two arguments, one is an array and the second is an index to write the return result in to the array:

  1. def task(arr, index):
  2. arr[index] = "some result to return"

To be explicit, the reason for the array is so I can process all the parallel tasks once they have completed. I presume this is ok?

I have created a multiprocessing pool and it calls task():

  1. def main():
  2. N = 10
  3. arr = np.empty(N)
  4. pool = Pool(os.cpu_count())
  5. pool.map(task, arr)
  6. pool.close()
  7. # Process results in arr

However, the problem is because map() is already iterable, how do I explicitly pass in the index? Each call to task() should pass in 0, 1, 2.... N.

答案1

得分: 1

  1. import multiprocessing as mp
  2. import numpy as np
  3. def task(index, arr):
  4. print(index, arr)
  5. if __name__ == '__main__':
  6. N = 10
  7. arr = np.empty(N)
  8. with mp.Pool(mp.cpu_count()) as pool:
  9. pool.starmap(task, enumerate(arr))

Output:

  1. 0 6.9180446290108e-310
  2. 1 6.9180446290108e-310
  3. 2 6.91804453329406e-310
  4. 3 6.91804425777776e-310
  5. 4 6.9180448957438e-310
  6. 5 6.9180105412701e-310
  7. 6 6.9180443068017e-310
  8. 7 6.91804453327193e-310
  9. 9 6.9180449088978e-310
  10. 8 6.91804436388567e-310
英文:

You can use:

  1. import multiprocessing as mp
  2. import numpy as np
  3. def task(index, arr):
  4. print(index, arr)
  5. if __name__ == '__main__':
  6. N = 10
  7. arr = np.empty(N)
  8. with mp.Pool(mp.cpu_count()) as pool:
  9. pool.starmap(task, enumerate(arr))

Output:

  1. 0 6.9180446290108e-310
  2. 1 6.9180446290108e-310
  3. 2 6.91804453329406e-310
  4. 3 6.91804425777776e-310
  5. 4 6.9180448957438e-310
  6. 5 6.9180105412701e-310
  7. 6 6.9180443068017e-310
  8. 7 6.91804453327193e-310
  9. 9 6.9180449088978e-310
  10. 8 6.91804436388567e-310

huangapple
  • 本文由 发表于 2023年3月7日 23:45:16
  • 转载请务必保留本文链接:https://go.coder-hub.com/75664160.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定