From b34d206f84c564b870788de63abccf1a3b8f731c Mon Sep 17 00:00:00 2001 From: gengdongjie Date: Tue, 22 Dec 2020 20:01:53 +0800 Subject: [PATCH] maskrcnn support 16p --- model_zoo/official/cv/maskrcnn/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/model_zoo/official/cv/maskrcnn/train.py b/model_zoo/official/cv/maskrcnn/train.py index e5edb3b947..941c881ffd 100644 --- a/model_zoo/official/cv/maskrcnn/train.py +++ b/model_zoo/official/cv/maskrcnn/train.py @@ -29,6 +29,7 @@ from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn import Momentum from mindspore.common import set_seed +from mindspore.communication.management import get_rank, get_group_size from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet @@ -56,11 +57,11 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=a if __name__ == '__main__': print("Start train for maskrcnn!") if not args_opt.do_eval and args_opt.run_distribute: - rank = args_opt.rank_id - device_num = args_opt.device_num + init() + rank = get_rank() + device_num = get_group_size() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) - init() else: rank = 0 device_num = 1