|
|
|
@@ -124,15 +124,6 @@ def parse_args(): |
|
|
|
args.data_root = os.path.join(args.data_dir, 'train2014') |
|
|
|
args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json') |
|
|
|
|
|
|
|
# init distributed |
|
|
|
if args.is_distributed: |
|
|
|
if args.device_target == "Ascend": |
|
|
|
init() |
|
|
|
else: |
|
|
|
init("nccl") |
|
|
|
args.rank = get_rank() |
|
|
|
args.group_size = get_group_size() |
|
|
|
|
|
|
|
# select for master rank save ckpt or all rank save, compatiable for model parallel |
|
|
|
args.rank_save_ckpt_flag = 0 |
|
|
|
if args.is_save_on_master: |
|
|
|
@@ -161,6 +152,14 @@ def train(): |
|
|
|
devid = int(os.getenv('DEVICE_ID', '0')) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, |
|
|
|
device_target=args.device_target, save_graphs=True, device_id=devid) |
|
|
|
# init distributed |
|
|
|
if args.is_distributed: |
|
|
|
if args.device_target == "Ascend": |
|
|
|
init() |
|
|
|
else: |
|
|
|
init("nccl") |
|
|
|
args.rank = get_rank() |
|
|
|
args.group_size = get_group_size() |
|
|
|
if args.need_profiler: |
|
|
|
from mindspore.profiler.profiling import Profiler |
|
|
|
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) |
|
|
|
|