Browse Source

!6191 add imagenet for alexnet

Merge pull request !6191 from wukesong/imagenet-alexent
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f7c5c5265f
12 changed files with 443 additions and 52 deletions
  1. +38
    -16
      model_zoo/official/cv/alexnet/eval.py
  2. +53
    -0
      model_zoo/official/cv/alexnet/scripts/run_distribution_ascend.sh
  3. +1
    -3
      model_zoo/official/cv/alexnet/scripts/run_standalone_eval_ascend.sh
  4. +1
    -3
      model_zoo/official/cv/alexnet/scripts/run_standalone_eval_gpu.sh
  5. +4
    -3
      model_zoo/official/cv/alexnet/scripts/run_standalone_train_ascend.sh
  6. +4
    -3
      model_zoo/official/cv/alexnet/scripts/run_standalone_train_gpu.sh
  7. +1
    -0
      model_zoo/official/cv/alexnet/src/alexnet.py
  8. +29
    -1
      model_zoo/official/cv/alexnet/src/config.py
  9. +103
    -10
      model_zoo/official/cv/alexnet/src/dataset.py
  10. +85
    -2
      model_zoo/official/cv/alexnet/src/generator_lr.py
  11. +34
    -0
      model_zoo/official/cv/alexnet/src/get_param_groups.py
  12. +90
    -11
      model_zoo/official/cv/alexnet/train.py

+ 38
- 16
model_zoo/official/cv/alexnet/eval.py View File

@@ -20,8 +20,8 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt


import ast import ast
import argparse import argparse
from src.config import alexnet_cfg as cfg
from src.dataset import create_dataset_cifar10
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
from src.alexnet import AlexNet from src.alexnet import AlexNet
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
@@ -32,28 +32,50 @@ from mindspore.nn.metrics import Accuracy


if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, parser.add_argument('--dataset_sink_mode', type=ast.literal_eval,
default=True, help='dataset_sink_mode is False or True') default=True, help='dataset_sink_mode is False or True')
args = parser.parse_args() args = parser.parse_args()


context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)


network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})

print("============== Starting Testing ==============") print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
ds_eval = create_dataset_cifar10(args.data_path,
cfg.batch_size,
status="test")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))

if args.dataset_name == 'cifar10':
cfg = alexnet_cifar10_cfg
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
ds_eval = create_dataset_cifar10(args.data_path, cfg.batch_size, status="test", target=args.device_target)

param_dict = load_checkpoint(args.ckpt_path)
print("load checkpoint from [{}].".format(args.ckpt_path))
load_param_into_net(network, param_dict)
network.set_train(False)

model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})

elif args.dataset_name == 'imagenet':
cfg = alexnet_imagenet_cfg
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
ds_eval = create_dataset_imagenet(args.data_path, cfg.batch_size, training=False)

param_dict = load_checkpoint(args.ckpt_path)
print("load checkpoint from [{}].".format(args.ckpt_path))
load_param_into_net(network, param_dict)
network.set_train(False)

model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})

else:
raise ValueError("Unsupport dataset.")

result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("result : {}".format(result))

+ 53
- 0
model_zoo/official/cv/alexnet/scripts/run_distribution_ascend.sh View File

@@ -0,0 +1,53 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

if [ $# != 3 ]
then
echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet] [DATA_PATH]"
exit 1
fi

if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi

ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export DATASET_NAME=$2
export DATA_PATH=$3
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"

export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
python train.py --device_id=$i --dataset_name=$DATASET_NAME --data_path=$DATA_PATH > log 2>&1 &
cd ..
done

+ 1
- 3
model_zoo/official/cv/alexnet/scripts/run_standalone_eval_ascend.sh View File

@@ -17,6 +17,4 @@
# an simple tutorial as follows, more parameters can be setting # an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0") script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}") self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../eval.py --data_path=./$DATA_PATH --device_target="Ascend" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../eval.py --device_target="Ascend" > log.txt 2>&1 &

