diff --git a/npu_mnist_example/train_continue.py b/npu_mnist_example/train_continue.py index 99aaea7..34f5db1 100644 --- a/npu_mnist_example/train_continue.py +++ b/npu_mnist_example/train_continue.py @@ -13,7 +13,6 @@ import os import argparse from config import mnist_cfg as cfg from dataset import create_dataset -from dataset_distributed import create_dataset_parallel from lenet import LeNet5 import mindspore.nn as nn from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor @@ -22,14 +21,11 @@ from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.communication.management import get_rank -from openi import obs_copy_file -from openi import obs_copy_folder -from openi import openi_multidataset_to_env +#导入c2net包 +from c2net.context import prepare +from c2net.context.moxing_helper import obs_copy_file, obs_copy_folder parser = argparse.ArgumentParser(description='MindSpore Lenet Example') -parser.add_argument('--multi_data_url', - help='path to training/inference dataset folder', - default= '[{}]') parser.add_argument('--train_url', help='output folder to save/load', @@ -58,7 +54,13 @@ parser.add_argument('--ckpt_save_name', if __name__ == "__main__": + ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 args, unknown = parser.parse_known_args() + #初始化导入数据集和预训练模型到容器内 + c2net_context = prepare() + #获取数据集路径 + MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" + data_dir = '/cache/data' base_path = '/cache/output' @@ -70,13 +72,10 @@ if __name__ == "__main__": except Exception as e: print("path already exists") - openi_multidataset_to_env(args.multi_data_url, data_dir) device_num = int(os.getenv('RANK_SIZE')) if device_num == 1: - ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) - if device_num > 1: - ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) + ds_train = create_dataset(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) if ds_train.get_dataset_size() == 0: raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")