|
|
|
@@ -13,7 +13,6 @@ import os |
|
|
|
import argparse |
|
|
|
from config import mnist_cfg as cfg |
|
|
|
from dataset import create_dataset |
|
|
|
from dataset_distributed import create_dataset_parallel |
|
|
|
from lenet import LeNet5 |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor |
|
|
|
@@ -22,14 +21,11 @@ from mindspore.train import Model |
|
|
|
from mindspore.nn.metrics import Accuracy |
|
|
|
from mindspore.communication.management import get_rank |
|
|
|
|
|
|
|
from openi import obs_copy_file |
|
|
|
from openi import obs_copy_folder |
|
|
|
from openi import openi_multidataset_to_env |
|
|
|
#导入c2net包 |
|
|
|
from c2net.context import prepare |
|
|
|
from c2net.context.moxing_helper import obs_copy_file, obs_copy_folder |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') |
|
|
|
parser.add_argument('--multi_data_url', |
|
|
|
help='path to training/inference dataset folder', |
|
|
|
default= '[{}]') |
|
|
|
|
|
|
|
parser.add_argument('--train_url', |
|
|
|
help='output folder to save/load', |
|
|
|
@@ -58,7 +54,13 @@ parser.add_argument('--ckpt_save_name', |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 |
|
|
|
args, unknown = parser.parse_known_args() |
|
|
|
#初始化导入数据集和预训练模型到容器内 |
|
|
|
c2net_context = prepare() |
|
|
|
#获取数据集路径 |
|
|
|
MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" |
|
|
|
|
|
|
|
data_dir = '/cache/data' |
|
|
|
base_path = '/cache/output' |
|
|
|
|
|
|
|
@@ -70,13 +72,10 @@ if __name__ == "__main__": |
|
|
|
except Exception as e: |
|
|
|
print("path already exists") |
|
|
|
|
|
|
|
openi_multidataset_to_env(args.multi_data_url, data_dir) |
|
|
|
|
|
|
|
device_num = int(os.getenv('RANK_SIZE')) |
|
|
|
if device_num == 1: |
|
|
|
ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) |
|
|
|
if device_num > 1: |
|
|
|
ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) |
|
|
|
ds_train = create_dataset(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) |
|
|
|
if ds_train.get_dataset_size() == 0: |
|
|
|
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") |
|
|
|
|
|
|
|
|