""" ######################## train lenet example ######################## train lenet and get network model files(.ckpt) The training of the intelligent computing network currently supports single dataset training, and does not require the obs copy process.It only needs to define two parameters and then call it directly: train_dir = '/cache/output' #The location of the output data_dir = '/cache/dataset' #The location of the dataset """ #!/usr/bin/python #coding=utf-8 import os import argparse from config import mnist_cfg as cfg from dataset import create_dataset from lenet import LeNet5 import mindspore.nn as nn from mindspore import context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.common import set_seed parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser.add_argument( '--device_target', type=str, default="Ascend", choices=['Ascend', 'CPU'], help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') parser.add_argument('--epoch_size', type=int, default=5, help='Training epochs.') set_seed(1) if __name__ == "__main__": args, unknown = parser.parse_known_args() print('args:') print(args) ###define two parameters and then call it directly### train_dir = '/cache/output' data_dir = '/cache/dataset' ###Specifies the device CPU or Ascend NPU used for training### context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) ds_train = create_dataset(os.path.join(data_dir, "train"), cfg.batch_size) if ds_train.get_dataset_size() == 0: raise ValueError( "Please check dataset size > 0 and batch_size <= dataset size") network = LeNet5(cfg.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) if args.device_target != "Ascend": model = Model(network, net_loss, net_opt, metrics={"accuracy": Accuracy()}) else: model = Model(network, net_loss, net_opt, metrics={"accuracy": Accuracy()}, amp_level="O2") config_ck = CheckpointConfig( save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=train_dir, config=config_ck) print("============== Starting Training ==============") epoch_size = cfg['epoch_size'] if (args.epoch_size): epoch_size = args.epoch_size print('epoch_size is: ', epoch_size) model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()]) print("============== Finish Training ==============")