You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 5.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Evaluation for Deeptext"""
  16. import argparse
  17. import os
  18. import time
  19. import numpy as np
  20. from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
  21. from src.config import config
  22. from src.dataset import data_to_mindrecord_byte_image, create_deeptext_dataset
  23. from src.utils import metrics
  24. from mindspore import context
  25. from mindspore.common import set_seed
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. set_seed(1)
  28. parser = argparse.ArgumentParser(description="Deeptext evaluation")
  29. parser.add_argument("--checkpoint_path", type=str, default='test', help="Checkpoint file path.")
  30. parser.add_argument("--imgs_path", type=str, required=True,
  31. help="Test images files paths, multiple paths can be separated by ','.")
  32. parser.add_argument("--annos_path", type=str, required=True,
  33. help="Annotations files paths of test images, multiple paths can be separated by ','.")
  34. parser.add_argument("--device_id", type=int, default=7, help="Device id, default is 7.")
  35. parser.add_argument("--mindrecord_prefix", type=str, default='Deeptext-TEST', help="Prefix of mindrecord.")
  36. args_opt = parser.parse_args()
  37. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
  38. def deeptext_eval_test(dataset_path='', ckpt_path=''):
  39. """Deeptext evaluation."""
  40. ds = create_deeptext_dataset(dataset_path, batch_size=config.test_batch_size,
  41. repeat_num=1, is_training=False)
  42. total = ds.get_dataset_size()
  43. net = Deeptext_VGG16(config)
  44. param_dict = load_checkpoint(ckpt_path)
  45. load_param_into_net(net, param_dict)
  46. net.set_train(False)
  47. eval_iter = 0
  48. print("\n========================================\n")
  49. print("Processing, please wait a moment.")
  50. max_num = 32
  51. pred_data = []
  52. for data in ds.create_dict_iterator():
  53. eval_iter = eval_iter + 1
  54. img_data = data['image']
  55. img_metas = data['image_shape']
  56. gt_bboxes = data['box']
  57. gt_labels = data['label']
  58. gt_num = data['valid_num']
  59. start = time.time()
  60. # run net
  61. output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
  62. gt_bboxes = gt_bboxes.asnumpy()
  63. gt_bboxes = gt_bboxes[gt_num.asnumpy().astype(bool), :]
  64. print(gt_bboxes)
  65. gt_labels = gt_labels.asnumpy()
  66. gt_labels = gt_labels[gt_num.asnumpy().astype(bool)]
  67. print(gt_labels)
  68. end = time.time()
  69. print("Iter {} cost time {}".format(eval_iter, end - start))
  70. # output
  71. all_bbox = output[0]
  72. all_label = output[1] + 1
  73. all_mask = output[2]
  74. for j in range(config.test_batch_size):
  75. all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :])
  76. all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :])
  77. all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :])
  78. all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :]
  79. all_labels_tmp_mask = all_label_squee[all_mask_squee]
  80. if all_bboxes_tmp_mask.shape[0] > max_num:
  81. inds = np.argsort(-all_bboxes_tmp_mask[:, -1])
  82. inds = inds[:max_num]
  83. all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds]
  84. all_labels_tmp_mask = all_labels_tmp_mask[inds]
  85. pred_data.append({"boxes": all_bboxes_tmp_mask,
  86. "labels": all_labels_tmp_mask,
  87. "gt_bboxes": gt_bboxes,
  88. "gt_labels": gt_labels})
  89. percent = round(eval_iter / total * 100, 2)
  90. print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
  91. precisions, recalls = metrics(pred_data)
  92. print("\n========================================\n")
  93. for i in range(config.num_classes - 1):
  94. j = i + 1
  95. f1 = (2 * precisions[j] * recalls[j]) / (precisions[j] + recalls[j] + 1e-6)
  96. print("class {} precision is {:.2f}%, recall is {:.2f}%,"
  97. "F1 is {:.2f}%".format(j, precisions[j] * 100, recalls[j] * 100, f1 * 100))
  98. if config.use_ambigous_sample:
  99. break
  100. if __name__ == '__main__':
  101. prefix = args_opt.mindrecord_prefix
  102. config.test_images = args_opt.imgs_path
  103. config.test_txts = args_opt.annos_path
  104. mindrecord_dir = config.mindrecord_dir
  105. mindrecord_file = os.path.join(mindrecord_dir, prefix)
  106. print("CHECKING MINDRECORD FILES ...")
  107. if not os.path.exists(mindrecord_file):
  108. if not os.path.isdir(mindrecord_dir):
  109. os.makedirs(mindrecord_dir)
  110. print("Create Mindrecord. It may take some time.")
  111. data_to_mindrecord_byte_image(False, prefix, file_num=1)
  112. print("Create Mindrecord Done, at {}".format(mindrecord_dir))
  113. print("CHECKING MINDRECORD FILES DONE!")
  114. print("Start Eval!")
  115. deeptext_eval_test(mindrecord_file, args_opt.checkpoint_path)