+ 1
- 3
model_zoo/official/cv/alexnet/scripts/run_standalone_eval_gpu.sh View File

@@ -17,6 +17,4 @@
# an simple tutorial as follows, more parameters can be setting # an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0") script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}") self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../eval.py --data_path=./$DATA_PATH --device_target="GPU" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../eval.py --device_target="GPU" > log.txt 2>&1 &

+ 4
- 3
model_zoo/official/cv/alexnet/scripts/run_standalone_train_ascend.sh View File

@@ -14,9 +14,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


export DEVICE_NUM=1
export RANK_SIZE=1

# an simple tutorial, more # an simple tutorial, more
script_self=$(readlink -f "$0") script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}") self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../train.py --data_path=./$DATA_PATH --device_target="Ascend" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../train.py --device_target="Ascend" > log.txt 2>&1 &

+ 4
- 3
model_zoo/official/cv/alexnet/scripts/run_standalone_train_gpu.sh View File

@@ -14,9 +14,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


export DEVICE_NUM=1
export RANK_SIZE=1

# an simple tutorial as follows, more parameters can be setting # an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0") script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}") self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../train.py --data_path=./$DATA_PATH --device_target="GPU" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../train.py --device_target="GPU" > log.txt 2>&1 &

+ 1
- 0
model_zoo/official/cv/alexnet/src/alexnet.py View File

@@ -51,6 +51,7 @@ class AlexNet(nn.Cell):
self.fc3 = fc_with_initialize(4096, num_classes) self.fc3 = fc_with_initialize(4096, num_classes)


def construct(self, x): def construct(self, x):
"""define network"""
x = self.conv1(x) x = self.conv1(x)
x = self.relu(x) x = self.relu(x)
x = self.max_pool2d(x) x = self.max_pool2d(x)


+ 29
- 1
model_zoo/official/cv/alexnet/src/config.py View File

@@ -18,7 +18,7 @@ network config setting, will be used in train.py


from easydict import EasyDict as edict from easydict import EasyDict as edict


alexnet_cfg = edict({
alexnet_cifar10_cfg = edict({
'num_classes': 10, 'num_classes': 10,
'learning_rate': 0.002, 'learning_rate': 0.002,
'momentum': 0.9, 'momentum': 0.9,
@@ -30,3 +30,31 @@ alexnet_cfg = edict({
'save_checkpoint_steps': 1562, 'save_checkpoint_steps': 1562,
'keep_checkpoint_max': 10, 'keep_checkpoint_max': 10,
}) })

alexnet_imagenet_cfg = edict({
'num_classes': 1000,
'learning_rate': 0.13,
'momentum': 0.9,
'epoch_size': 150,
'batch_size': 256,
'buffer_size': None, # invalid parameter
'image_height': 227,
'image_width': 227,
'save_checkpoint_steps': 625,
'keep_checkpoint_max': 10,

# opt
'weight_decay': 0.0001,
'loss_scale': 1024,

# lr
'is_dynamic_loss_scale': 0,
'label_smooth': 1,
'label_smooth_factor': 0.1,
'lr_scheduler': 'cosine_annealing',
'warmup_epochs': 5,
'lr_epochs': [30, 60, 90, 120],
'lr_gamma': 0.1,
'T_max': 150,
'eta_min': 0.0,
})

+ 103
- 10
model_zoo/official/cv/alexnet/src/dataset.py View File

@@ -16,20 +16,32 @@
Produce the dataset Produce the dataset
""" """


import os

import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.vision.c_transforms as CV
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .config import alexnet_cfg as cfg
from mindspore.communication.management import get_rank, get_group_size
from .config import alexnet_cifar10_cfg, alexnet_imagenet_cfg




def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train"):
def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train", target="Ascend"):
""" """
create dataset for train or test create dataset for train or test
""" """
cifar_ds = ds.Cifar10Dataset(data_path)

if target == "Ascend":
device_num, rank_id = _get_rank_info()

if target != "Ascend" or device_num == 1:
cifar_ds = ds.Cifar10Dataset(data_path)
else:
cifar_ds = ds.Cifar10Dataset(data_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
rescale = 1.0 / 255.0 rescale = 1.0 / 255.0
shift = 0.0 shift = 0.0
cfg = alexnet_cifar10_cfg


resize_op = CV.Resize((cfg.image_height, cfg.image_width)) resize_op = CV.Resize((cfg.image_height, cfg.image_width))
rescale_op = CV.Rescale(rescale, shift) rescale_op = CV.Rescale(rescale, shift)
@@ -39,16 +51,97 @@ def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="trai
random_horizontal_op = CV.RandomHorizontalFlip() random_horizontal_op = CV.RandomHorizontalFlip()
channel_swap_op = CV.HWC2CHW() channel_swap_op = CV.HWC2CHW()
typecast_op = C.TypeCast(mstype.int32) typecast_op = C.TypeCast(mstype.int32)
cifar_ds = cifar_ds.map(operations=typecast_op, input_columns="label")
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=8)
if status == "train": if status == "train":
cifar_ds = cifar_ds.map(operations=random_crop_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=random_horizontal_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=resize_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=rescale_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=normalize_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=channel_swap_op, input_columns="image")
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op, num_parallel_workers=8)


cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size) cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size)
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
cifar_ds = cifar_ds.repeat(repeat_size) cifar_ds = cifar_ds.repeat(repeat_size)
return cifar_ds return cifar_ds


