diff --git a/model_zoo/official/cv/FCN8s/train.py b/model_zoo/official/cv/FCN8s/train.py index 2e0eb92324..f076ff7589 100644 --- a/model_zoo/official/cv/FCN8s/train.py +++ b/model_zoo/official/cv/FCN8s/train.py @@ -46,6 +46,8 @@ def train(): args = parse_args() cfg = FCN8s_VOC2012_cfg device_num = int(os.environ.get("DEVICE_NUM", 1)) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, + device_target="Ascend", device_id=args.device_id) # init multicards training if device_num > 1: parallel_mode = ParallelMode.DATA_PARALLEL @@ -54,9 +56,6 @@ def train(): args.rank = get_rank() args.group_size = get_group_size() - context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, - device_target="Ascend", device_id=args.device_id) - # dataset dataset = data_generator.SegDataset(image_mean=cfg.image_mean, image_std=cfg.image_std,