Browse Source

fix smoke of resnet thor

tags/v1.2.0-rc1
mwang 5 years ago
parent
commit
fbcb3061d8
2 changed files with 16 additions and 13 deletions
  1. +12
    -10
      tests/st/networks/models/resnet50/src_thor/dataset.py
  2. +4
    -3
      tests/st/networks/models/resnet50/test_resnet50_imagenet.py

+ 12
- 10
tests/st/networks/models/resnet50/src_thor/dataset.py View File

@@ -16,18 +16,15 @@
"""create train or eval dataset."""

import os

import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C

ds.config.set_seed(1)
import mindspore.dataset.transforms.c_transforms as C2


def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
Create a train or eval dataset.
create a train or eval dataset.

Args:
dataset_path(string): the path of dataset.
@@ -41,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):

device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
if do_train:
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False,
num_shards=device_num, shard_id=rank_id)

image_size = 224
@@ -54,8 +55,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Decode(),
C.Resize((256, 256)),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]


+ 4
- 3
tests/st/networks/models/resnet50/test_resnet50_imagenet.py View File

@@ -39,6 +39,7 @@ from tests.st.networks.models.resnet50.src.config import config
from tests.st.networks.models.resnet50.src.metric import DistAccuracy, ClassifyCorrectCell
from tests.st.networks.models.resnet50.src.CrossEntropySmooth import CrossEntropySmooth
from tests.st.networks.models.resnet50.src_thor.config import config as thor_config
from tests.st.networks.models.resnet50.src_thor.dataset import create_dataset as create_dataset_thor
from tests.st.networks.models.resnet50.src_thor.model_thor import Model as THOR_Model
from tests.st.networks.models.resnet50.src_thor.resnet import resnet50 as resnet50_thor

@@ -250,8 +251,8 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
num_classes=thor_config.class_num)

# train dataset
dataset = create_dataset(dataset_path=dataset_path, do_train=True,
repeat_num=1, batch_size=thor_config.batch_size)
dataset = create_dataset_thor(dataset_path=dataset_path, do_train=True,
repeat_num=1, batch_size=thor_config.batch_size)

step_size = dataset.get_dataset_size()
eval_interval = thor_config.eval_interval
@@ -367,5 +368,5 @@ def test_resnet_and_resnet_thor_imagenet_4p():
for i in range(4, device_num + 4):
os.system("rm -rf " + str(i))
print("End training...")
assert thor_acc > 0.22
assert thor_acc > 0.25
assert thor_cost < 25

Loading…
Cancel
Save