##################################################################################################### # 继续训练功能:修改训练任务时,若勾选复用上次结果,则可在新训练任务的输出路径中读取到上次结果 # # 示例用法 # - 增加两个训练参数 # 'ckpt_save_name' 此次任务的输出文件名,用于保存此次训练的模型文件名称(不带后缀) # 'ckpt_load_name' 上一次任务的输出文件名,用于加载上一次输出的模型文件名称(不带后缀),首次训练默认为空,则不读取任何文件 # - 训练代码中判断 'ckpt_load_name' 是否为空,若不为空,则为继续训练任务 ##################################################################################################### import os import argparse from config import mnist_cfg as cfg from dataset import create_dataset from lenet import LeNet5 import mindspore.nn as nn from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore import load_checkpoint, load_param_into_net from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.communication.management import get_rank #导入c2net包 from c2net.context import prepare from c2net.context.moxing_helper import obs_copy_file, obs_copy_folder parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser.add_argument('--train_url', help='output folder to save/load', default= '') parser.add_argument( '--device_target', type=str, default="Ascend", choices=['Ascend', 'CPU'], help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') parser.add_argument('--epoch_size', type=int, default=5, help='Training epochs.') ### continue task parameters parser.add_argument('--ckpt_load_name', help='model name to save/load', default= '') parser.add_argument('--ckpt_save_name', help='model name to save/load', default= 'checkpoint') if __name__ == "__main__": ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 args, unknown = parser.parse_known_args() #初始化导入数据集和预训练模型到容器内 c2net_context = prepare() #获取数据集路径 MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" data_dir = '/cache/data' base_path = '/cache/output' try: if not os.path.exists(data_dir): os.makedirs(data_dir) if not os.path.exists(base_path): os.makedirs(base_path) except Exception as e: print("path already exists") device_num = int(os.getenv('RANK_SIZE')) if device_num == 1: ds_train = create_dataset(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) if ds_train.get_dataset_size() == 0: raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") network = LeNet5(cfg.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) ### 继续训练模型加载 if args.ckpt_load_name: obs_copy_folder(args.train_url, base_path) load_path = "{}/{}.ckpt".format(base_path,args.ckpt_load_name) param_dict = load_checkpoint(load_path) load_param_into_net(network, param_dict) print("Successfully load ckpt file:{}, saved_net_work:{}".format(load_path,param_dict)) ### 保存已有模型名避免重复回传结果 outputFiles = os.listdir(base_path) if args.device_target != "Ascend": model = Model(network, net_loss, net_opt, metrics={"accuracy": Accuracy()}) else: model = Model(network, net_loss, net_opt, metrics={"accuracy": Accuracy()}, amp_level="O2") config_ck = CheckpointConfig( save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) #Note that this method saves the model file on each card. You need to specify the save path on each card. # In this example, get_rank() is added to distinguish different paths. if device_num == 1: save_path = base_path + "/" if device_num > 1: save_path = base_path + "/" + str(get_rank()) + "/" ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_save_name, directory=save_path, config=config_ck) print("============== Starting Training ==============") epoch_size = cfg['epoch_size'] if (args.epoch_size): epoch_size = args.epoch_size print('epoch_size is: ', epoch_size) model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()]) ### 将训练容器中的新输出模型 回传到启智社区 outputFilesNew = os.listdir(base_path) new_models = [i for i in outputFilesNew if i not in outputFiles] for n in new_models: ckpt_url = base_path + "/" + n obs_ckpt_url = args.train_url + "/" + n obs_copy_file(ckpt_url, obs_ckpt_url)