|
|
|
@@ -98,24 +98,24 @@ if __name__ == '__main__': |
|
|
|
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') |
|
|
|
args_opt = parser.parse_args() |
|
|
|
|
|
|
|
# init distributed |
|
|
|
if args_opt.is_distributed: |
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit(): |
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) |
|
|
|
rank = get_rank() |
|
|
|
group_size = get_group_size() |
|
|
|
parallel_mode = ParallelMode.DATA_PARALLEL |
|
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) |
|
|
|
init() |
|
|
|
else: |
|
|
|
rank = 0 |
|
|
|
group_size = 1 |
|
|
|
context.set_context(device_id=0) |
|
|
|
|
|
|
|
if args_opt.device_target == "Ascend": |
|
|
|
#train on Ascend |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False) |
|
|
|
|
|
|
|
# init distributed |
|
|
|
if args_opt.is_distributed: |
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit(): |
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) |
|
|
|
init() |
|
|
|
rank = get_rank() |
|
|
|
group_size = get_group_size() |
|
|
|
parallel_mode = ParallelMode.DATA_PARALLEL |
|
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) |
|
|
|
else: |
|
|
|
rank = 0 |
|
|
|
group_size = 1 |
|
|
|
context.set_context(device_id=0) |
|
|
|
|
|
|
|
# define network |
|
|
|
net = xception(class_num=config.class_num) |
|
|
|
net.to_float(mstype.float16) |
|
|
|
|