Browse Source

更新npu继续训练示例

liuzx-patch-1
liuzx 2 years ago
parent
commit
4e02d75bf6
1 changed files with 10 additions and 11 deletions
  1. +10
    -11
      npu_mnist_example/train_continue.py

+ 10
- 11
npu_mnist_example/train_continue.py View File

@@ -13,7 +13,6 @@ import os
import argparse import argparse
from config import mnist_cfg as cfg from config import mnist_cfg as cfg
from dataset import create_dataset from dataset import create_dataset
from dataset_distributed import create_dataset_parallel
from lenet import LeNet5 from lenet import LeNet5
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@@ -22,14 +21,11 @@ from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.communication.management import get_rank from mindspore.communication.management import get_rank


from openi import obs_copy_file
from openi import obs_copy_folder
from openi import openi_multidataset_to_env
#导入c2net包
from c2net.context import prepare
from c2net.context.moxing_helper import obs_copy_file, obs_copy_folder


parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--multi_data_url',
help='path to training/inference dataset folder',
default= '[{}]')


parser.add_argument('--train_url', parser.add_argument('--train_url',
help='output folder to save/load', help='output folder to save/load',
@@ -58,7 +54,13 @@ parser.add_argument('--ckpt_save_name',




if __name__ == "__main__": if __name__ == "__main__":
###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
#初始化导入数据集和预训练模型到容器内
c2net_context = prepare()
#获取数据集路径
MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore"

data_dir = '/cache/data' data_dir = '/cache/data'
base_path = '/cache/output' base_path = '/cache/output'


@@ -70,13 +72,10 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print("path already exists") print("path already exists")


openi_multidataset_to_env(args.multi_data_url, data_dir)


device_num = int(os.getenv('RANK_SIZE')) device_num = int(os.getenv('RANK_SIZE'))
if device_num == 1: if device_num == 1:
ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
if device_num > 1:
ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
ds_train = create_dataset(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size)
if ds_train.get_dataset_size() == 0: if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")




Loading…
Cancel
Save