def create_dataset_imagenet(dataset_path, batch_size=32, repeat_num=1, training=True,
num_parallel_workers=None, shuffle=None, sampler=None, class_indexing=None):
"""
create a train or eval imagenet2012 dataset for resnet50

Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend

Returns:
dataset
"""

device_num, rank_id = _get_rank_info()
cfg = alexnet_imagenet_cfg

if num_parallel_workers is None:
num_parallel_workers = int(64 / device_num)
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers,
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
num_shards=device_num, shard_id=rank_id)

assert cfg.image_height == cfg.image_width, "imagenet_cfg.image_height not equal imagenet_cfg.image_width"
image_size = cfg.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

# define map operations
if training:
transform_img = [
CV.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
CV.RandomHorizontalFlip(prob=0.5),
CV.Normalize(mean=mean, std=std),
CV.HWC2CHW()
]
else:
transform_img = [
CV.Decode(),
CV.Resize((256, 256)),
CV.CenterCrop(image_size),
CV.Normalize(mean=mean, std=std),
CV.HWC2CHW()
]

transform_label = [C.TypeCast(mstype.int32)]

data_set = data_set.map(input_columns="image", num_parallel_workers=num_parallel_workers,
operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=num_parallel_workers,
operations=transform_label)

num_parallel_workers2 = int(16 / device_num)
data_set = data_set.batch(batch_size, num_parallel_workers=num_parallel_workers2, drop_remainder=True)

# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)

return data_set


def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))

if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
# rank_size = rank_id = None

rank_size = 1
rank_id = 0

return rank_size, rank_id

+ 85
- 2
model_zoo/official/cv/alexnet/src/generator_lr.py View File

@@ -13,10 +13,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""learning rate generator""" """learning rate generator"""
import math
from collections import Counter
import numpy as np import numpy as np



def get_lr(current_step, lr_max, total_epochs, steps_per_epoch):
def get_lr_cifar10(current_step, lr_max, total_epochs, steps_per_epoch):
""" """
generate learning rate array generate learning rate array


@@ -42,3 +43,85 @@ def get_lr(current_step, lr_max, total_epochs, steps_per_epoch):
learning_rate = lr_each_step[current_step:] learning_rate = lr_each_step[current_step:]


return learning_rate return learning_rate

def get_lr_imagenet(cfg, steps_per_epoch):
"""generate learning rate array"""
if cfg.lr_scheduler == 'exponential':
lr = warmup_step_lr(cfg.learning_rate,
cfg.lr_epochs,
steps_per_epoch,
cfg.warmup_epochs,
cfg.epoch_size,
gamma=cfg.lr_gamma,
)
elif cfg.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(cfg.learning_rate,
steps_per_epoch,
cfg.warmup_epochs,
cfg.epoch_size,
cfg.T_max,
cfg.eta_min)
else:
raise NotImplementedError(cfg.lr_scheduler)

return lr



def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""Linear learning rate"""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr

