Merge pull request !5926 from caojian05/ms_master_googlenet_support_imagenet_on_ascendtags/v1.0.0
| @@ -25,35 +25,57 @@ from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src.config import cifar_cfg as cfg | |||||
| from src.dataset import create_dataset | |||||
| from src.config import cifar_cfg, imagenet_cfg | |||||
| from src.dataset import create_dataset_cifar10, create_dataset_imagenet | |||||
| from src.googlenet import GoogleNet | from src.googlenet import GoogleNet | ||||
| set_seed(1) | set_seed(1) | ||||
| parser = argparse.ArgumentParser(description='googlenet') | parser = argparse.ArgumentParser(description='googlenet') | ||||
| parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], | |||||
| help='dataset name.') | |||||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if args_opt.dataset_name == 'cifar10': | |||||
| cfg = cifar_cfg | |||||
| dataset = create_dataset_cifar10(cfg.data_path, 1, False) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||||
| net = GoogleNet(num_classes=cfg.num_classes) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, | |||||
| weight_decay=cfg.weight_decay) | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||||
| elif args_opt.dataset_name == "imagenet": | |||||
| cfg = imagenet_cfg | |||||
| dataset = create_dataset_imagenet(cfg.val_data_path, 1, False) | |||||
| if not cfg.use_label_smooth: | |||||
| cfg.label_smooth_factor = 0.0 | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", | |||||
| smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||||
| net = GoogleNet(num_classes=cfg.num_classes) | |||||
| model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) | |||||
| else: | |||||
| raise ValueError("dataset is not support.") | |||||
| device_target = cfg.device_target | device_target = cfg.device_target | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | ||||
| if device_target == "Ascend": | if device_target == "Ascend": | ||||
| context.set_context(device_id=cfg.device_id) | context.set_context(device_id=cfg.device_id) | ||||
| net = GoogleNet(num_classes=cfg.num_classes) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, | |||||
| weight_decay=cfg.weight_decay) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||||
| if device_target == "Ascend": | |||||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
| else: # GPU | |||||
| if args_opt.checkpoint_path is not None: | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | param_dict = load_checkpoint(args_opt.checkpoint_path) | ||||
| print("load checkpoint from [{}].".format(args_opt.checkpoint_path)) | |||||
| else: | |||||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
| print("load checkpoint from [{}].".format(cfg.checkpoint_path)) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| net.set_train(False) | net.set_train(False) | ||||
| dataset = create_dataset(cfg.data_path, 1, False) | |||||
| acc = model.eval(dataset) | acc = model.eval(dataset) | ||||
| print("accuracy: ", acc) | print("accuracy: ", acc) | ||||
| @@ -16,18 +16,32 @@ | |||||
| ##############export checkpoint file into air and onnx models################# | ##############export checkpoint file into air and onnx models################# | ||||
| python export.py | python export.py | ||||
| """ | """ | ||||
| import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | ||||
| from src.config import cifar_cfg as cfg | |||||
| from src.config import cifar_cfg, imagenet_cfg | |||||
| from src.googlenet import GoogleNet | from src.googlenet import GoogleNet | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| parser = argparse.ArgumentParser(description='Classification') | |||||
| parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], | |||||
| help='dataset name.') | |||||
| args_opt = parser.parse_args() | |||||
| if args_opt.dataset_name == 'cifar10': | |||||
| cfg = cifar_cfg | |||||
| elif args_opt.dataset_name == 'imagenet': | |||||
| cfg = imagenet_cfg | |||||
| else: | |||||
| raise ValueError("dataset is not support.") | |||||
| net = GoogleNet(num_classes=cfg.num_classes) | net = GoogleNet(num_classes=cfg.num_classes) | ||||
| assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None." | |||||
| param_dict = load_checkpoint(cfg.checkpoint_path) | param_dict = load_checkpoint(cfg.checkpoint_path) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| @@ -14,9 +14,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# != 1 ] | |||||
| if [ $# != 1 ] && [ $# != 2 ] | |||||
| then | then | ||||
| echo "Usage: sh run_train.sh [RANK_TABLE_FILE]" | |||||
| echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet]" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -26,6 +26,19 @@ then | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| dataset_type='cifar10' | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| if [ $2 != "cifar10" ] && [ $2 != "imagenet" ] | |||||
| then | |||||
| echo "error: the selected dataset is neither cifar10 nor imagenet" | |||||
| exit 1 | |||||
| fi | |||||
| dataset_type=$2 | |||||
| fi | |||||
| ulimit -u unlimited | ulimit -u unlimited | ||||
| export DEVICE_NUM=8 | export DEVICE_NUM=8 | ||||
| export RANK_SIZE=8 | export RANK_SIZE=8 | ||||
| @@ -43,9 +56,9 @@ do | |||||
| mkdir ./train_parallel$i | mkdir ./train_parallel$i | ||||
| cp -r ./src ./train_parallel$i | cp -r ./src ./train_parallel$i | ||||
| cp ./train.py ./train_parallel$i | cp ./train.py ./train_parallel$i | ||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" | |||||
| cd ./train_parallel$i ||exit | cd ./train_parallel$i ||exit | ||||
| env > env.log | env > env.log | ||||
| python train.py --device_id=$i > log 2>&1 & | |||||
| python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 & | |||||
| cd .. | cd .. | ||||
| done | done | ||||
| @@ -18,6 +18,7 @@ network config setting, will be used in main.py | |||||
| from easydict import EasyDict as edict | from easydict import EasyDict as edict | ||||
| cifar_cfg = edict({ | cifar_cfg = edict({ | ||||
| 'name': 'cifar10', | |||||
| 'pre_trained': False, | 'pre_trained': False, | ||||
| 'num_classes': 10, | 'num_classes': 10, | ||||
| 'lr_init': 0.1, | 'lr_init': 0.1, | ||||
| @@ -30,9 +31,45 @@ cifar_cfg = edict({ | |||||
| 'image_width': 224, | 'image_width': 224, | ||||
| 'data_path': './cifar10', | 'data_path': './cifar10', | ||||
| 'device_target': 'Ascend', | 'device_target': 'Ascend', | ||||
| 'device_id': 4, | |||||
| 'device_id': 0, | |||||
| 'keep_checkpoint_max': 10, | 'keep_checkpoint_max': 10, | ||||
| 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt', | 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt', | ||||
| 'onnx_filename': 'googlenet.onnx', | 'onnx_filename': 'googlenet.onnx', | ||||
| 'air_filename': 'googlenet.air' | 'air_filename': 'googlenet.air' | ||||
| }) | }) | ||||
| imagenet_cfg = edict({ | |||||
| 'name': 'imagenet', | |||||
| 'pre_trained': False, | |||||
| 'num_classes': 1000, | |||||
| 'lr_init': 0.1, | |||||
| 'batch_size': 256, | |||||
| 'epoch_size': 300, | |||||
| 'momentum': 0.9, | |||||
| 'weight_decay': 1e-4, | |||||
| 'buffer_size': None, # invalid parameter | |||||
| 'image_height': 224, | |||||
| 'image_width': 224, | |||||
| 'data_path': './ImageNet_Original/train/', | |||||
| 'val_data_path': './ImageNet_Original/val/', | |||||
| 'device_target': 'Ascend', | |||||
| 'device_id': 0, | |||||
| 'keep_checkpoint_max': 10, | |||||
| 'checkpoint_path': None, | |||||
| 'onnx_filename': 'googlenet.onnx', | |||||
| 'air_filename': 'googlenet.air', | |||||
| # optimizer and lr related | |||||
| 'lr_scheduler': 'exponential', | |||||
| 'lr_epochs': [70, 140, 210, 280], | |||||
| 'lr_gamma': 0.3, | |||||
| 'eta_min': 0.0, | |||||
| 'T_max': 150, | |||||
| 'warmup_epochs': 0, | |||||
| # loss related | |||||
| 'is_dynamic_loss_scale': 0, | |||||
| 'loss_scale': 1024, | |||||
| 'label_smooth_factor': 0.1, | |||||
| 'use_label_smooth': True, | |||||
| }) | |||||
| @@ -21,10 +21,10 @@ import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.c_transforms as C | import mindspore.dataset.transforms.c_transforms as C | ||||
| import mindspore.dataset.vision.c_transforms as vision | import mindspore.dataset.vision.c_transforms as vision | ||||
| from src.config import cifar_cfg as cfg | |||||
| from src.config import cifar_cfg, imagenet_cfg | |||||
| def create_dataset(data_home, repeat_num=1, training=True): | |||||
| def create_dataset_cifar10(data_home, repeat_num=1, training=True): | |||||
| """Data operations.""" | """Data operations.""" | ||||
| ds.config.set_seed(1) | ds.config.set_seed(1) | ||||
| data_dir = os.path.join(data_home, "cifar-10-batches-bin") | data_dir = os.path.join(data_home, "cifar-10-batches-bin") | ||||
| @@ -37,14 +37,14 @@ def create_dataset(data_home, repeat_num=1, training=True): | |||||
| else: | else: | ||||
| data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False) | data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False) | ||||
| resize_height = cfg.image_height | |||||
| resize_width = cfg.image_width | |||||
| resize_height = cifar_cfg.image_height | |||||
| resize_width = cifar_cfg.image_width | |||||
| # define map operations | # define map operations | ||||
| random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT | random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT | ||||
| random_horizontal_op = vision.RandomHorizontalFlip() | random_horizontal_op = vision.RandomHorizontalFlip() | ||||
| resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR | resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR | ||||
| rescale_op = vision.Rescale(1.0/255.0, 0.0) | |||||
| rescale_op = vision.Rescale(1.0 / 255.0, 0.0) | |||||
| normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | ||||
| changeswap_op = vision.HWC2CHW() | changeswap_op = vision.HWC2CHW() | ||||
| type_cast_op = C.TypeCast(mstype.int32) | type_cast_op = C.TypeCast(mstype.int32) | ||||
| @@ -59,7 +59,7 @@ def create_dataset(data_home, repeat_num=1, training=True): | |||||
| data_set = data_set.map(input_columns="image", operations=c_trans) | data_set = data_set.map(input_columns="image", operations=c_trans) | ||||
| # apply batch operations | # apply batch operations | ||||
| data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) | |||||
| data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True) | |||||
| # apply repeat operations | # apply repeat operations | ||||
| data_set = data_set.repeat(repeat_num) | data_set = data_set.repeat(repeat_num) | ||||
| @@ -67,6 +67,67 @@ def create_dataset(data_home, repeat_num=1, training=True): | |||||
| return data_set | return data_set | ||||
| def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, | |||||
| num_parallel_workers=None, shuffle=None): | |||||
| """ | |||||
| create a train or eval imagenet2012 dataset for resnet50 | |||||
| Args: | |||||
| dataset_path(string): the path of dataset. | |||||
| do_train(bool): whether dataset is used for train or eval. | |||||
| repeat_num(int): the repeat times of dataset. Default: 1 | |||||
| batch_size(int): the batch size of dataset. Default: 32 | |||||
| target(str): the device target. Default: Ascend | |||||
| Returns: | |||||
| dataset | |||||
| """ | |||||
| device_num, rank_id = _get_rank_info() | |||||
| if device_num == 1: | |||||
| data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle) | |||||
| else: | |||||
| data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle, | |||||
| num_shards=device_num, shard_id=rank_id) | |||||
| assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width" | |||||
| image_size = imagenet_cfg.image_height | |||||
| mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | |||||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | |||||
| # define map operations | |||||
| if training: | |||||
| transform_img = [ | |||||
| vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | |||||
| vision.RandomHorizontalFlip(prob=0.5), | |||||
| vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1), | |||||
| vision.Normalize(mean=mean, std=std), | |||||
| vision.HWC2CHW() | |||||
| ] | |||||
| else: | |||||
| transform_img = [ | |||||
| vision.Decode(), | |||||
| vision.Resize(256), | |||||
| vision.CenterCrop(image_size), | |||||
| vision.Normalize(mean=mean, std=std), | |||||
| vision.HWC2CHW() | |||||
| ] | |||||
| transform_label = [C.TypeCast(mstype.int32)] | |||||
| data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img) | |||||
| data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label) | |||||
| # apply batch operations | |||||
| data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True) | |||||
| # apply dataset repeat operation | |||||
| data_set = data_set.repeat(repeat_num) | |||||
| return data_set | |||||
| def _get_rank_info(): | def _get_rank_info(): | ||||
| """ | """ | ||||
| get rank size and rank id | get rank size and rank id | ||||
| @@ -112,6 +112,7 @@ class GoogleNet(nn.Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| """construct""" | |||||
| x = self.conv1(x) | x = self.conv1(x) | ||||
| x = self.maxpool1(x) | x = self.maxpool1(x) | ||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """lr""" | |||||
| def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): | |||||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||||
| lr = float(init_lr) + lr_inc * current_step | |||||
| return lr | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """lr""" | |||||
| import math | |||||
| import numpy as np | |||||
| from .linear_warmup import linear_warmup_lr | |||||
| def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||||
| """ warmup cosine annealing lr""" | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| lr_each_step = [] | |||||
| for i in range(total_steps): | |||||
| last_epoch = i // steps_per_epoch | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2 | |||||
| lr_each_step.append(lr) | |||||
| return np.array(lr_each_step).astype(np.float32) | |||||
| @@ -0,0 +1,59 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """lr""" | |||||
| from collections import Counter | |||||
| import numpy as np | |||||
| from .linear_warmup import linear_warmup_lr | |||||
| def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): | |||||
| """warmup step lr""" | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| milestones = lr_epochs | |||||
| milestones_steps = [] | |||||
| for milestone in milestones: | |||||
| milestones_step = milestone * steps_per_epoch | |||||
| milestones_steps.append(milestones_step) | |||||
| lr_each_step = [] | |||||
| lr = base_lr | |||||
| milestones_steps_counter = Counter(milestones_steps) | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| lr = lr * gamma ** milestones_steps_counter[i] | |||||
| lr_each_step.append(lr) | |||||
| return np.array(lr_each_step).astype(np.float32) | |||||
| def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): | |||||
| """lr""" | |||||
| return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) | |||||
| def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): | |||||
| """lr""" | |||||
| lr_epochs = [] | |||||
| for i in range(1, max_epoch): | |||||
| if i % epoch_size == 0: | |||||
| lr_epochs.append(i) | |||||
| return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) | |||||
| @@ -27,18 +27,19 @@ from mindspore import context | |||||
| from mindspore.communication.management import init, get_rank | from mindspore.communication.management import init, get_rank | ||||
| from mindspore.nn.optim.momentum import Momentum | from mindspore.nn.optim.momentum import Momentum | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src.config import cifar_cfg as cfg | |||||
| from src.dataset import create_dataset | |||||
| from src.config import cifar_cfg, imagenet_cfg | |||||
| from src.dataset import create_dataset_cifar10, create_dataset_imagenet | |||||
| from src.googlenet import GoogleNet | from src.googlenet import GoogleNet | ||||
| set_seed(1) | set_seed(1) | ||||
| def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): | |||||
| def lr_steps_cifar10(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): | |||||
| """Set learning rate.""" | """Set learning rate.""" | ||||
| lr_each_step = [] | lr_each_step = [] | ||||
| total_steps = steps_per_epoch * total_epochs | total_steps = steps_per_epoch * total_epochs | ||||
| @@ -59,11 +60,46 @@ def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): | |||||
| return learning_rate | return learning_rate | ||||
| def lr_steps_imagenet(_cfg, steps_per_epoch): | |||||
| """lr step for imagenet""" | |||||
| from src.lr_scheduler.warmup_step_lr import warmup_step_lr | |||||
| from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr | |||||
| if _cfg.lr_scheduler == 'exponential': | |||||
| _lr = warmup_step_lr(_cfg.lr_init, | |||||
| _cfg.lr_epochs, | |||||
| steps_per_epoch, | |||||
| _cfg.warmup_epochs, | |||||
| _cfg.epoch_size, | |||||
| gamma=_cfg.lr_gamma, | |||||
| ) | |||||
| elif _cfg.lr_scheduler == 'cosine_annealing': | |||||
| _lr = warmup_cosine_annealing_lr(_cfg.lr_init, | |||||
| steps_per_epoch, | |||||
| _cfg.warmup_epochs, | |||||
| _cfg.epoch_size, | |||||
| _cfg.T_max, | |||||
| _cfg.eta_min) | |||||
| else: | |||||
| raise NotImplementedError(_cfg.lr_scheduler) | |||||
| return _lr | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| parser = argparse.ArgumentParser(description='Cifar10 classification') | |||||
| parser = argparse.ArgumentParser(description='Classification') | |||||
| parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], | |||||
| help='dataset name.') | |||||
| parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') | parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| if args_opt.dataset_name == "cifar10": | |||||
| cfg = cifar_cfg | |||||
| elif args_opt.dataset_name == "imagenet": | |||||
| cfg = imagenet_cfg | |||||
| else: | |||||
| raise ValueError("Unsupport dataset.") | |||||
| # set context | |||||
| device_target = cfg.device_target | device_target = cfg.device_target | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | ||||
| @@ -90,7 +126,13 @@ if __name__ == '__main__': | |||||
| else: | else: | ||||
| raise ValueError("Unsupported platform.") | raise ValueError("Unsupported platform.") | ||||
| dataset = create_dataset(cfg.data_path, 1) | |||||
| if args_opt.dataset_name == "cifar10": | |||||
| dataset = create_dataset_cifar10(cfg.data_path, 1) | |||||
| elif args_opt.dataset_name == "imagenet": | |||||
| dataset = create_dataset_imagenet(cfg.data_path, 1) | |||||
| else: | |||||
| raise ValueError("Unsupport dataset.") | |||||
| batch_num = dataset.get_dataset_size() | batch_num = dataset.get_dataset_size() | ||||
| net = GoogleNet(num_classes=cfg.num_classes) | net = GoogleNet(num_classes=cfg.num_classes) | ||||
| @@ -98,23 +140,75 @@ if __name__ == '__main__': | |||||
| if cfg.pre_trained: | if cfg.pre_trained: | ||||
| param_dict = load_checkpoint(cfg.checkpoint_path) | param_dict = load_checkpoint(cfg.checkpoint_path) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, | |||||
| weight_decay=cfg.weight_decay) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| loss_scale_manager = None | |||||
| if args_opt.dataset_name == 'cifar10': | |||||
| lr = lr_steps_cifar10(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), | |||||
| learning_rate=Tensor(lr), | |||||
| momentum=cfg.momentum, | |||||
| weight_decay=cfg.weight_decay) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||||
| elif args_opt.dataset_name == 'imagenet': | |||||
| lr = lr_steps_imagenet(cfg, batch_num) | |||||
| def get_param_groups(network): | |||||
| """ get param groups """ | |||||
| decay_params = [] | |||||
| no_decay_params = [] | |||||
| for x in network.trainable_params(): | |||||
| parameter_name = x.name | |||||
| if parameter_name.endswith('.bias'): | |||||
| # all bias not using weight decay | |||||
| # print('no decay:{}'.format(parameter_name)) | |||||
| no_decay_params.append(x) | |||||
| elif parameter_name.endswith('.gamma'): | |||||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||||
| # print('no decay:{}'.format(parameter_name)) | |||||
| no_decay_params.append(x) | |||||
| elif parameter_name.endswith('.beta'): | |||||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||||
| # print('no decay:{}'.format(parameter_name)) | |||||
| no_decay_params.append(x) | |||||
| else: | |||||
| decay_params.append(x) | |||||
| return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] | |||||
| if cfg.is_dynamic_loss_scale: | |||||
| cfg.loss_scale = 1 | |||||
| opt = Momentum(params=get_param_groups(net), | |||||
| learning_rate=Tensor(lr), | |||||
| momentum=cfg.momentum, | |||||
| weight_decay=cfg.weight_decay, | |||||
| loss_scale=cfg.loss_scale) | |||||
| if not cfg.use_label_smooth: | |||||
| cfg.label_smooth_factor = 0.0 | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", | |||||
| smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||||
| if cfg.is_dynamic_loss_scale == 1: | |||||
| loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) | |||||
| else: | |||||
| loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) | |||||
| if device_target == "Ascend": | if device_target == "Ascend": | ||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | ||||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager) | |||||
| ckpt_save_dir = "./" | ckpt_save_dir = "./" | ||||
| else: # GPU | |||||
| else: # GPU | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | ||||
| amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None) | |||||
| amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager) | |||||
| ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" | ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) | config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) | ||||
| time_cb = TimeMonitor(data_size=batch_num) | time_cb = TimeMonitor(data_size=batch_num) | ||||
| ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory=ckpt_save_dir, config=config_ck) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir, | |||||
| config=config_ck) | |||||
| loss_cb = LossMonitor() | loss_cb = LossMonitor() | ||||
| model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | ||||
| print("train success") | print("train success") | ||||