|
|
|
@@ -94,11 +94,19 @@ class MyTimeMonitor(Callback): |
|
|
|
|
|
|
|
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16", |
|
|
|
device_num=1): |
|
|
|
if args_opt.mode == "GRAPH": |
|
|
|
ds_num_parallel_worker = 4 |
|
|
|
map_num_parallel_worker = 8 |
|
|
|
batch_num_parallel_worker = None |
|
|
|
else: |
|
|
|
ds_num_parallel_worker = 2 |
|
|
|
map_num_parallel_worker = 3 |
|
|
|
batch_num_parallel_worker = 2 |
|
|
|
ds.config.set_numa_enable(True) |
|
|
|
if device_num == 1: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True) |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True) |
|
|
|
else: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True, |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True, |
|
|
|
num_shards=device_num, shard_id=get_rank()) |
|
|
|
image_size = 224 |
|
|
|
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] |
|
|
|
@@ -127,9 +135,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" |
|
|
|
] |
|
|
|
if dtype == "fp32": |
|
|
|
trans.append(C.HWC2CHW()) |
|
|
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) |
|
|
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=map_num_parallel_worker) |
|
|
|
# apply batch operations |
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True) |
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True, num_parallel_workers=batch_num_parallel_worker) |
|
|
|
# apply dataset repeat operation |
|
|
|
if repeat_num > 1: |
|
|
|
data_set = data_set.repeat(repeat_num) |
|
|
|
@@ -165,14 +173,16 @@ def train(): |
|
|
|
# init context |
|
|
|
if args_opt.mode == "GRAPH": |
|
|
|
mode = context.GRAPH_MODE |
|
|
|
all_reduce_fusion_config = [85, 160] |
|
|
|
else: |
|
|
|
mode = context.PYNATIVE_MODE |
|
|
|
all_reduce_fusion_config = [30, 90, 160] |
|
|
|
context.set_context(mode=mode, device_target=dev, save_graphs=False) |
|
|
|
if args_opt.run_distribute: |
|
|
|
init() |
|
|
|
device_num = get_group_size() |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True, all_reduce_fusion_config=[85, 160]) |
|
|
|
gradients_mean=True, all_reduce_fusion_config=all_reduce_fusion_config) |
|
|
|
ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/" |
|
|
|
|
|
|
|
# create dataset |
|
|
|
|