You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Test centerface example
  17. """
  18. import os
  19. import time
  20. import argparse
  21. import datetime
  22. import scipy.io as sio
  23. from mindspore import context
  24. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  25. from src.utils import get_logger
  26. from src.var_init import default_recurisive_init
  27. from src.centerface import CenterfaceMobilev2, CenterFaceWithNms
  28. from src.config import ConfigCenterface
  29. from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
  30. from dependency.evaluate.eval import evaluation
  31. dev_id = int(os.getenv('DEVICE_ID'))
  32. context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
  33. device_target="Ascend", save_graphs=False, device_id=dev_id)
  34. parser = argparse.ArgumentParser('mindspore coco training')
  35. parser.add_argument('--data_dir', type=str, default='', help='train data dir')
  36. parser.add_argument('--test_model', type=str, default='', help='test model dir')
  37. parser.add_argument('--ground_truth_mat', type=str, default='', help='ground_truth, mat type')
  38. parser.add_argument('--save_dir', type=str, default='', help='save_path for evaluate')
  39. parser.add_argument('--ground_truth_path', type=str, default='', help='ground_truth path, contain all mat file')
  40. parser.add_argument('--eval', type=int, default=0, help='if do eval after test')
  41. parser.add_argument('--eval_script_path', type=str, default='', help='evaluate script path')
  42. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  43. parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
  44. parser.add_argument('--ckpt_name', type=str, default="", help='input model name')
  45. parser.add_argument('--device_num', type=int, default=1, help='device num for testing')
  46. parser.add_argument('--steps_per_epoch', type=int, default=198, help='steps for each epoch')
  47. parser.add_argument('--start', type=int, default=0, help='start loop number, used to calculate first epoch number')
  48. parser.add_argument('--end', type=int, default=18, help='end loop number, used to calculate last epoch number')
  49. args, _ = parser.parse_known_args()
  50. if __name__ == "__main__":
  51. # logger
  52. args.outputs_dir = os.path.join(args.ckpt_path,
  53. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  54. args.logger = get_logger(args.outputs_dir, args.rank)
  55. args.logger.save_args(args)
  56. if args.ckpt_name != "":
  57. args.start = 0
  58. args.end = 1
  59. for loop in range(args.start, args.end, 1):
  60. network = CenterfaceMobilev2()
  61. default_recurisive_init(network)
  62. if args.ckpt_name == "":
  63. ckpt_num = loop * args.device_num + args.rank + 1
  64. ckpt_name = "0-" + str(ckpt_num) + "_" + str(args.steps_per_epoch * ckpt_num) + ".ckpt"
  65. else:
  66. ckpt_name = args.ckpt_name
  67. test_model = args.test_model + ckpt_name
  68. if not test_model:
  69. args.logger.info('load_model {} none'.format(test_model))
  70. continue
  71. if os.path.isfile(test_model):
  72. param_dict = load_checkpoint(test_model)
  73. param_dict_new = {}
  74. for key, values in param_dict.items():
  75. if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
  76. continue
  77. elif key.startswith('centerface_network.'):
  78. param_dict_new[key[19:]] = values
  79. else:
  80. param_dict_new[key] = values
  81. load_param_into_net(network, param_dict_new)
  82. args.logger.info('load_model {} success'.format(test_model))
  83. else:
  84. args.logger.info('{} not exists or not a pre-trained file'.format(test_model))
  85. continue
  86. train_network_type_nms = 1 # default with num
  87. if train_network_type_nms:
  88. network = CenterFaceWithNms(network)
  89. args.logger.info('train network type with nms')
  90. network.set_train(False)
  91. args.logger.info('finish get network')
  92. config = ConfigCenterface()
  93. # test network -----------
  94. start = time.time()
  95. ground_truth_mat = sio.loadmat(args.ground_truth_mat)
  96. event_list = ground_truth_mat['event_list']
  97. file_list = ground_truth_mat['file_list']
  98. if args.ckpt_name == "":
  99. save_path = args.save_dir + str(ckpt_num) + '/'
  100. else:
  101. save_path = args.save_dir+ '/'
  102. detector = CenterFaceDetector(config, network)
  103. for index, event in enumerate(event_list):
  104. file_list_item = file_list[index][0]
  105. im_dir = event[0][0]
  106. if not os.path.exists(save_path + im_dir):
  107. os.makedirs(save_path + im_dir)
  108. args.logger.info('save_path + im_dir={}'.format(save_path + im_dir))
  109. for num, file in enumerate(file_list_item):
  110. im_name = file[0][0]
  111. zip_name = '%s/%s.jpg' % (im_dir, im_name)
  112. img_path = os.path.join(args.data_dir, zip_name)
  113. args.logger.info('img_path={}'.format(img_path))
  114. dets = detector.run(img_path)['results']
  115. f = open(save_path + im_dir + '/' + im_name + '.txt', 'w')
  116. f.write('{:s}\n'.format('%s/%s.jpg' % (im_dir, im_name)))
  117. f.write('{:d}\n'.format(len(dets)))
  118. for b in dets[1]:
  119. x1, y1, x2, y2, s = b[0], b[1], b[2], b[3], b[4]
  120. f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(x1, y1, (x2 - x1 + 1), (y2 - y1 + 1), s))
  121. f.close()
  122. args.logger.info('event:{}, num:{}'.format(index + 1, num + 1))
  123. end = time.time()
  124. args.logger.info("============num {} time {}".format(num, (end-start)*1000))
  125. start = end
  126. if args.eval:
  127. args.logger.info('==========start eval===============')
  128. args.logger.info("test output path = {}".format(save_path))
  129. if os.path.isdir(save_path):
  130. evaluation(save_path, args.ground_truth_path)
  131. else:
  132. args.logger.info('no test output path')
  133. args.logger.info('==========end eval===============')
  134. if args.ckpt_name != "":
  135. break
  136. args.logger.info('==========end testing===============')