Merge pull request !6191 from wukesong/imagenet-alexenttags/v1.0.0
| @@ -20,8 +20,8 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt | |||
| import ast | |||
| import argparse | |||
| from src.config import alexnet_cfg as cfg | |||
| from src.dataset import create_dataset_cifar10 | |||
| from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg | |||
| from src.dataset import create_dataset_cifar10, create_dataset_imagenet | |||
| from src.alexnet import AlexNet | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| @@ -32,28 +32,50 @@ from mindspore.nn.metrics import Accuracy | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') | |||
| parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], | |||
| help='dataset name.') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||
| path where the trained ckpt file') | |||
| path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, | |||
| default=True, help='dataset_sink_mode is False or True') | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| network = AlexNet(cfg.num_classes) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| repeat_size = cfg.epoch_size | |||
| opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) | |||
| model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Testing ==============") | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| ds_eval = create_dataset_cifar10(args.data_path, | |||
| cfg.batch_size, | |||
| status="test") | |||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== {} ==============".format(acc)) | |||
| if args.dataset_name == 'cifar10': | |||
| cfg = alexnet_cifar10_cfg | |||
| network = AlexNet(cfg.num_classes) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) | |||
| ds_eval = create_dataset_cifar10(args.data_path, cfg.batch_size, status="test", target=args.device_target) | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| print("load checkpoint from [{}].".format(args.ckpt_path)) | |||
| load_param_into_net(network, param_dict) | |||
| network.set_train(False) | |||
| model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) | |||
| elif args.dataset_name == 'imagenet': | |||
| cfg = alexnet_imagenet_cfg | |||
| network = AlexNet(cfg.num_classes) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| ds_eval = create_dataset_imagenet(args.data_path, cfg.batch_size, training=False) | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| print("load checkpoint from [{}].".format(args.ckpt_path)) | |||
| load_param_into_net(network, param_dict) | |||
| network.set_train(False) | |||
| model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) | |||
| else: | |||
| raise ValueError("Unsupport dataset.") | |||
| result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | |||
| print("result : {}".format(result)) | |||
| @@ -0,0 +1,53 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 3 ] | |||
| then | |||
| echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet] [DATA_PATH]" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $1 ] | |||
| then | |||
| echo "error: RANK_TABLE_FILE=$1 is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| RANK_TABLE_FILE=$(realpath $1) | |||
| export RANK_TABLE_FILE | |||
| export DATASET_NAME=$2 | |||
| export DATA_PATH=$3 | |||
| echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" | |||
| export SERVER_ID=0 | |||
| rank_start=$((DEVICE_NUM * SERVER_ID)) | |||
| for((i=0; i<${DEVICE_NUM}; i++)) | |||
| do | |||
| export DEVICE_ID=$i | |||
| export RANK_ID=$((rank_start + i)) | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp -r ./src ./train_parallel$i | |||
| cp ./train.py ./train_parallel$i | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| cd ./train_parallel$i ||exit | |||
| env > env.log | |||
| python train.py --device_id=$i --dataset_name=$DATASET_NAME --data_path=$DATA_PATH > log 2>&1 & | |||
| cd .. | |||
| done | |||
| @@ -17,6 +17,4 @@ | |||
| # an simple tutorial as follows, more parameters can be setting | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| DATA_PATH=$1 | |||
| CKPT_PATH=$2 | |||
| python -s ${self_path}/../eval.py --data_path=./$DATA_PATH --device_target="Ascend" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 & | |||
| python -s ${self_path}/../eval.py --device_target="Ascend" > log.txt 2>&1 & | |||
| @@ -17,6 +17,4 @@ | |||
| # an simple tutorial as follows, more parameters can be setting | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| DATA_PATH=$1 | |||
| CKPT_PATH=$2 | |||
| python -s ${self_path}/../eval.py --data_path=./$DATA_PATH --device_target="GPU" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 & | |||
| python -s ${self_path}/../eval.py --device_target="GPU" > log.txt 2>&1 & | |||
| @@ -14,9 +14,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_NUM=1 | |||
| export RANK_SIZE=1 | |||
| # an simple tutorial, more | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| DATA_PATH=$1 | |||
| CKPT_PATH=$2 | |||
| python -s ${self_path}/../train.py --data_path=./$DATA_PATH --device_target="Ascend" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 & | |||
| python -s ${self_path}/../train.py --device_target="Ascend" > log.txt 2>&1 & | |||
| @@ -14,9 +14,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_NUM=1 | |||
| export RANK_SIZE=1 | |||
| # an simple tutorial as follows, more parameters can be setting | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| DATA_PATH=$1 | |||
| CKPT_PATH=$2 | |||
| python -s ${self_path}/../train.py --data_path=./$DATA_PATH --device_target="GPU" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 & | |||
| python -s ${self_path}/../train.py --device_target="GPU" > log.txt 2>&1 & | |||
| @@ -51,6 +51,7 @@ class AlexNet(nn.Cell): | |||
| self.fc3 = fc_with_initialize(4096, num_classes) | |||
| def construct(self, x): | |||
| """define network""" | |||
| x = self.conv1(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| @@ -18,7 +18,7 @@ network config setting, will be used in train.py | |||
| from easydict import EasyDict as edict | |||
| alexnet_cfg = edict({ | |||
| alexnet_cifar10_cfg = edict({ | |||
| 'num_classes': 10, | |||
| 'learning_rate': 0.002, | |||
| 'momentum': 0.9, | |||
| @@ -30,3 +30,31 @@ alexnet_cfg = edict({ | |||
| 'save_checkpoint_steps': 1562, | |||
| 'keep_checkpoint_max': 10, | |||
| }) | |||
| alexnet_imagenet_cfg = edict({ | |||
| 'num_classes': 1000, | |||
| 'learning_rate': 0.13, | |||
| 'momentum': 0.9, | |||
| 'epoch_size': 150, | |||
| 'batch_size': 256, | |||
| 'buffer_size': None, # invalid parameter | |||
| 'image_height': 227, | |||
| 'image_width': 227, | |||
| 'save_checkpoint_steps': 625, | |||
| 'keep_checkpoint_max': 10, | |||
| # opt | |||
| 'weight_decay': 0.0001, | |||
| 'loss_scale': 1024, | |||
| # lr | |||
| 'is_dynamic_loss_scale': 0, | |||
| 'label_smooth': 1, | |||
| 'label_smooth_factor': 0.1, | |||
| 'lr_scheduler': 'cosine_annealing', | |||
| 'warmup_epochs': 5, | |||
| 'lr_epochs': [30, 60, 90, 120], | |||
| 'lr_gamma': 0.1, | |||
| 'T_max': 150, | |||
| 'eta_min': 0.0, | |||
| }) | |||
| @@ -16,20 +16,32 @@ | |||
| Produce the dataset | |||
| """ | |||
| import os | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| import mindspore.dataset.vision.c_transforms as CV | |||
| from mindspore.common import dtype as mstype | |||
| from .config import alexnet_cfg as cfg | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| from .config import alexnet_cifar10_cfg, alexnet_imagenet_cfg | |||
| def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train"): | |||
| def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train", target="Ascend"): | |||
| """ | |||
| create dataset for train or test | |||
| """ | |||
| cifar_ds = ds.Cifar10Dataset(data_path) | |||
| if target == "Ascend": | |||
| device_num, rank_id = _get_rank_info() | |||
| if target != "Ascend" or device_num == 1: | |||
| cifar_ds = ds.Cifar10Dataset(data_path) | |||
| else: | |||
| cifar_ds = ds.Cifar10Dataset(data_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=device_num, shard_id=rank_id) | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| cfg = alexnet_cifar10_cfg | |||
| resize_op = CV.Resize((cfg.image_height, cfg.image_width)) | |||
| rescale_op = CV.Rescale(rescale, shift) | |||
| @@ -39,16 +51,97 @@ def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="trai | |||
| random_horizontal_op = CV.RandomHorizontalFlip() | |||
| channel_swap_op = CV.HWC2CHW() | |||
| typecast_op = C.TypeCast(mstype.int32) | |||
| cifar_ds = cifar_ds.map(operations=typecast_op, input_columns="label") | |||
| cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=8) | |||
| if status == "train": | |||
| cifar_ds = cifar_ds.map(operations=random_crop_op, input_columns="image") | |||
| cifar_ds = cifar_ds.map(operations=random_horizontal_op, input_columns="image") | |||
| cifar_ds = cifar_ds.map(operations=resize_op, input_columns="image") | |||
| cifar_ds = cifar_ds.map(operations=rescale_op, input_columns="image") | |||
| cifar_ds = cifar_ds.map(operations=normalize_op, input_columns="image") | |||
| cifar_ds = cifar_ds.map(operations=channel_swap_op, input_columns="image") | |||
| cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op, num_parallel_workers=8) | |||
| cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op, num_parallel_workers=8) | |||
| cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=8) | |||
| cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=8) | |||
| cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op, num_parallel_workers=8) | |||
| cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op, num_parallel_workers=8) | |||
| cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size) | |||
| cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) | |||
| cifar_ds = cifar_ds.repeat(repeat_size) | |||
| return cifar_ds | |||
| def create_dataset_imagenet(dataset_path, batch_size=32, repeat_num=1, training=True, | |||
| num_parallel_workers=None, shuffle=None, sampler=None, class_indexing=None): | |||
| """ | |||
| create a train or eval imagenet2012 dataset for resnet50 | |||
| Args: | |||
| dataset_path(string): the path of dataset. | |||
| do_train(bool): whether dataset is used for train or eval. | |||
| repeat_num(int): the repeat times of dataset. Default: 1 | |||
| batch_size(int): the batch size of dataset. Default: 32 | |||
| target(str): the device target. Default: Ascend | |||
| Returns: | |||
| dataset | |||
| """ | |||
| device_num, rank_id = _get_rank_info() | |||
| cfg = alexnet_imagenet_cfg | |||
| if num_parallel_workers is None: | |||
| num_parallel_workers = int(64 / device_num) | |||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, | |||
| shuffle=shuffle, sampler=sampler, class_indexing=class_indexing, | |||
| num_shards=device_num, shard_id=rank_id) | |||
| assert cfg.image_height == cfg.image_width, "imagenet_cfg.image_height not equal imagenet_cfg.image_width" | |||
| image_size = cfg.image_height | |||
| mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | |||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | |||
| # define map operations | |||
| if training: | |||
| transform_img = [ | |||
| CV.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | |||
| CV.RandomHorizontalFlip(prob=0.5), | |||
| CV.Normalize(mean=mean, std=std), | |||
| CV.HWC2CHW() | |||
| ] | |||
| else: | |||
| transform_img = [ | |||
| CV.Decode(), | |||
| CV.Resize((256, 256)), | |||
| CV.CenterCrop(image_size), | |||
| CV.Normalize(mean=mean, std=std), | |||
| CV.HWC2CHW() | |||
| ] | |||
| transform_label = [C.TypeCast(mstype.int32)] | |||
| data_set = data_set.map(input_columns="image", num_parallel_workers=num_parallel_workers, | |||
| operations=transform_img) | |||
| data_set = data_set.map(input_columns="label", num_parallel_workers=num_parallel_workers, | |||
| operations=transform_label) | |||
| num_parallel_workers2 = int(16 / device_num) | |||
| data_set = data_set.batch(batch_size, num_parallel_workers=num_parallel_workers2, drop_remainder=True) | |||
| # apply dataset repeat operation | |||
| data_set = data_set.repeat(repeat_num) | |||
| return data_set | |||
| def _get_rank_info(): | |||
| """ | |||
| get rank size and rank id | |||
| """ | |||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||
| if rank_size > 1: | |||
| rank_size = get_group_size() | |||
| rank_id = get_rank() | |||
| else: | |||
| # rank_size = rank_id = None | |||
| rank_size = 1 | |||
| rank_id = 0 | |||
| return rank_size, rank_id | |||
| @@ -13,10 +13,11 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """learning rate generator""" | |||
| import math | |||
| from collections import Counter | |||
| import numpy as np | |||
| def get_lr(current_step, lr_max, total_epochs, steps_per_epoch): | |||
| def get_lr_cifar10(current_step, lr_max, total_epochs, steps_per_epoch): | |||
| """ | |||
| generate learning rate array | |||
| @@ -42,3 +43,85 @@ def get_lr(current_step, lr_max, total_epochs, steps_per_epoch): | |||
| learning_rate = lr_each_step[current_step:] | |||
| return learning_rate | |||
| def get_lr_imagenet(cfg, steps_per_epoch): | |||
| """generate learning rate array""" | |||
| if cfg.lr_scheduler == 'exponential': | |||
| lr = warmup_step_lr(cfg.learning_rate, | |||
| cfg.lr_epochs, | |||
| steps_per_epoch, | |||
| cfg.warmup_epochs, | |||
| cfg.epoch_size, | |||
| gamma=cfg.lr_gamma, | |||
| ) | |||
| elif cfg.lr_scheduler == 'cosine_annealing': | |||
| lr = warmup_cosine_annealing_lr(cfg.learning_rate, | |||
| steps_per_epoch, | |||
| cfg.warmup_epochs, | |||
| cfg.epoch_size, | |||
| cfg.T_max, | |||
| cfg.eta_min) | |||
| else: | |||
| raise NotImplementedError(cfg.lr_scheduler) | |||
| return lr | |||
| def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): | |||
| """Linear learning rate""" | |||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||
| lr = float(init_lr) + lr_inc * current_step | |||
| return lr | |||
| def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): | |||
| """Linear warm up learning rate""" | |||
| base_lr = lr | |||
| warmup_init_lr = 0 | |||
| total_steps = int(max_epoch * steps_per_epoch) | |||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||
| milestones = lr_epochs | |||
| milestones_steps = [] | |||
| for milestone in milestones: | |||
| milestones_step = milestone * steps_per_epoch | |||
| milestones_steps.append(milestones_step) | |||
| lr_each_step = [] | |||
| lr = base_lr | |||
| milestones_steps_counter = Counter(milestones_steps) | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
| else: | |||
| lr = lr * gamma**milestones_steps_counter[i] | |||
| lr_each_step.append(lr) | |||
| return np.array(lr_each_step).astype(np.float32) | |||
| def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): | |||
| return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) | |||
| def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): | |||
| lr_epochs = [] | |||
| for i in range(1, max_epoch): | |||
| if i % epoch_size == 0: | |||
| lr_epochs.append(i) | |||
| return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) | |||
| def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||
| """ Cosine annealing learning rate""" | |||
| base_lr = lr | |||
| warmup_init_lr = 0 | |||
| total_steps = int(max_epoch * steps_per_epoch) | |||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||
| lr_each_step = [] | |||
| for i in range(total_steps): | |||
| last_epoch = i // steps_per_epoch | |||
| if i < warmup_steps: | |||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
| else: | |||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 | |||
| lr_each_step.append(lr) | |||
| return np.array(lr_each_step).astype(np.float32) | |||
| @@ -0,0 +1,34 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """get parameters for Momentum optimizer""" | |||
| def get_param_groups(network): | |||
| """get parameters""" | |||
| decay_params = [] | |||
| no_decay_params = [] | |||
| for x in network.trainable_params(): | |||
| parameter_name = x.name | |||
| if parameter_name.endswith('.bias'): | |||
| # all bias not using weight decay | |||
| no_decay_params.append(x) | |||
| elif parameter_name.endswith('.gamma'): | |||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||
| no_decay_params.append(x) | |||
| elif parameter_name.endswith('.beta'): | |||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||
| no_decay_params.append(x) | |||
| else: | |||
| decay_params.append(x) | |||
| return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] | |||
| @@ -20,14 +20,18 @@ python train.py --data_path /YourDataPath | |||
| import ast | |||
| import argparse | |||
| from src.config import alexnet_cfg as cfg | |||
| from src.dataset import create_dataset_cifar10 | |||
| from src.generator_lr import get_lr | |||
| import os | |||
| from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg | |||
| from src.dataset import create_dataset_cifar10, create_dataset_imagenet | |||
| from src.generator_lr import get_lr_cifar10, get_lr_imagenet | |||
| from src.alexnet import AlexNet | |||
| from src.get_param_groups import get_param_groups | |||
| import mindspore.nn as nn | |||
| from mindspore.communication.management import init, get_rank | |||
| from mindspore import context | |||
| from mindspore import Tensor | |||
| from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.common import set_seed | |||
| @@ -36,27 +40,102 @@ set_seed(1) | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') | |||
| parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], | |||
| help='dataset name.') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||
| path where the trained ckpt file') | |||
| path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, | |||
| default=True, help='dataset_sink_mode is False or True') | |||
| parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)') | |||
| args = parser.parse_args() | |||
| if args.dataset_name == "cifar10": | |||
| cfg = alexnet_cifar10_cfg | |||
| elif args.dataset_name == "imagenet": | |||
| cfg = alexnet_imagenet_cfg | |||
| else: | |||
| raise ValueError("Unsupport dataset.") | |||
| device_target = args.device_target | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| context.set_context(save_graphs=False) | |||
| device_num = int(os.environ.get("DEVICE_NUM", 1)) | |||
| if device_target == "Ascend": | |||
| context.set_context(device_id=args.device_id) | |||
| if device_num > 1: | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| init() | |||
| elif device_target == "GPU": | |||
| init() | |||
| if device_num > 1: | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| else: | |||
| raise ValueError("Unsupported platform.") | |||
| if args.dataset_name == "cifar10": | |||
| ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, target=args.device_target) | |||
| elif args.dataset_name == "imagenet": | |||
| ds_train = create_dataset_imagenet(args.data_path, cfg.batch_size) | |||
| else: | |||
| raise ValueError("Unsupport dataset.") | |||
| ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1) | |||
| network = AlexNet(cfg.num_classes) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size())) | |||
| opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum) | |||
| model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) | |||
| loss_scale_manager = None | |||
| metrics = None | |||
| if args.dataset_name == 'cifar10': | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| lr = Tensor(get_lr_cifar10(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size())) | |||
| opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum) | |||
| metrics = {"Accuracy": Accuracy()} | |||
| elif args.dataset_name == 'imagenet': | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| lr = Tensor(get_lr_imagenet(cfg, ds_train.get_dataset_size())) | |||
| opt = nn.Momentum(params=get_param_groups(network), | |||
| learning_rate=lr, | |||
| momentum=cfg.momentum, | |||
| weight_decay=cfg.weight_decay, | |||
| loss_scale=cfg.loss_scale) | |||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager | |||
| if cfg.is_dynamic_loss_scale == 1: | |||
| loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) | |||
| else: | |||
| loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) | |||
| else: | |||
| raise ValueError("Unsupport dataset.") | |||
| if device_target == "Ascend": | |||
| model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, amp_level="O2", keep_batchnorm_fp32=False, | |||
| loss_scale_manager=loss_scale_manager) | |||
| elif device_target == "GPU": | |||
| model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, loss_scale_manager=loss_scale_manager) | |||
| else: | |||
| raise ValueError("Unsupported platform.") | |||
| if device_num > 1: | |||
| ckpt_save_dir = os.path.join(args.ckpt_path + "_" + str(get_rank())) | |||
| else: | |||
| ckpt_save_dir = args.ckpt_path | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=ckpt_save_dir, config=config_ck) | |||
| print("============== Starting Training ==============") | |||
| model.train(cfg.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], | |||