diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index c62a5b8c21..5c2b438b5d 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -170,6 +170,84 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= return data_set +def create_dataset_pynative(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False, + enable_cache=False, cache_session_id=None): + """ + create a train or eval imagenet2012 dataset for resnet50 benchmark + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + distribute(bool): data for distribute or not. Default: False + enable_cache(bool): whether tensor caching service is used for eval. Default: False + cache_session_id(int): If enable_cache, cache session_id need to be provided. Default: None + + Returns: + dataset + """ + ds.config.set_numa_enable(True) + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + if distribute: + init() + rank_id = get_rank() + device_num = get_group_size() + else: + device_num = 1 + + if device_num == 1: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=2, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + image_size = 224 + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=4) + # only enable cache for eval + if do_train: + enable_cache = False + if enable_cache: + if not cache_session_id: + raise ValueError("A cache session_id must be provided to use cache.") + eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0) + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=2, + cache=eval_cache) + else: + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=2) + + # apply batch operations + data_set = data_set.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + data_set = data_set.repeat(repeat_num) + + return data_set def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False, enable_cache=False, cache_session_id=None): diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index fcf67e43d1..f03a7e2a9b 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -63,6 +63,8 @@ parser.add_argument("--eval_interval", type=int, default=1, parser.add_argument('--enable_cache', type=ast.literal_eval, default=False, help='Caching the eval dataset in memory to speedup evaluation, default is False.') parser.add_argument('--cache_session_id', type=str, default="", help='The session id for cache service.') +parser.add_argument('--mode', type=str, default='GRAPH', choices=('GRAPH', 'PYNATIVE'), + help="Graph mode or PyNative mode, default is Graph mode") args_opt = parser.parse_args() set_seed(1) @@ -77,7 +79,10 @@ if args_opt.net in ("resnet18", "resnet50"): from src.dataset import create_dataset1 as create_dataset else: from src.config import config2 as config - from src.dataset import create_dataset2 as create_dataset + if args_opt.mode == "GRAPH": + from src.dataset import create_dataset2 as create_dataset + else: + from src.dataset import create_dataset_pynative as create_dataset elif args_opt.net == "resnet101": from src.resnet import resnet101 as resnet @@ -119,7 +124,10 @@ if __name__ == '__main__': ckpt_save_dir = config.save_checkpoint_path # init context - context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + if args_opt.mode == 'GRAPH': + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + else: + context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False) if args_opt.parameter_server: context.set_ps_context(enable_ps=True) if args_opt.run_distribute: