Browse Source

!15920 pynative benchmark

From: @jojobugfree
Reviewed-by: @jjfeing,@zhoufeng54
Signed-off-by: @zhoufeng54
pull/15920/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
eb487e8deb
2 changed files with 88 additions and 2 deletions
  1. +78
    -0
      model_zoo/official/cv/resnet/src/dataset.py
  2. +10
    -2
      model_zoo/official/cv/resnet/train.py

+ 78
- 0
model_zoo/official/cv/resnet/src/dataset.py View File

@@ -170,6 +170,84 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=

return data_set

def create_dataset_pynative(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False,
enable_cache=False, cache_session_id=None):
"""
create a train or eval imagenet2012 dataset for resnet50 benchmark

Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
distribute(bool): data for distribute or not. Default: False
enable_cache(bool): whether tensor caching service is used for eval. Default: False
cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None

Returns:
dataset
"""
ds.config.set_numa_enable(True)
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
if distribute:
init()
rank_id = get_rank()
device_num = get_group_size()
else:
device_num = 1

if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=2, shuffle=True,
num_shards=device_num, shard_id=rank_id)

image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]

type_cast_op = C2.TypeCast(mstype.int32)

data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=4)
# only enable cache for eval
if do_train:
enable_cache = False
if enable_cache:
if not cache_session_id:
raise ValueError("A cache session_id must be provided to use cache.")
eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=2,
cache=eval_cache)
else:
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=2)

# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)

# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)

return data_set

def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False,
enable_cache=False, cache_session_id=None):


+ 10
- 2
model_zoo/official/cv/resnet/train.py View File

@@ -63,6 +63,8 @@ parser.add_argument("--eval_interval", type=int, default=1,
parser.add_argument('--enable_cache', type=ast.literal_eval, default=False,
help='Caching the eval dataset in memory to speedup evaluation, default is False.')
parser.add_argument('--cache_session_id', type=str, default="", help='The session id for cache service.')
parser.add_argument('--mode', type=str, default='GRAPH', choices=('GRAPH', 'PYNATIVE'),
help="Graph mode or PyNative mode, default is Graph mode")
args_opt = parser.parse_args()

set_seed(1)
@@ -77,7 +79,10 @@ if args_opt.net in ("resnet18", "resnet50"):
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
if args_opt.mode == "GRAPH":
from src.dataset import create_dataset2 as create_dataset
else:
from src.dataset import create_dataset_pynative as create_dataset

elif args_opt.net == "resnet101":
from src.resnet import resnet101 as resnet
@@ -119,7 +124,10 @@ if __name__ == '__main__':
ckpt_save_dir = config.save_checkpoint_path

# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if args_opt.mode == 'GRAPH':
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False)
if args_opt.parameter_server:
context.set_ps_context(enable_ps=True)
if args_opt.run_distribute:


Loading…
Cancel
Save