From 4e02d75bf639d7fc19b6898f44742426744892b1 Mon Sep 17 00:00:00 2001 From: liuzx Date: Tue, 16 Jan 2024 16:20:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0npu=E7=BB=A7=E7=BB=AD?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E7=A4=BA=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- npu_mnist_example/train_continue.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/npu_mnist_example/train_continue.py b/npu_mnist_example/train_continue.py index 99aaea7..34f5db1 100644 --- a/npu_mnist_example/train_continue.py +++ b/npu_mnist_example/train_continue.py @@ -13,7 +13,6 @@ import os import argparse from config import mnist_cfg as cfg from dataset import create_dataset -from dataset_distributed import create_dataset_parallel from lenet import LeNet5 import mindspore.nn as nn from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor @@ -22,14 +21,11 @@ from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.communication.management import get_rank -from openi import obs_copy_file -from openi import obs_copy_folder -from openi import openi_multidataset_to_env +#导入c2net包 +from c2net.context import prepare +from c2net.context.moxing_helper import obs_copy_file, obs_copy_folder parser = argparse.ArgumentParser(description='MindSpore Lenet Example') -parser.add_argument('--multi_data_url', - help='path to training/inference dataset folder', - default= '[{}]') parser.add_argument('--train_url', help='output folder to save/load', @@ -58,7 +54,13 @@ parser.add_argument('--ckpt_save_name', if __name__ == "__main__": + ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 args, unknown = parser.parse_known_args() + #初始化导入数据集和预训练模型到容器内 + c2net_context = prepare() + #获取数据集路径 + MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" + data_dir = '/cache/data' base_path = '/cache/output' @@ -70,13 +72,10 @@ if __name__ == "__main__": except Exception as e: print("path already exists") - openi_multidataset_to_env(args.multi_data_url, data_dir) device_num = int(os.getenv('RANK_SIZE')) if device_num == 1: - ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) - if device_num > 1: - ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) + ds_train = create_dataset(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) if ds_train.get_dataset_size() == 0: raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")