Browse Source

!9681 resolved 8-SIM run task error in Ascend environment

From: @shuzigood
Reviewed-by: @guoqi1024,@oacjiewen
Signed-off-by: @guoqi1024
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
84c49f9d77
1 changed files with 8 additions and 9 deletions
  1. +8
    -9
      model_zoo/official/cv/yolov3_darknet53/train.py

+ 8
- 9
model_zoo/official/cv/yolov3_darknet53/train.py View File

@@ -124,15 +124,6 @@ def parse_args():
args.data_root = os.path.join(args.data_dir, 'train2014')
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json')

# init distributed
if args.is_distributed:
if args.device_target == "Ascend":
init()
else:
init("nccl")
args.rank = get_rank()
args.group_size = get_group_size()

# select for master rank save ckpt or all rank save, compatiable for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
@@ -161,6 +152,14 @@ def train():
devid = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=True, device_id=devid)
# init distributed
if args.is_distributed:
if args.device_target == "Ascend":
init()
else:
init("nccl")
args.rank = get_rank()
args.group_size = get_group_size()
if args.need_profiler:
from mindspore.profiler.profiling import Profiler
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)


Loading…
Cancel
Save