| @@ -124,15 +124,6 @@ def parse_args(): | |||||
| args.data_root = os.path.join(args.data_dir, 'train2014') | args.data_root = os.path.join(args.data_dir, 'train2014') | ||||
| args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json') | args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json') | ||||
| # init distributed | |||||
| if args.is_distributed: | |||||
| if args.device_target == "Ascend": | |||||
| init() | |||||
| else: | |||||
| init("nccl") | |||||
| args.rank = get_rank() | |||||
| args.group_size = get_group_size() | |||||
| # select for master rank save ckpt or all rank save, compatiable for model parallel | # select for master rank save ckpt or all rank save, compatiable for model parallel | ||||
| args.rank_save_ckpt_flag = 0 | args.rank_save_ckpt_flag = 0 | ||||
| if args.is_save_on_master: | if args.is_save_on_master: | ||||
| @@ -161,6 +152,14 @@ def train(): | |||||
| devid = int(os.getenv('DEVICE_ID', '0')) | devid = int(os.getenv('DEVICE_ID', '0')) | ||||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | ||||
| device_target=args.device_target, save_graphs=True, device_id=devid) | device_target=args.device_target, save_graphs=True, device_id=devid) | ||||
| # init distributed | |||||
| if args.is_distributed: | |||||
| if args.device_target == "Ascend": | |||||
| init() | |||||
| else: | |||||
| init("nccl") | |||||
| args.rank = get_rank() | |||||
| args.group_size = get_group_size() | |||||
| if args.need_profiler: | if args.need_profiler: | ||||
| from mindspore.profiler.profiling import Profiler | from mindspore.profiler.profiling import Profiler | ||||
| profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) | profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) | ||||