| @@ -29,6 +29,7 @@ from mindspore.context import ParallelMode | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.nn import Momentum | from mindspore.nn import Momentum | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from mindspore.communication.management import get_rank, get_group_size | |||||
| from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 | from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 | ||||
| from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet | from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet | ||||
| @@ -56,11 +57,11 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=a | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| print("Start train for maskrcnn!") | print("Start train for maskrcnn!") | ||||
| if not args_opt.do_eval and args_opt.run_distribute: | if not args_opt.do_eval and args_opt.run_distribute: | ||||
| rank = args_opt.rank_id | |||||
| device_num = args_opt.device_num | |||||
| init() | |||||
| rank = get_rank() | |||||
| device_num = get_group_size() | |||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True) | gradients_mean=True) | ||||
| init() | |||||
| else: | else: | ||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||