You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_c2net.py 3.4 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. ######################## train lenet example ########################
  3. train lenet and get network model files(.ckpt)
  4. The training of the intelligent computing network currently supports single dataset training, and does not require
  5. the obs copy process.It only needs to define two parameters and then call it directly:
  6. train_dir = '/cache/output' #The location of the output
  7. data_dir = '/cache/dataset' #The location of the dataset
  8. """
  9. #!/usr/bin/python
  10. #coding=utf-8
  11. import os
  12. import argparse
  13. from config import mnist_cfg as cfg
  14. from dataset import create_dataset
  15. from lenet import LeNet5
  16. import mindspore.nn as nn
  17. from mindspore import context
  18. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  19. from mindspore.train import Model
  20. from mindspore.nn.metrics import Accuracy
  21. from mindspore.common import set_seed
  22. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  23. parser.add_argument(
  24. '--device_target',
  25. type=str,
  26. default="Ascend",
  27. choices=['Ascend', 'CPU'],
  28. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  29. parser.add_argument('--epoch_size',
  30. type=int,
  31. default=5,
  32. help='Training epochs.')
  33. set_seed(1)
  34. if __name__ == "__main__":
  35. args, unknown = parser.parse_known_args()
  36. print('args:')
  37. print(args)
  38. ###define two parameters and then call it directly###
  39. train_dir = '/cache/output'
  40. data_dir = '/cache/dataset'
  41. ###Specifies the device CPU or Ascend NPU used for training###
  42. context.set_context(mode=context.GRAPH_MODE,
  43. device_target=args.device_target)
  44. ds_train = create_dataset(os.path.join(data_dir, "train"),
  45. cfg.batch_size)
  46. if ds_train.get_dataset_size() == 0:
  47. raise ValueError(
  48. "Please check dataset size > 0 and batch_size <= dataset size")
  49. network = LeNet5(cfg.num_classes)
  50. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  51. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  52. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  53. if args.device_target != "Ascend":
  54. model = Model(network,
  55. net_loss,
  56. net_opt,
  57. metrics={"accuracy": Accuracy()})
  58. else:
  59. model = Model(network,
  60. net_loss,
  61. net_opt,
  62. metrics={"accuracy": Accuracy()},
  63. amp_level="O2")
  64. config_ck = CheckpointConfig(
  65. save_checkpoint_steps=cfg.save_checkpoint_steps,
  66. keep_checkpoint_max=cfg.keep_checkpoint_max)
  67. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  68. directory=train_dir,
  69. config=config_ck)
  70. print("============== Starting Training ==============")
  71. epoch_size = cfg['epoch_size']
  72. if (args.epoch_size):
  73. epoch_size = args.epoch_size
  74. print('epoch_size is: ', epoch_size)
  75. model.train(epoch_size,
  76. ds_train,
  77. callbacks=[time_cb, ckpoint_cb,
  78. LossMonitor()])
  79. print("============== Finish Training ==============")