Browse Source

update

pull/3/head
liuzx 2 years ago
parent
commit
4fc0d713a8
1 changed files with 11 additions and 27 deletions
  1. +11
    -27
      npu_mnist_example/train_multi_card.py

+ 11
- 27
npu_mnist_example/train_multi_card.py View File

@@ -19,6 +19,7 @@
import os import os
import argparse import argparse
from config import mnist_cfg as cfg from config import mnist_cfg as cfg
from dataset import create_dataset
from dataset_distributed import create_dataset_parallel from dataset_distributed import create_dataset_parallel
from lenet import LeNet5 from lenet import LeNet5
import mindspore.nn as nn import mindspore.nn as nn
@@ -29,7 +30,6 @@ from mindspore.train import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank from mindspore.communication.management import init, get_rank
import time import time
#导入openi包
from c2net.context import prepare, upload_output from c2net.context import prepare, upload_output




@@ -50,10 +50,7 @@ parser.add_argument('--epoch_size',
if __name__ == "__main__": if __name__ == "__main__":
###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
MnistDataset_mindspore_path = ''
Mindspore_MNIST_Example_Model_path = ''
output_path = ''

device_num = int(os.getenv('RANK_SIZE')) device_num = int(os.getenv('RANK_SIZE'))
#使用多卡时 #使用多卡时
# set device_id and init for multi-card training # set device_id and init for multi-card training
@@ -63,32 +60,19 @@ if __name__ == "__main__":
init() init()
#Copying obs data does not need to be executed multiple times, just let the 0th card copy the data #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
local_rank=int(os.getenv('RANK_ID')) local_rank=int(os.getenv('RANK_ID'))
if local_rank%8==0:
#初始化导入数据集和预训练模型到容器内
c2net_context = prepare()
#获取数据集路径
MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore"
#获取预训练模型路径
Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model"
output_path = c2net_context.output_path
#Set a cache file to determine whether the data has been copied to obs.
#If this file exists during multi-card training, there is no need to copy the dataset multiple times.
f = open("/cache/download_input.txt", 'w')
f.close()
try:
if os.path.exists("/cache/download_input.txt"):
print("download_input succeed")
except Exception as e:
print("download_input failed")
while not os.path.exists("/cache/download_input.txt"):
time.sleep(1)
ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size)

#初始化导入数据集和预训练模型到容器内
c2net_context = prepare()
#获取数据集路径
MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore"
#获取预训练模型路径
Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model"
output_path = c2net_context.output_path
ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size)
network = LeNet5(cfg.num_classes) network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")))
#load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")))
if args.device_target != "Ascend": if args.device_target != "Ascend":
model = Model(network, model = Model(network,
net_loss, net_loss,


Loading…
Cancel
Save