| @@ -19,6 +19,7 @@ | |||
| import os | |||
| import argparse | |||
| from config import mnist_cfg as cfg | |||
| from dataset import create_dataset | |||
| from dataset_distributed import create_dataset_parallel | |||
| from lenet import LeNet5 | |||
| import mindspore.nn as nn | |||
| @@ -29,7 +30,6 @@ from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import init, get_rank | |||
| import time | |||
| #导入openi包 | |||
| from c2net.context import prepare, upload_output | |||
| @@ -50,10 +50,7 @@ parser.add_argument('--epoch_size', | |||
| if __name__ == "__main__": | |||
| ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 | |||
| args, unknown = parser.parse_known_args() | |||
| MnistDataset_mindspore_path = '' | |||
| Mindspore_MNIST_Example_Model_path = '' | |||
| output_path = '' | |||
| device_num = int(os.getenv('RANK_SIZE')) | |||
| #使用多卡时 | |||
| # set device_id and init for multi-card training | |||
| @@ -63,32 +60,19 @@ if __name__ == "__main__": | |||
| init() | |||
| #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data | |||
| local_rank=int(os.getenv('RANK_ID')) | |||
| if local_rank%8==0: | |||
| #初始化导入数据集和预训练模型到容器内 | |||
| c2net_context = prepare() | |||
| #获取数据集路径 | |||
| MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" | |||
| #获取预训练模型路径 | |||
| Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model" | |||
| output_path = c2net_context.output_path | |||
| #Set a cache file to determine whether the data has been copied to obs. | |||
| #If this file exists during multi-card training, there is no need to copy the dataset multiple times. | |||
| f = open("/cache/download_input.txt", 'w') | |||
| f.close() | |||
| try: | |||
| if os.path.exists("/cache/download_input.txt"): | |||
| print("download_input succeed") | |||
| except Exception as e: | |||
| print("download_input failed") | |||
| while not os.path.exists("/cache/download_input.txt"): | |||
| time.sleep(1) | |||
| ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) | |||
| #初始化导入数据集和预训练模型到容器内 | |||
| c2net_context = prepare() | |||
| #获取数据集路径 | |||
| MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" | |||
| #获取预训练模型路径 | |||
| Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model" | |||
| output_path = c2net_context.output_path | |||
| ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) | |||
| network = LeNet5(cfg.num_classes) | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) | |||
| #load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) | |||
| if args.device_target != "Ascend": | |||
| model = Model(network, | |||
| net_loss, | |||