From 1b41008660e86ffd7545b13087343a60bb75e9bb Mon Sep 17 00:00:00 2001 From: liuzx Date: Thu, 1 Feb 2024 10:03:58 +0800 Subject: [PATCH] update train_multi_card example --- npu_mnist_example/train_multi_card.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/npu_mnist_example/train_multi_card.py b/npu_mnist_example/train_multi_card.py index 8e1a8bc..ffa7ea4 100644 --- a/npu_mnist_example/train_multi_card.py +++ b/npu_mnist_example/train_multi_card.py @@ -58,8 +58,19 @@ if __name__ == "__main__": init() #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data local_rank=int(os.getenv('RANK_ID')) - #初始化导入数据集和预训练模型到容器内 - c2net_context = prepare() + #初始化导入数据集和预训练模型到容器内,并行任务先让0卡拷贝数据,并用一个缓存文件标记0卡已prepare完成 + if local_rank == 0: + c2net_context = prepare() + f = open("/cache/prepare_completed.txt", 'w') + f.close() + try: + if os.path.exists("/cache/prepare_completed.txt"): + print("prepare completed!") + except Exception as e: + print("prepare failed") + while not os.path.exists("/cache/prepare_completed.txt"): + time.sleep(1) + c2net_context = prepare() #获取数据集路径 MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore"