Browse Source

更新npu继续训练示例

liuzx-patch-1
liuzx 2 years ago
parent
commit
4e02d75bf6
1 changed files with 10 additions and 11 deletions
  1. +10
    -11
      npu_mnist_example/train_continue.py

+ 10
- 11
npu_mnist_example/train_continue.py View File

@@ -13,7 +13,6 @@ import os
import argparse
from config import mnist_cfg as cfg
from dataset import create_dataset
from dataset_distributed import create_dataset_parallel
from lenet import LeNet5
import mindspore.nn as nn
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@@ -22,14 +21,11 @@ from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.communication.management import get_rank

from openi import obs_copy_file
from openi import obs_copy_folder
from openi import openi_multidataset_to_env
#导入c2net包
from c2net.context import prepare
from c2net.context.moxing_helper import obs_copy_file, obs_copy_folder

parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--multi_data_url',
help='path to training/inference dataset folder',
default= '[{}]')

parser.add_argument('--train_url',
help='output folder to save/load',
@@ -58,7 +54,13 @@ parser.add_argument('--ckpt_save_name',


if __name__ == "__main__":
###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
args, unknown = parser.parse_known_args()
#初始化导入数据集和预训练模型到容器内
c2net_context = prepare()
#获取数据集路径
MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore"

data_dir = '/cache/data'
base_path = '/cache/output'

@@ -70,13 +72,10 @@ if __name__ == "__main__":
except Exception as e:
print("path already exists")

openi_multidataset_to_env(args.multi_data_url, data_dir)

device_num = int(os.getenv('RANK_SIZE'))
if device_num == 1:
ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
if device_num > 1:
ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
ds_train = create_dataset(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")



Loading…
Cancel
Save