diff --git a/npu_mnist_example/train_continue.py b/npu_mnist_example/train_continue.py new file mode 100644 index 0000000..99aaea7 --- /dev/null +++ b/npu_mnist_example/train_continue.py @@ -0,0 +1,138 @@ +##################################################################################################### +# 继续训练功能:修改训练任务时,若勾选复用上次结果,则可在新训练任务的输出路径中读取到上次结果 +# +# 示例用法 +# - 增加两个训练参数 +# 'ckpt_save_name' 此次任务的输出文件名,用于保存此次训练的模型文件名称(不带后缀) +# 'ckpt_load_name' 上一次任务的输出文件名,用于加载上一次输出的模型文件名称(不带后缀),首次训练默认为空,则不读取任何文件 +# - 训练代码中判断 'ckpt_load_name' 是否为空,若不为空,则为继续训练任务 +##################################################################################################### + + +import os +import argparse +from config import mnist_cfg as cfg +from dataset import create_dataset +from dataset_distributed import create_dataset_parallel +from lenet import LeNet5 +import mindspore.nn as nn +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore import load_checkpoint, load_param_into_net +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore.communication.management import get_rank + +from openi import obs_copy_file +from openi import obs_copy_folder +from openi import openi_multidataset_to_env + +parser = argparse.ArgumentParser(description='MindSpore Lenet Example') +parser.add_argument('--multi_data_url', + help='path to training/inference dataset folder', + default= '[{}]') + +parser.add_argument('--train_url', + help='output folder to save/load', + default= '') + +parser.add_argument( + '--device_target', + type=str, + default="Ascend", + choices=['Ascend', 'CPU'], + help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') + +parser.add_argument('--epoch_size', + type=int, + default=5, + help='Training epochs.') + +### continue task parameters +parser.add_argument('--ckpt_load_name', + help='model name to save/load', + default= '') + +parser.add_argument('--ckpt_save_name', + help='model name to save/load', + default= 'checkpoint') + + +if __name__ == "__main__": + args, unknown = parser.parse_known_args() + data_dir = '/cache/data' + base_path = '/cache/output' + + try: + if not os.path.exists(data_dir): + os.makedirs(data_dir) + if not os.path.exists(base_path): + os.makedirs(base_path) + except Exception as e: + print("path already exists") + + openi_multidataset_to_env(args.multi_data_url, data_dir) + + device_num = int(os.getenv('RANK_SIZE')) + if device_num == 1: + ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) + if device_num > 1: + ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) + if ds_train.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + + ### 继续训练模型加载 + if args.ckpt_load_name: + obs_copy_folder(args.train_url, base_path) + load_path = "{}/{}.ckpt".format(base_path,args.ckpt_load_name) + param_dict = load_checkpoint(load_path) + load_param_into_net(network, param_dict) + print("Successfully load ckpt file:{}, saved_net_work:{}".format(load_path,param_dict)) + ### 保存已有模型名避免重复回传结果 + outputFiles = os.listdir(base_path) + + if args.device_target != "Ascend": + model = Model(network, + net_loss, + net_opt, + metrics={"accuracy": Accuracy()}) + else: + model = Model(network, + net_loss, + net_opt, + metrics={"accuracy": Accuracy()}, + amp_level="O2") + + config_ck = CheckpointConfig( + save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + #Note that this method saves the model file on each card. You need to specify the save path on each card. + # In this example, get_rank() is added to distinguish different paths. + if device_num == 1: + save_path = base_path + "/" + if device_num > 1: + save_path = base_path + "/" + str(get_rank()) + "/" + ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_save_name, + directory=save_path, + config=config_ck) + print("============== Starting Training ==============") + epoch_size = cfg['epoch_size'] + if (args.epoch_size): + epoch_size = args.epoch_size + print('epoch_size is: ', epoch_size) + model.train(epoch_size, + ds_train, + callbacks=[time_cb, ckpoint_cb, + LossMonitor()]) + + ### 将训练容器中的新输出模型 回传到启智社区 + outputFilesNew = os.listdir(base_path) + new_models = [i for i in outputFilesNew if i not in outputFiles] + for n in new_models: + ckpt_url = base_path + "/" + n + obs_ckpt_url = args.train_url + "/" + n + obs_copy_file(ckpt_url, obs_ckpt_url) \ No newline at end of file