#!/usr/bin/python #coding=utf-8 """ ######################## train lenet example ######################## train lenet and get network model files(.ckpt) """ import os import argparse from config import mnist_cfg as cfg from dataset import create_dataset from lenet import LeNet5 import mindspore.nn as nn from mindspore import context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.common import set_seed parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser.add_argument( '--device_target', type=str, default="Ascend", choices=['Ascend', 'CPU'], help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend') parser.add_argument('--epoch_size', type=int, default=5, help='Training epochs.') set_seed(1) if __name__ == "__main__": args = parser.parse_args() print('args:') print(args) # train_dir = '/tmp/output' # data_dir = '/tmp/dataset' train_dir = '/cache/output' data_dir = '/cache/dataset' #注意:这里很重要,指定了训练所用的设备CPU还是Ascend NPU context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) #创建数据集 ds_train = create_dataset(os.path.join(data_dir, "train"), cfg.batch_size) if ds_train.get_dataset_size() == 0: raise ValueError( "Please check dataset size > 0 and batch_size <= dataset size") #创建网络 network = LeNet5(cfg.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) if args.device_target != "Ascend": model = Model(network, net_loss, net_opt, metrics={"accuracy": Accuracy()}) else: model = Model(network, net_loss, net_opt, metrics={"accuracy": Accuracy()}, amp_level="O2") config_ck = CheckpointConfig( save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) #定义模型输出路径 ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=train_dir, config=config_ck) #开始训练 print("============== Starting Training ==============") epoch_size = cfg['epoch_size'] if (args.epoch_size): epoch_size = args.epoch_size print('epoch_size is: ', epoch_size) model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()]) print("============== Finish Training ==============")