| @@ -13,7 +13,6 @@ import os | |||||
| import argparse | import argparse | ||||
| from config import mnist_cfg as cfg | from config import mnist_cfg as cfg | ||||
| from dataset import create_dataset | from dataset import create_dataset | ||||
| from dataset_distributed import create_dataset_parallel | |||||
| from lenet import LeNet5 | from lenet import LeNet5 | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| @@ -22,14 +21,11 @@ from mindspore.train import Model | |||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore.communication.management import get_rank | from mindspore.communication.management import get_rank | ||||
| from openi import obs_copy_file | |||||
| from openi import obs_copy_folder | |||||
| from openi import openi_multidataset_to_env | |||||
| #导入c2net包 | |||||
| from c2net.context import prepare | |||||
| from c2net.context.moxing_helper import obs_copy_file, obs_copy_folder | |||||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | ||||
| parser.add_argument('--multi_data_url', | |||||
| help='path to training/inference dataset folder', | |||||
| default= '[{}]') | |||||
| parser.add_argument('--train_url', | parser.add_argument('--train_url', | ||||
| help='output folder to save/load', | help='output folder to save/load', | ||||
| @@ -58,7 +54,13 @@ parser.add_argument('--ckpt_save_name', | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 | |||||
| args, unknown = parser.parse_known_args() | args, unknown = parser.parse_known_args() | ||||
| #初始化导入数据集和预训练模型到容器内 | |||||
| c2net_context = prepare() | |||||
| #获取数据集路径 | |||||
| MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" | |||||
| data_dir = '/cache/data' | data_dir = '/cache/data' | ||||
| base_path = '/cache/output' | base_path = '/cache/output' | ||||
| @@ -70,13 +72,10 @@ if __name__ == "__main__": | |||||
| except Exception as e: | except Exception as e: | ||||
| print("path already exists") | print("path already exists") | ||||
| openi_multidataset_to_env(args.multi_data_url, data_dir) | |||||
| device_num = int(os.getenv('RANK_SIZE')) | device_num = int(os.getenv('RANK_SIZE')) | ||||
| if device_num == 1: | if device_num == 1: | ||||
| ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) | |||||
| if device_num > 1: | |||||
| ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) | |||||
| ds_train = create_dataset(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) | |||||
| if ds_train.get_dataset_size() == 0: | if ds_train.get_dataset_size() == 0: | ||||
| raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | ||||