|
|
|
@@ -29,6 +29,7 @@ from mindspore.context import ParallelMode |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
from mindspore.nn import Momentum |
|
|
|
from mindspore.common import set_seed |
|
|
|
from mindspore.communication.management import get_rank, get_group_size |
|
|
|
|
|
|
|
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 |
|
|
|
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet |
|
|
|
@@ -56,11 +57,11 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=a |
|
|
|
if __name__ == '__main__': |
|
|
|
print("Start train for maskrcnn!") |
|
|
|
if not args_opt.do_eval and args_opt.run_distribute: |
|
|
|
rank = args_opt.rank_id |
|
|
|
device_num = args_opt.device_num |
|
|
|
init() |
|
|
|
rank = get_rank() |
|
|
|
device_num = get_group_size() |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True) |
|
|
|
init() |
|
|
|
else: |
|
|
|
rank = 0 |
|
|
|
device_num = 1 |
|
|
|
|