You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import argparse
  16. import os
  17. import time
  18. import numpy as np
  19. from mindspore import Tensor, float32, context
  20. from mindspore.common import set_seed
  21. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  22. from src.config import config
  23. from src.dataset import flip_pairs, keypoint_dataset
  24. from src.evaluate.coco_eval import evaluate
  25. from src.model import get_pose_net
  26. from src.utils.transform import flip_back
  27. from src.predict import get_final_preds
  28. def parse_args():
  29. parser = argparse.ArgumentParser(description='Train keypoints network')
  30. parser.add_argument("--train_url", type=str, default="", help="")
  31. parser.add_argument("--data_url", type=str, default="", help="data")
  32. # output
  33. parser.add_argument('--output-url',
  34. help='output dir',
  35. type=str)
  36. # training
  37. parser.add_argument('--workers',
  38. help='num of dataloader workers',
  39. default=8,
  40. type=int)
  41. parser.add_argument('--model-file',
  42. help='model state file',
  43. type=str)
  44. parser.add_argument('--use-detect-bbox',
  45. help='use detect bbox',
  46. action='store_true')
  47. parser.add_argument('--flip-test',
  48. help='use flip test',
  49. default=True,
  50. action='store_true')
  51. parser.add_argument('--post-process',
  52. help='use post process',
  53. action='store_true')
  54. parser.add_argument('--shift-heatmap',
  55. help='shift heatmap',
  56. action='store_true')
  57. parser.add_argument('--coco-bbox-file',
  58. help='coco detection bbox file',
  59. type=str)
  60. args = parser.parse_args()
  61. return args
  62. def reset_config(cfg, args):
  63. if args.use_detect_bbox:
  64. cfg.TEST.USE_GT_BBOX = not args.use_detect_bbox
  65. if args.flip_test:
  66. cfg.TEST.FLIP_TEST = args.flip_test
  67. print('use flip test:', cfg.TEST.FLIP_TEST)
  68. if args.post_process:
  69. cfg.TEST.POST_PROCESS = args.post_process
  70. if args.shift_heatmap:
  71. cfg.TEST.SHIFT_HEATMAP = args.shift_heatmap
  72. if args.model_file:
  73. cfg.TEST.MODEL_FILE = args.model_file
  74. if args.coco_bbox_file:
  75. cfg.TEST.COCO_BBOX_FILE = args.coco_bbox_file
  76. def validate(cfg, val_dataset, model, output_dir):
  77. # switch to evaluate mode
  78. model.set_train(False)
  79. # init record
  80. num_samples = val_dataset.get_dataset_size() * cfg.TEST.BATCH_SIZE
  81. all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3),
  82. dtype=np.float32)
  83. all_boxes = np.zeros((num_samples, 2))
  84. image_id = []
  85. idx = 0
  86. # start eval
  87. start = time.time()
  88. for item in val_dataset.create_dict_iterator():
  89. # input data
  90. inputs = item['image'].asnumpy()
  91. # compute output
  92. output = model(Tensor(inputs, float32)).asnumpy()
  93. if cfg.TEST.FLIP_TEST:
  94. inputs_flipped = Tensor(inputs[:, :, :, ::-1], float32)
  95. output_flipped = model(inputs_flipped)
  96. output_flipped = flip_back(output_flipped.asnumpy(), flip_pairs)
  97. # feature is not aligned, shift flipped heatmap for higher accuracy
  98. if cfg.TEST.SHIFT_HEATMAP:
  99. output_flipped[:, :, :, 1:] = \
  100. output_flipped.copy()[:, :, :, 0:-1]
  101. output = (output + output_flipped) * 0.5
  102. # meta data
  103. c = item['center'].asnumpy()
  104. s = item['scale'].asnumpy()
  105. score = item['score'].asnumpy()
  106. file_id = list(item['id'].asnumpy())
  107. # pred by heatmaps
  108. preds, maxvals = get_final_preds(cfg, output.copy(), c, s)
  109. num_images, _ = preds.shape[:2]
  110. all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
  111. all_preds[idx:idx + num_images, :, 2:3] = maxvals
  112. # double check this all_boxes parts
  113. all_boxes[idx:idx + num_images, 0] = np.prod(s * 200, 1)
  114. all_boxes[idx:idx + num_images, 1] = score
  115. image_id.extend(file_id)
  116. idx += num_images
  117. if idx % 1024 == 0:
  118. print('{} samples validated in {} seconds'.format(idx, time.time() - start))
  119. start = time.time()
  120. print(all_preds[:idx].shape, all_boxes[:idx].shape, len(image_id))
  121. _, perf_indicator = evaluate(
  122. cfg, all_preds[:idx], output_dir, all_boxes[:idx], image_id)
  123. print("AP:", perf_indicator)
  124. return perf_indicator
  125. def main():
  126. # init seed
  127. set_seed(1)
  128. # set context
  129. device_id = int(os.getenv('DEVICE_ID'))
  130. context.set_context(mode=context.GRAPH_MODE,
  131. device_target="Ascend", save_graphs=False, device_id=device_id)
  132. args = parse_args()
  133. # update config
  134. reset_config(config, args)
  135. # init model
  136. model = get_pose_net(config, is_train=False)
  137. # load parameters
  138. ckpt_name = config.TEST.MODEL_FILE
  139. print('loading model ckpt from {}'.format(ckpt_name))
  140. load_param_into_net(model, load_checkpoint(ckpt_name))
  141. # Data loading code
  142. valid_dataset, _ = keypoint_dataset(
  143. config,
  144. bbox_file=config.TEST.COCO_BBOX_FILE,
  145. train_mode=False,
  146. num_parallel_workers=args.workers,
  147. )
  148. # evaluate on validation set
  149. validate(config, valid_dataset, model, ckpt_name.split('.')[0])
  150. if __name__ == '__main__':
  151. main()