def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""Linear warm up learning rate"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)

lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)

return np.array(lr_each_step).astype(np.float32)

def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)

def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)

def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
""" Cosine annealing learning rate"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)

lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)

return np.array(lr_each_step).astype(np.float32)

+ 34
- 0
model_zoo/official/cv/alexnet/src/get_param_groups.py View File

@@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get parameters for Momentum optimizer"""
def get_param_groups(network):
"""get parameters"""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)

return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

+ 90
- 11
model_zoo/official/cv/alexnet/train.py View File

@@ -20,14 +20,18 @@ python train.py --data_path /YourDataPath


import ast import ast
import argparse import argparse
from src.config import alexnet_cfg as cfg
from src.dataset import create_dataset_cifar10
from src.generator_lr import get_lr
import os
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
from src.generator_lr import get_lr_cifar10, get_lr_imagenet
from src.alexnet import AlexNet from src.alexnet import AlexNet
from src.get_param_groups import get_param_groups
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.communication.management import init, get_rank
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.train import Model from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.common import set_seed from mindspore.common import set_seed
@@ -36,27 +40,102 @@ set_seed(1)


if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, parser.add_argument('--dataset_sink_mode', type=ast.literal_eval,
default=True, help='dataset_sink_mode is False or True') default=True, help='dataset_sink_mode is False or True')

parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
args = parser.parse_args() args = parser.parse_args()


if args.dataset_name == "cifar10":
cfg = alexnet_cifar10_cfg
elif args.dataset_name == "imagenet":
cfg = alexnet_imagenet_cfg
else:
raise ValueError("Unsupport dataset.")

device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(save_graphs=False)
device_num = int(os.environ.get("DEVICE_NUM", 1))

if device_target == "Ascend":
context.set_context(device_id=args.device_id)

if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
elif device_target == "GPU":
init()
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
raise ValueError("Unsupported platform.")

if args.dataset_name == "cifar10":
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, target=args.device_target)
elif args.dataset_name == "imagenet":
ds_train = create_dataset_imagenet(args.data_path, cfg.batch_size)
else:
raise ValueError("Unsupport dataset.")


ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1)
network = AlexNet(cfg.num_classes) network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})

loss_scale_manager = None
metrics = None
if args.dataset_name == 'cifar10':
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
lr = Tensor(get_lr_cifar10(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
metrics = {"Accuracy": Accuracy()}

elif args.dataset_name == 'imagenet':
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

lr = Tensor(get_lr_imagenet(cfg, ds_train.get_dataset_size()))

opt = nn.Momentum(params=get_param_groups(network),
learning_rate=lr,
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)

from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
if cfg.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)

else:
raise ValueError("Unsupport dataset.")

if device_target == "Ascend":
model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, amp_level="O2", keep_batchnorm_fp32=False,
loss_scale_manager=loss_scale_manager)
elif device_target == "GPU":
model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, loss_scale_manager=loss_scale_manager)
else:
raise ValueError("Unsupported platform.")

if device_num > 1:
ckpt_save_dir = os.path.join(args.ckpt_path + "_" + str(get_rank()))
else:
ckpt_save_dir = args.ckpt_path

time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
config_ck = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=ckpt_save_dir, config=config_ck)


print("============== Starting Training ==============") print("============== Starting Training ==============")
model.train(cfg.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], model.train(cfg.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],


Loading…
Cancel
Save