| @@ -0,0 +1,138 @@ | |||
| ##################################################################################################### | |||
| # 继续训练功能:修改训练任务时,若勾选复用上次结果,则可在新训练任务的输出路径中读取到上次结果 | |||
| # | |||
| # 示例用法 | |||
| # - 增加两个训练参数 | |||
| # 'ckpt_save_name' 此次任务的输出文件名,用于保存此次训练的模型文件名称(不带后缀) | |||
| # 'ckpt_load_name' 上一次任务的输出文件名,用于加载上一次输出的模型文件名称(不带后缀),首次训练默认为空,则不读取任何文件 | |||
| # - 训练代码中判断 'ckpt_load_name' 是否为空,若不为空,则为继续训练任务 | |||
| ##################################################################################################### | |||
| import os | |||
| import argparse | |||
| from config import mnist_cfg as cfg | |||
| from dataset import create_dataset | |||
| from dataset_distributed import create_dataset_parallel | |||
| from lenet import LeNet5 | |||
| import mindspore.nn as nn | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore import load_checkpoint, load_param_into_net | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore.communication.management import get_rank | |||
| from openi import obs_copy_file | |||
| from openi import obs_copy_folder | |||
| from openi import openi_multidataset_to_env | |||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||
| parser.add_argument('--multi_data_url', | |||
| help='path to training/inference dataset folder', | |||
| default= '[{}]') | |||
| parser.add_argument('--train_url', | |||
| help='output folder to save/load', | |||
| default= '') | |||
| parser.add_argument( | |||
| '--device_target', | |||
| type=str, | |||
| default="Ascend", | |||
| choices=['Ascend', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') | |||
| parser.add_argument('--epoch_size', | |||
| type=int, | |||
| default=5, | |||
| help='Training epochs.') | |||
| ### continue task parameters | |||
| parser.add_argument('--ckpt_load_name', | |||
| help='model name to save/load', | |||
| default= '') | |||
| parser.add_argument('--ckpt_save_name', | |||
| help='model name to save/load', | |||
| default= 'checkpoint') | |||
| if __name__ == "__main__": | |||
| args, unknown = parser.parse_known_args() | |||
| data_dir = '/cache/data' | |||
| base_path = '/cache/output' | |||
| try: | |||
| if not os.path.exists(data_dir): | |||
| os.makedirs(data_dir) | |||
| if not os.path.exists(base_path): | |||
| os.makedirs(base_path) | |||
| except Exception as e: | |||
| print("path already exists") | |||
| openi_multidataset_to_env(args.multi_data_url, data_dir) | |||
| device_num = int(os.getenv('RANK_SIZE')) | |||
| if device_num == 1: | |||
| ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) | |||
| if device_num > 1: | |||
| ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) | |||
| if ds_train.get_dataset_size() == 0: | |||
| raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | |||
| network = LeNet5(cfg.num_classes) | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| ### 继续训练模型加载 | |||
| if args.ckpt_load_name: | |||
| obs_copy_folder(args.train_url, base_path) | |||
| load_path = "{}/{}.ckpt".format(base_path,args.ckpt_load_name) | |||
| param_dict = load_checkpoint(load_path) | |||
| load_param_into_net(network, param_dict) | |||
| print("Successfully load ckpt file:{}, saved_net_work:{}".format(load_path,param_dict)) | |||
| ### 保存已有模型名避免重复回传结果 | |||
| outputFiles = os.listdir(base_path) | |||
| if args.device_target != "Ascend": | |||
| model = Model(network, | |||
| net_loss, | |||
| net_opt, | |||
| metrics={"accuracy": Accuracy()}) | |||
| else: | |||
| model = Model(network, | |||
| net_loss, | |||
| net_opt, | |||
| metrics={"accuracy": Accuracy()}, | |||
| amp_level="O2") | |||
| config_ck = CheckpointConfig( | |||
| save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| #Note that this method saves the model file on each card. You need to specify the save path on each card. | |||
| # In this example, get_rank() is added to distinguish different paths. | |||
| if device_num == 1: | |||
| save_path = base_path + "/" | |||
| if device_num > 1: | |||
| save_path = base_path + "/" + str(get_rank()) + "/" | |||
| ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_save_name, | |||
| directory=save_path, | |||
| config=config_ck) | |||
| print("============== Starting Training ==============") | |||
| epoch_size = cfg['epoch_size'] | |||
| if (args.epoch_size): | |||
| epoch_size = args.epoch_size | |||
| print('epoch_size is: ', epoch_size) | |||
| model.train(epoch_size, | |||
| ds_train, | |||
| callbacks=[time_cb, ckpoint_cb, | |||
| LossMonitor()]) | |||
| ### 将训练容器中的新输出模型 回传到启智社区 | |||
| outputFilesNew = os.listdir(base_path) | |||
| new_models = [i for i in outputFilesNew if i not in outputFiles] | |||
| for n in new_models: | |||
| ckpt_url = base_path + "/" + n | |||
| obs_ckpt_url = args.train_url + "/" + n | |||
| obs_copy_file(ckpt_url, obs_ckpt_url) | |||