Browse Source

!1550 bug fix while evaluation

Merge pull request !1550 from SanjayChan/mobilenet
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
61ac727c00
2 changed files with 14 additions and 8 deletions
  1. +8
    -5
      mindspore/model_zoo/mobilenetv2/src/dataset.py
  2. +6
    -3
      mindspore/model_zoo/mobilenetv3/src/dataset.py

+ 8
- 5
mindspore/model_zoo/mobilenetv2/src/dataset.py View File

@@ -28,8 +28,8 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
repeat_num(int): the repeat times of dataset. Default: 1.
batch_size(int): the batch size of dataset. Default: 32.

Returns:
dataset
@@ -43,9 +43,12 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU":
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
if do_train:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
raise ValueError("Unsupport platform.")



+ 6
- 3
mindspore/model_zoo/mobilenetv3/src/dataset.py View File

@@ -44,9 +44,12 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id)
elif platform == "GPU":
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
if do_train:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
raise ValueError("Unsupport platform.")



Loading…
Cancel
Save