Browse Source

update train_multi_card example

pull/3/head
liuzx 2 years ago
parent
commit
1b41008660
1 changed files with 13 additions and 2 deletions
  1. +13
    -2
      npu_mnist_example/train_multi_card.py

+ 13
- 2
npu_mnist_example/train_multi_card.py View File

@@ -58,8 +58,19 @@ if __name__ == "__main__":
init()
#Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
local_rank=int(os.getenv('RANK_ID'))
#初始化导入数据集和预训练模型到容器内
c2net_context = prepare()
#初始化导入数据集和预训练模型到容器内,并行任务先让0卡拷贝数据,并用一个缓存文件标记0卡已prepare完成
if local_rank == 0:
c2net_context = prepare()
f = open("/cache/prepare_completed.txt", 'w')
f.close()
try:
if os.path.exists("/cache/prepare_completed.txt"):
print("prepare completed!")
except Exception as e:
print("prepare failed")
while not os.path.exists("/cache/prepare_completed.txt"):
time.sleep(1)
c2net_context = prepare()
#获取数据集路径
MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore"



Loading…
Cancel
Save