You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """cnnctc eval"""
  16. import argparse
  17. import time
  18. import numpy as np
  19. from mindspore import Tensor, context
  20. import mindspore.common.dtype as mstype
  21. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  22. from mindspore.dataset import GeneratorDataset
  23. from src.util import CTCLabelConverter, AverageMeter
  24. from src.config import Config_CNNCTC
  25. from src.dataset import IIIT_Generator_batch
  26. from src.cnn_ctc import CNNCTC_Model
  27. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
  28. save_graphs_path=".", enable_auto_mixed_precision=False)
  29. def test_dataset_creator():
  30. ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str'])
  31. return ds
  32. def test(config):
  33. ds = test_dataset_creator()
  34. net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
  35. ckpt_path = config.CKPT_PATH
  36. param_dict = load_checkpoint(ckpt_path)
  37. load_param_into_net(net, param_dict)
  38. print('parameters loaded! from: ', ckpt_path)
  39. converter = CTCLabelConverter(config.CHARACTER)
  40. model_run_time = AverageMeter()
  41. npu_to_cpu_time = AverageMeter()
  42. postprocess_time = AverageMeter()
  43. count = 0
  44. correct_count = 0
  45. for data in ds.create_tuple_iterator():
  46. img, _, text, _, length = data
  47. img_tensor = Tensor(img, mstype.float32)
  48. model_run_begin = time.time()
  49. model_predict = net(img_tensor)
  50. model_run_end = time.time()
  51. model_run_time.update(model_run_end - model_run_begin)
  52. npu_to_cpu_begin = time.time()
  53. model_predict = np.squeeze(model_predict.asnumpy())
  54. npu_to_cpu_end = time.time()
  55. npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin)
  56. postprocess_begin = time.time()
  57. preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
  58. preds_index = np.argmax(model_predict, 2)
  59. preds_index = np.reshape(preds_index, [-1])
  60. preds_str = converter.decode(preds_index, preds_size)
  61. postprocess_end = time.time()
  62. postprocess_time.update(postprocess_end - postprocess_begin)
  63. label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())
  64. if count == 0:
  65. model_run_time.reset()
  66. npu_to_cpu_time.reset()
  67. postprocess_time.reset()
  68. else:
  69. print('---------model run time--------', model_run_time.avg)
  70. print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg)
  71. print('---------postprocess run time--------', postprocess_time.avg)
  72. print("Prediction samples: \n", preds_str[:5])
  73. print("Ground truth: \n", label_str[:5])
  74. for pred, label in zip(preds_str, label_str):
  75. if pred == label:
  76. correct_count += 1
  77. count += 1
  78. print('accuracy: ', correct_count / count)
  79. if __name__ == '__main__':
  80. parser = argparse.ArgumentParser(description="FasterRcnn training")
  81. parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
  82. parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.")
  83. args_opt = parser.parse_args()
  84. cfg = Config_CNNCTC()
  85. if args_opt.ckpt_path != "":
  86. cfg.CKPT_PATH = args_opt.ckpt_path
  87. test(cfg)