You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.0 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import torch.multiprocessing as mp
  2. from queue import Queue
  3. import time
  4. import random
  5. def dummy_func(dev,cfg):
  6. time.sleep(random.random()*2)
  7. def dummy_config():
  8. return list(range(20))
  9. def mp_exec(resources,configs,func):
  10. '''
  11. @ resources : list of gpu devices
  12. @ configs : list of params
  13. @ func : f(dev,cfg)
  14. '''
  15. q=Queue()
  16. ret=Queue()
  17. for res in resources:
  18. q.put(res)
  19. pool=mp.Pool()
  20. def put_back_dev(dev,cfg):
  21. def callback(*args):
  22. print(f"Device {dev} Finish cfg {cfg} ")
  23. q.put(dev)
  24. ret.put([cfg,args])
  25. print(*args)
  26. return callback
  27. for idx,cfg in enumerate(configs):
  28. dev = q.get()
  29. print(f"Start config {cfg} on device {dev}")
  30. pool.apply_async(func,args=[dev,cfg],callback=put_back_dev(dev,cfg),error_callback=put_back_dev(dev,cfg))
  31. pool.close()
  32. pool.join()
  33. lret=[]
  34. while not ret.empty():
  35. lret.append(ret.get())
  36. return lret