Browse Source

Increase num_parallel in data preprocess of mobilenetv1

tags/v1.2.0-rc1
chenhaozhe 4 years ago
parent
commit
2f05a0a441
2 changed files with 9 additions and 8 deletions
  1. +8
    -8
      model_zoo/official/cv/mobilenetv1/src/dataset.py
  2. +1
    -0
      model_zoo/official/cv/mobilenetv1/train.py

+ 8
- 8
model_zoo/official/cv/mobilenetv1/src/dataset.py View File

@@ -44,9 +44,9 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
device_num = get_group_size()

if device_num == 1:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True)
else:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True,
num_shards=device_num, shard_id=rank_id)

# define map operations
@@ -66,8 +66,8 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=

type_cast_op = C2.TypeCast(mstype.int32)

data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)

# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
@@ -99,9 +99,9 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
device_num = get_group_size()

if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True,
num_shards=device_num, shard_id=rank_id)

image_size = 224
@@ -127,8 +127,8 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=

type_cast_op = C2.TypeCast(mstype.int32)

data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12)

# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)


+ 1
- 0
model_zoo/official/cv/mobilenetv1/train.py View File

@@ -68,6 +68,7 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
context.set_auto_parallel_context(all_reduce_fusion_config=[75])
# GPU target
else:
init()


Loading…
Cancel
Save