From 4fc0d713a8cc791a30e68f1ef262c7b2fc9fd48b Mon Sep 17 00:00:00 2001 From: liuzx Date: Thu, 25 Jan 2024 09:07:21 +0800 Subject: [PATCH] update --- npu_mnist_example/train_multi_card.py | 38 ++++++++------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/npu_mnist_example/train_multi_card.py b/npu_mnist_example/train_multi_card.py index 2741550..d53a7e8 100644 --- a/npu_mnist_example/train_multi_card.py +++ b/npu_mnist_example/train_multi_card.py @@ -19,6 +19,7 @@ import os import argparse from config import mnist_cfg as cfg +from dataset import create_dataset from dataset_distributed import create_dataset_parallel from lenet import LeNet5 import mindspore.nn as nn @@ -29,7 +30,6 @@ from mindspore.train import Model from mindspore.context import ParallelMode from mindspore.communication.management import init, get_rank import time -#导入openi包 from c2net.context import prepare, upload_output @@ -50,10 +50,7 @@ parser.add_argument('--epoch_size', if __name__ == "__main__": ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 args, unknown = parser.parse_known_args() - MnistDataset_mindspore_path = '' - Mindspore_MNIST_Example_Model_path = '' - output_path = '' - + device_num = int(os.getenv('RANK_SIZE')) #使用多卡时 # set device_id and init for multi-card training @@ -63,32 +60,19 @@ if __name__ == "__main__": init() #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data local_rank=int(os.getenv('RANK_ID')) - if local_rank%8==0: - #初始化导入数据集和预训练模型到容器内 - c2net_context = prepare() - #获取数据集路径 - MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" - #获取预训练模型路径 - Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model" - output_path = c2net_context.output_path - #Set a cache file to determine whether the data has been copied to obs. - #If this file exists during multi-card training, there is no need to copy the dataset multiple times. - f = open("/cache/download_input.txt", 'w') - f.close() - try: - if os.path.exists("/cache/download_input.txt"): - print("download_input succeed") - except Exception as e: - print("download_input failed") - while not os.path.exists("/cache/download_input.txt"): - time.sleep(1) - ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) - + #初始化导入数据集和预训练模型到容器内 + c2net_context = prepare() + #获取数据集路径 + MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" + #获取预训练模型路径 + Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model" + output_path = c2net_context.output_path + ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) network = LeNet5(cfg.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) + #load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) if args.device_target != "Ascend": model = Model(network, net_loss,