|
|
|
@@ -58,8 +58,19 @@ if __name__ == "__main__": |
|
|
|
init() |
|
|
|
#Copying obs data does not need to be executed multiple times, just let the 0th card copy the data |
|
|
|
local_rank=int(os.getenv('RANK_ID')) |
|
|
|
#初始化导入数据集和预训练模型到容器内 |
|
|
|
c2net_context = prepare() |
|
|
|
#初始化导入数据集和预训练模型到容器内,并行任务先让0卡拷贝数据,并用一个缓存文件标记0卡已prepare完成 |
|
|
|
if local_rank == 0: |
|
|
|
c2net_context = prepare() |
|
|
|
f = open("/cache/prepare_completed.txt", 'w') |
|
|
|
f.close() |
|
|
|
try: |
|
|
|
if os.path.exists("/cache/prepare_completed.txt"): |
|
|
|
print("prepare completed!") |
|
|
|
except Exception as e: |
|
|
|
print("prepare failed") |
|
|
|
while not os.path.exists("/cache/prepare_completed.txt"): |
|
|
|
time.sleep(1) |
|
|
|
c2net_context = prepare() |
|
|
|
#获取数据集路径 |
|
|
|
MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" |
|
|
|
|
|
|
|
|