Browse Source

maskrcnn support 16p

tags/v1.1.0
gengdongjie 5 years ago
parent
commit
b34d206f84
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      model_zoo/official/cv/maskrcnn/train.py

+ 4
- 3
model_zoo/official/cv/maskrcnn/train.py View File

@@ -29,6 +29,7 @@ from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Momentum
from mindspore.common import set_seed
from mindspore.communication.management import get_rank, get_group_size

from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
@@ -56,11 +57,11 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=a
if __name__ == '__main__':
print("Start train for maskrcnn!")
if not args_opt.do_eval and args_opt.run_distribute:
rank = args_opt.rank_id
device_num = args_opt.device_num
init()
rank = get_rank()
device_num = get_group_size()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1


Loading…
Cancel
Save