You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 3.8 kB

4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. ######################## eval alexnet example ########################
  17. eval alexnet according to model file:
  18. python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
  19. """
  20. import os
  21. # import sys
  22. # sys.path.append(os.path.join(os.getcwd(), 'utils'))
  23. from utils.config import config
  24. from utils.moxing_adapter import moxing_wrapper
  25. from utils.device_adapter import get_device_id, get_device_num
  26. from src.dataset import create_dataset_cifar10, create_dataset_imagenet
  27. from src.alexnet import AlexNet
  28. import mindspore.nn as nn
  29. from mindspore import context
  30. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  31. from mindspore.train import Model
  32. from mindspore.nn.metrics import Accuracy
  33. from mindspore.communication.management import init
  34. if os.path.exists(config.data_path_local):
  35. config.data_path = config.data_path_local
  36. load_path = config.ckpt_path_local
  37. else:
  38. load_path = os.path.join(config.data_path, 'checkpoint_alexnet-30_1562.ckpt')
  39. def modelarts_process():
  40. pass
  41. @moxing_wrapper(pre_process=modelarts_process)
  42. def eval_alexnet():
  43. print("============== Starting Testing ==============")
  44. device_num = get_device_num()
  45. if device_num > 1:
  46. # context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
  47. context.set_context(mode=context.GRAPH_MODE, device_target='Davinci', save_graphs=False)
  48. if config.device_target == "Ascend":
  49. context.set_context(device_id=get_device_id())
  50. init()
  51. elif config.device_target == "GPU":
  52. init()
  53. if config.dataset_name == 'cifar10':
  54. network = AlexNet(config.num_classes, phase='test')
  55. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  56. opt = nn.Momentum(network.trainable_params(), config.learning_rate, config.momentum)
  57. ds_eval = create_dataset_cifar10(config.data_path, config.batch_size, status="test", \
  58. target=config.device_target)
  59. param_dict = load_checkpoint(load_path)
  60. print("load checkpoint from [{}].".format(load_path))
  61. load_param_into_net(network, param_dict)
  62. network.set_train(False)
  63. model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})
  64. elif config.dataset_name == 'imagenet':
  65. network = AlexNet(config.num_classes, phase='test')
  66. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  67. ds_eval = create_dataset_imagenet(config.data_path, config.batch_size, training=False)
  68. param_dict = load_checkpoint(load_path)
  69. print("load checkpoint from [{}].".format(load_path))
  70. load_param_into_net(network, param_dict)
  71. network.set_train(False)
  72. model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
  73. else:
  74. raise ValueError("Unsupported dataset.")
  75. if ds_eval.get_dataset_size() == 0:
  76. raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
  77. result = model.eval(ds_eval, dataset_sink_mode=config.dataset_sink_mode)
  78. print("result : {}".format(result))
  79. if __name__ == "__main__":
  80. eval_alexnet()