You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 6.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import pickle
  4. import shutil
  5. import tempfile
  6. import time
  7. import mmcv
  8. import torch
  9. import torch.distributed as dist
  10. from mmcv.image import tensor2imgs
  11. from mmcv.runner import get_dist_info
  12. from mmdet.core import encode_mask_results
  13. def single_gpu_test(model,
  14. data_loader,
  15. show=False,
  16. out_dir=None,
  17. show_score_thr=0.3):
  18. model.eval()
  19. results = []
  20. dataset = data_loader.dataset
  21. prog_bar = mmcv.ProgressBar(len(dataset))
  22. for i, data in enumerate(data_loader):
  23. with torch.no_grad():
  24. result = model(return_loss=False, rescale=True, **data)
  25. batch_size = len(result)
  26. if show or out_dir:
  27. if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
  28. img_tensor = data['img'][0]
  29. else:
  30. img_tensor = data['img'][0].data[0]
  31. img_metas = data['img_metas'][0].data[0]
  32. imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
  33. assert len(imgs) == len(img_metas)
  34. for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
  35. h, w, _ = img_meta['img_shape']
  36. img_show = img[:h, :w, :]
  37. ori_h, ori_w = img_meta['ori_shape'][:-1]
  38. img_show = mmcv.imresize(img_show, (ori_w, ori_h))
  39. if out_dir:
  40. out_file = osp.join(out_dir, img_meta['ori_filename'])
  41. else:
  42. out_file = None
  43. model.module.show_result(
  44. img_show,
  45. result[i],
  46. show=show,
  47. out_file=out_file,
  48. score_thr=show_score_thr)
  49. # encode mask results
  50. if isinstance(result[0], tuple):
  51. result = [(bbox_results, encode_mask_results(mask_results))
  52. for bbox_results, mask_results in result]
  53. results.extend(result)
  54. for _ in range(batch_size):
  55. prog_bar.update()
  56. return results
  57. def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
  58. """Test model with multiple gpus.
  59. This method tests model with multiple gpus and collects the results
  60. under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
  61. it encodes results to gpu tensors and use gpu communication for results
  62. collection. On cpu mode it saves the results on different gpus to 'tmpdir'
  63. and collects them by the rank 0 worker.
  64. Args:
  65. model (nn.Module): Model to be tested.
  66. data_loader (nn.Dataloader): Pytorch data loader.
  67. tmpdir (str): Path of directory to save the temporary results from
  68. different gpus under cpu mode.
  69. gpu_collect (bool): Option to use either gpu or cpu to collect results.
  70. Returns:
  71. list: The prediction results.
  72. """
  73. model.eval()
  74. results = []
  75. dataset = data_loader.dataset
  76. rank, world_size = get_dist_info()
  77. if rank == 0:
  78. prog_bar = mmcv.ProgressBar(len(dataset))
  79. time.sleep(2) # This line can prevent deadlock problem in some cases.
  80. for i, data in enumerate(data_loader):
  81. with torch.no_grad():
  82. result = model(return_loss=False, rescale=True, **data)
  83. # encode mask results
  84. if isinstance(result[0], tuple):
  85. result = [(bbox_results, encode_mask_results(mask_results))
  86. for bbox_results, mask_results in result]
  87. results.extend(result)
  88. if rank == 0:
  89. batch_size = len(result)
  90. for _ in range(batch_size * world_size):
  91. prog_bar.update()
  92. # collect results from all ranks
  93. if gpu_collect:
  94. results = collect_results_gpu(results, len(dataset))
  95. else:
  96. results = collect_results_cpu(results, len(dataset), tmpdir)
  97. return results
  98. def collect_results_cpu(result_part, size, tmpdir=None):
  99. rank, world_size = get_dist_info()
  100. # create a tmp dir if it is not specified
  101. if tmpdir is None:
  102. MAX_LEN = 512
  103. # 32 is whitespace
  104. dir_tensor = torch.full((MAX_LEN, ),
  105. 32,
  106. dtype=torch.uint8,
  107. device='cuda')
  108. if rank == 0:
  109. mmcv.mkdir_or_exist('.dist_test')
  110. tmpdir = tempfile.mkdtemp(dir='.dist_test')
  111. tmpdir = torch.tensor(
  112. bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
  113. dir_tensor[:len(tmpdir)] = tmpdir
  114. dist.broadcast(dir_tensor, 0)
  115. tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
  116. else:
  117. mmcv.mkdir_or_exist(tmpdir)
  118. # dump the part result to the dir
  119. mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
  120. dist.barrier()
  121. # collect all parts
  122. if rank != 0:
  123. return None
  124. else:
  125. # load results of all parts from tmp dir
  126. part_list = []
  127. for i in range(world_size):
  128. part_file = osp.join(tmpdir, f'part_{i}.pkl')
  129. part_list.append(mmcv.load(part_file))
  130. # sort the results
  131. ordered_results = []
  132. for res in zip(*part_list):
  133. ordered_results.extend(list(res))
  134. # the dataloader may pad some samples
  135. ordered_results = ordered_results[:size]
  136. # remove tmp dir
  137. shutil.rmtree(tmpdir)
  138. return ordered_results
  139. def collect_results_gpu(result_part, size):
  140. rank, world_size = get_dist_info()
  141. # dump result part to tensor with pickle
  142. part_tensor = torch.tensor(
  143. bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
  144. # gather all result part tensor shape
  145. shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
  146. shape_list = [shape_tensor.clone() for _ in range(world_size)]
  147. dist.all_gather(shape_list, shape_tensor)
  148. # padding result part tensor to max length
  149. shape_max = torch.tensor(shape_list).max()
  150. part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
  151. part_send[:shape_tensor[0]] = part_tensor
  152. part_recv_list = [
  153. part_tensor.new_zeros(shape_max) for _ in range(world_size)
  154. ]
  155. # gather all result part
  156. dist.all_gather(part_recv_list, part_send)
  157. if rank == 0:
  158. part_list = []
  159. for recv, shape in zip(part_recv_list, shape_list):
  160. part_list.append(
  161. pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
  162. # sort the results
  163. ordered_results = []
  164. for res in zip(*part_list):
  165. ordered_results.extend(list(res))
  166. # the dataloader may pad some samples
  167. ordered_results = ordered_results[:size]
  168. return ordered_results

No Description

Contributors (3)