| @@ -94,11 +94,19 @@ class MyTimeMonitor(Callback): | |||||
| def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16", | def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16", | ||||
| device_num=1): | device_num=1): | ||||
| if args_opt.mode == "GRAPH": | |||||
| ds_num_parallel_worker = 4 | |||||
| map_num_parallel_worker = 8 | |||||
| batch_num_parallel_worker = None | |||||
| else: | |||||
| ds_num_parallel_worker = 2 | |||||
| map_num_parallel_worker = 3 | |||||
| batch_num_parallel_worker = 2 | |||||
| ds.config.set_numa_enable(True) | ds.config.set_numa_enable(True) | ||||
| if device_num == 1: | if device_num == 1: | ||||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True) | |||||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True) | |||||
| else: | else: | ||||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True, | |||||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True, | |||||
| num_shards=device_num, shard_id=get_rank()) | num_shards=device_num, shard_id=get_rank()) | ||||
| image_size = 224 | image_size = 224 | ||||
| mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | ||||
| @@ -127,9 +135,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" | |||||
| ] | ] | ||||
| if dtype == "fp32": | if dtype == "fp32": | ||||
| trans.append(C.HWC2CHW()) | trans.append(C.HWC2CHW()) | ||||
| data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) | |||||
| data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=map_num_parallel_worker) | |||||
| # apply batch operations | # apply batch operations | ||||
| data_set = data_set.batch(batch_size, drop_remainder=True) | |||||
| data_set = data_set.batch(batch_size, drop_remainder=True, num_parallel_workers=batch_num_parallel_worker) | |||||
| # apply dataset repeat operation | # apply dataset repeat operation | ||||
| if repeat_num > 1: | if repeat_num > 1: | ||||
| data_set = data_set.repeat(repeat_num) | data_set = data_set.repeat(repeat_num) | ||||
| @@ -165,14 +173,16 @@ def train(): | |||||
| # init context | # init context | ||||
| if args_opt.mode == "GRAPH": | if args_opt.mode == "GRAPH": | ||||
| mode = context.GRAPH_MODE | mode = context.GRAPH_MODE | ||||
| all_reduce_fusion_config = [85, 160] | |||||
| else: | else: | ||||
| mode = context.PYNATIVE_MODE | mode = context.PYNATIVE_MODE | ||||
| all_reduce_fusion_config = [30, 90, 160] | |||||
| context.set_context(mode=mode, device_target=dev, save_graphs=False) | context.set_context(mode=mode, device_target=dev, save_graphs=False) | ||||
| if args_opt.run_distribute: | if args_opt.run_distribute: | ||||
| init() | init() | ||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True, all_reduce_fusion_config=[85, 160]) | |||||
| gradients_mean=True, all_reduce_fusion_config=all_reduce_fusion_config) | |||||
| ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/" | ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/" | ||||
| # create dataset | # create dataset | ||||