You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 3.6 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """
  2. 示例选用的数据集是MnistDataset_mindspore.zip
  3. 数据集结构是:
  4. MnistDataset_mindspore.zip
  5. ├── test
  6. │ ├── t10k-images-idx3-ubyte
  7. │ └── t10k-labels-idx1-ubyte
  8. └── train
  9. ├── train-images-idx3-ubyte
  10. └── train-labels-idx1-ubyte
  11. 使用注意事项:
  12. 1、在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
  13. 2、用户需要调用c2net的python sdk包
  14. """
  15. import time
  16. import os
  17. import argparse
  18. from config import mnist_cfg as cfg
  19. from dataset import create_dataset
  20. from lenet import LeNet5
  21. import mindspore.nn as nn
  22. import numpy as np
  23. from mindspore import context
  24. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  25. from mindspore import load_checkpoint, load_param_into_net
  26. from mindspore.train import Model
  27. from mindspore import Tensor
  28. #导入c2net包
  29. from c2net.context import prepare, upload_output
  30. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  31. parser.add_argument(
  32. '--device_target',
  33. type=str,
  34. default="Ascend",
  35. choices=['Ascend', 'CPU'],
  36. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  37. parser.add_argument('--epoch_size',
  38. type=int,
  39. default=5,
  40. help='Training epochs.')
  41. if __name__ == "__main__":
  42. ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
  43. args, unknown = parser.parse_known_args()
  44. #初始化导入数据集和预训练模型到容器内
  45. c2net_context = prepare()
  46. #获取数据集路径
  47. MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore"
  48. #获取预训练模型路径
  49. Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model"
  50. #获取输出路径
  51. save_path = c2net_context.output_path
  52. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
  53. network = LeNet5(cfg.num_classes)
  54. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  55. repeat_size = cfg.epoch_size
  56. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  57. #model = Model(network, net_loss, net_opt, metrics={"Accuracy"})
  58. model = Model(network, net_loss, net_opt)
  59. print("============== Starting Testing ==============")
  60. load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")))
  61. ds_test = create_dataset(os.path.join(MnistDataset_mindspore_path, "test"), batch_size=1).create_dict_iterator()
  62. data = next(ds_test)
  63. images = data["image"].asnumpy()
  64. labels = data["label"].asnumpy()
  65. print('Tensor:', Tensor(data['image']))
  66. output = model.predict(Tensor(data['image']))
  67. predicted = np.argmax(output.asnumpy(), axis=1)
  68. pred = np.argmax(output.asnumpy(), axis=1)
  69. print('predicted:', predicted)
  70. print('pred:', pred)
  71. print(f'Predicted: "{predicted[0]}", Actual: "{labels[0]}"')
  72. filename = 'result.txt'
  73. file_path = os.path.join(save_path, filename)
  74. with open(file_path, 'a+') as file:
  75. file.write(" {}: {:.2f} \n".format("Predicted", predicted[0]))
  76. ###上传训练结果到启智平台,注意必须将要输出的模型存储在c2net_context.output_path
  77. upload_output()

No Description