| @@ -49,12 +49,14 @@ def train(): | |||||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, | context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, | ||||
| device_target="Ascend", device_id=args.device_id) | device_target="Ascend", device_id=args.device_id) | ||||
| # init multicards training | # init multicards training | ||||
| args.rank = 0 | |||||
| args.group_size = 1 | |||||
| if device_num > 1: | if device_num > 1: | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num) | context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num) | ||||
| init() | init() | ||||
| args.rank = get_rank() | |||||
| args.group_size = get_group_size() | |||||
| args.rank = get_rank() | |||||
| args.group_size = get_group_size() | |||||
| # dataset | # dataset | ||||
| dataset = data_generator.SegDataset(image_mean=cfg.image_mean, | dataset = data_generator.SegDataset(image_mean=cfg.image_mean, | ||||