You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_c2net_testcopy.py 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. ######################## train lenet example ########################
  3. train lenet and get network model files(.ckpt)
  4. """
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. import os
  8. import argparse
  9. from config import mnist_cfg as cfg
  10. from dataset import create_dataset
  11. from lenet import LeNet5
  12. import mindspore.nn as nn
  13. from mindspore import context
  14. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  15. from mindspore.train import Model
  16. from mindspore.nn.metrics import Accuracy
  17. from mindspore.common import set_seed
  18. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  19. parser.add_argument(
  20. '--device_target',
  21. type=str,
  22. default="Ascend",
  23. choices=['Ascend', 'CPU'],
  24. help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend')
  25. parser.add_argument('--epoch_size',
  26. type=int,
  27. default=5,
  28. help='Training epochs.')
  29. set_seed(1)
  30. if __name__ == "__main__":
  31. args = parser.parse_args()
  32. print('args:')
  33. print(args)
  34. train_dir = '/cache/output'
  35. data_dir = '/cache/dataset'
  36. #注意:这里很重要,指定了训练所用的设备CPU还是Ascend NPU
  37. context.set_context(mode=context.GRAPH_MODE,
  38. device_target=args.device_target)
  39. #创建数据集
  40. ds_train = create_dataset(os.path.join(data_dir, "train"),
  41. cfg.batch_size)
  42. if ds_train.get_dataset_size() == 0:
  43. raise ValueError(
  44. "Please check dataset size > 0 and batch_size <= dataset size")
  45. #创建网络
  46. network = LeNet5(cfg.num_classes)
  47. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  48. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  49. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  50. if args.device_target != "Ascend":
  51. model = Model(network,
  52. net_loss,
  53. net_opt,
  54. metrics={"accuracy": Accuracy()})
  55. else:
  56. model = Model(network,
  57. net_loss,
  58. net_opt,
  59. metrics={"accuracy": Accuracy()},
  60. amp_level="O2")
  61. config_ck = CheckpointConfig(
  62. save_checkpoint_steps=cfg.save_checkpoint_steps,
  63. keep_checkpoint_max=cfg.keep_checkpoint_max)
  64. #定义模型输出路径
  65. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  66. directory=train_dir,
  67. config=config_ck)
  68. #开始训练
  69. print("============== Starting Training ==============")
  70. epoch_size = cfg['epoch_size']
  71. if (args.epoch_size):
  72. epoch_size = args.epoch_size
  73. print('epoch_size is: ', epoch_size)
  74. # 测试代码。结果回传
  75. os.system("cd /cache/script_for_grampus/ &&./uploader_for_npu " + "/cache/code/")
  76. model.train(epoch_size,
  77. ds_train,
  78. callbacks=[time_cb, ckpoint_cb,
  79. LossMonitor()])
  80. print("============== Finish Training ==============")