You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 3.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """DPN model eval with MindSpore"""
  16. import os
  17. import argparse
  18. from mindspore import context
  19. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  20. from mindspore.train.model import Model
  21. from mindspore.common import set_seed
  22. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  23. from src.dpn import dpns
  24. from src.config import config
  25. from src.imagenet_dataset import classification_dataset
  26. set_seed(1)
  27. # set context
  28. device_id = int(os.getenv('DEVICE_ID'))
  29. context.set_context(mode=context.GRAPH_MODE,
  30. device_target="Ascend", save_graphs=False, device_id=device_id)
  31. def parse_args():
  32. """parameters"""
  33. parser = argparse.ArgumentParser('dpn evaluating')
  34. # dataset related
  35. parser.add_argument('--data_dir', type=str, default='', help='eval data dir')
  36. # network related
  37. parser.add_argument('--pretrained', type=str, default='', help='ckpt path to load')
  38. args, _ = parser.parse_known_args()
  39. args.image_size = config.image_size
  40. args.num_classes = config.num_classes
  41. args.batch_size = config.batch_size
  42. args.num_parallel_workers = config.num_parallel_workers
  43. args.backbone = config.backbone
  44. args.loss_scale_num = config.loss_scale_num
  45. args.rank = config.rank
  46. args.group_size = config.group_size
  47. args.dataset = config.dataset
  48. return args
  49. def dpn_evaluate(args):
  50. # create evaluate dataset
  51. eval_path = os.path.join(args.data_dir, 'val')
  52. eval_dataset = classification_dataset(eval_path,
  53. image_size=args.image_size,
  54. num_parallel_workers=args.num_parallel_workers,
  55. per_batch_size=args.batch_size,
  56. max_epoch=1,
  57. rank=args.rank,
  58. shuffle=False,
  59. group_size=args.group_size,
  60. mode='eval')
  61. # create network
  62. net = dpns[args.backbone](num_classes=args.num_classes)
  63. # load checkpoint
  64. load_param_into_net(net, load_checkpoint(args.pretrained))
  65. print("load checkpoint from [{}].".format(args.pretrained))
  66. # loss
  67. if args.dataset == "imagenet-1K":
  68. loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  69. else:
  70. if not args.label_smooth:
  71. args.label_smooth_factor = 0.0
  72. loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
  73. # create model
  74. model = Model(net, amp_level="O2", keep_batchnorm_fp32=False, loss_fn=loss,
  75. metrics={'top_1_accuracy', 'top_5_accuracy'})
  76. # evaluate
  77. output = model.eval(eval_dataset)
  78. print(f'Evaluation result: {output}.')
  79. if __name__ == '__main__':
  80. dpn_evaluate(parse_args())
  81. print('DPN evaluate success!')