|
- import torch.multiprocessing as mp
- from queue import Queue
- import time
- import random
-
- def dummy_func(dev,cfg):
- time.sleep(random.random()*2)
- def dummy_config():
- return list(range(20))
-
- def mp_exec(resources,configs,func):
- '''
- @ resources : list of gpu devices
- @ configs : list of params
- @ func : f(dev,cfg)
- '''
- q=Queue()
- ret=Queue()
- for res in resources:
- q.put(res)
- pool=mp.Pool()
- def put_back_dev(dev,cfg):
- def callback(*args):
- print(f"Device {dev} Finish cfg {cfg} ")
- q.put(dev)
- ret.put([cfg,args])
- print(*args)
- return callback
-
- for idx,cfg in enumerate(configs):
- dev = q.get()
- print(f"Start config {cfg} on device {dev}")
- pool.apply_async(func,args=[dev,cfg],callback=put_back_dev(dev,cfg),error_callback=put_back_dev(dev,cfg))
-
- pool.close()
- pool.join()
-
- lret=[]
- while not ret.empty():
- lret.append(ret.get())
- return lret
|