Browse Source

!13481 fix squeezenet 8p performance degradation by adjusting parameters.

From: @anzhengqi
Reviewed-by: @heleiwang,@liucunwei
Signed-off-by: @liucunwei
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
2cef6a1143
1 changed files with 2 additions and 5 deletions
  1. +2
    -5
      model_zoo/official/cv/squeezenet/src/dataset.py

+ 2
- 5
model_zoo/official/cv/squeezenet/src/dataset.py View File

@@ -125,11 +125,9 @@ def create_dataset_imagenet(dataset_path,


if device_num == 1: if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, data_set = ds.ImageFolderDataset(dataset_path,
num_parallel_workers=8,
shuffle=True) shuffle=True)
else: else:
data_set = ds.ImageFolderDataset(dataset_path, data_set = ds.ImageFolderDataset(dataset_path,
num_parallel_workers=8,
shuffle=True, shuffle=True,
num_shards=device_num, num_shards=device_num,
shard_id=rank_id) shard_id=rank_id)
@@ -162,11 +160,10 @@ def create_dataset_imagenet(dataset_path,
type_cast_op = C2.TypeCast(mstype.int32) type_cast_op = C2.TypeCast(mstype.int32)


data_set = data_set.map(operations=type_cast_op, data_set = data_set.map(operations=type_cast_op,
input_columns="label",
num_parallel_workers=8)
input_columns="label")
data_set = data_set.map(operations=trans, data_set = data_set.map(operations=trans,
input_columns="image", input_columns="image",
num_parallel_workers=8)
num_parallel_workers=10)


# apply batch operations # apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True) data_set = data_set.batch(batch_size, drop_remainder=True)


Loading…
Cancel
Save