Browse Source

!5926 googlenet support imagenet dataset on Ascend

Merge pull request !5926 from caojian05/ms_master_googlenet_support_imagenet_on_ascend
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f5e6fa95f3
11 changed files with 398 additions and 38 deletions
  1. +34
    -12
      model_zoo/official/cv/googlenet/eval.py
  2. +16
    -2
      model_zoo/official/cv/googlenet/export.py
  3. +17
    -4
      model_zoo/official/cv/googlenet/scripts/run_train.sh
  4. +38
    -1
      model_zoo/official/cv/googlenet/src/config.py
  5. +67
    -6
      model_zoo/official/cv/googlenet/src/dataset.py
  6. +1
    -0
      model_zoo/official/cv/googlenet/src/googlenet.py
  7. +0
    -0
      model_zoo/official/cv/googlenet/src/lr_scheduler/__init__.py
  8. +20
    -0
      model_zoo/official/cv/googlenet/src/lr_scheduler/linear_warmup.py
  9. +39
    -0
      model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py
  10. +59
    -0
      model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_step_lr.py
  11. +107
    -13
      model_zoo/official/cv/googlenet/train.py

+ 34
- 12
model_zoo/official/cv/googlenet/eval.py View File

@@ -25,35 +25,57 @@ from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed from mindspore.common import set_seed


from src.config import cifar_cfg as cfg
from src.dataset import create_dataset
from src.config import cifar_cfg, imagenet_cfg
from src.dataset import create_dataset_cifar10, create_dataset_imagenet

from src.googlenet import GoogleNet from src.googlenet import GoogleNet


set_seed(1) set_seed(1)


parser = argparse.ArgumentParser(description='googlenet') parser = argparse.ArgumentParser(description='googlenet')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
args_opt = parser.parse_args() args_opt = parser.parse_args()


if __name__ == '__main__': if __name__ == '__main__':

if args_opt.dataset_name == 'cifar10':
cfg = cifar_cfg
dataset = create_dataset_cifar10(cfg.data_path, 1, False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})

elif args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
net = GoogleNet(num_classes=cfg.num_classes)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})

else:
raise ValueError("dataset is not support.")

device_target = cfg.device_target device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
if device_target == "Ascend": if device_target == "Ascend":
context.set_context(device_id=cfg.device_id) context.set_context(device_id=cfg.device_id)


net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})

if device_target == "Ascend":
param_dict = load_checkpoint(cfg.checkpoint_path)
else: # GPU
if args_opt.checkpoint_path is not None:
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
else:
param_dict = load_checkpoint(cfg.checkpoint_path)
print("load checkpoint from [{}].".format(cfg.checkpoint_path))


load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
dataset = create_dataset(cfg.data_path, 1, False)
acc = model.eval(dataset) acc = model.eval(dataset)
print("accuracy: ", acc) print("accuracy: ", acc)

+ 16
- 2
model_zoo/official/cv/googlenet/export.py View File

@@ -16,18 +16,32 @@
##############export checkpoint file into air and onnx models################# ##############export checkpoint file into air and onnx models#################
python export.py python export.py
""" """
import argparse
import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export


from src.config import cifar_cfg as cfg
from src.config import cifar_cfg, imagenet_cfg
from src.googlenet import GoogleNet from src.googlenet import GoogleNet



if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
args_opt = parser.parse_args()

if args_opt.dataset_name == 'cifar10':
cfg = cifar_cfg
elif args_opt.dataset_name == 'imagenet':
cfg = imagenet_cfg
else:
raise ValueError("dataset is not support.")

net = GoogleNet(num_classes=cfg.num_classes) net = GoogleNet(num_classes=cfg.num_classes)

assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None."
param_dict = load_checkpoint(cfg.checkpoint_path) param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)




+ 17
- 4
model_zoo/official/cv/googlenet/scripts/run_train.sh View File

@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


if [ $# != 1 ]
if [ $# != 1 ] && [ $# != 2 ]
then then
echo "Usage: sh run_train.sh [RANK_TABLE_FILE]"
echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet]"
exit 1 exit 1
fi fi


@@ -26,6 +26,19 @@ then
exit 1 exit 1
fi fi



dataset_type='cifar10'
if [ $# == 2 ]
then
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
dataset_type=$2
fi


ulimit -u unlimited ulimit -u unlimited
export DEVICE_NUM=8 export DEVICE_NUM=8
export RANK_SIZE=8 export RANK_SIZE=8
@@ -43,9 +56,9 @@ do
mkdir ./train_parallel$i mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
cd ./train_parallel$i ||exit cd ./train_parallel$i ||exit
env > env.log env > env.log
python train.py --device_id=$i > log 2>&1 &
python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 &
cd .. cd ..
done done

+ 38
- 1
model_zoo/official/cv/googlenet/src/config.py View File

@@ -18,6 +18,7 @@ network config setting, will be used in main.py
from easydict import EasyDict as edict from easydict import EasyDict as edict


cifar_cfg = edict({ cifar_cfg = edict({
'name': 'cifar10',
'pre_trained': False, 'pre_trained': False,
'num_classes': 10, 'num_classes': 10,
'lr_init': 0.1, 'lr_init': 0.1,
@@ -30,9 +31,45 @@ cifar_cfg = edict({
'image_width': 224, 'image_width': 224,
'data_path': './cifar10', 'data_path': './cifar10',
'device_target': 'Ascend', 'device_target': 'Ascend',
'device_id': 4,
'device_id': 0,
'keep_checkpoint_max': 10, 'keep_checkpoint_max': 10,
'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt', 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt',
'onnx_filename': 'googlenet.onnx', 'onnx_filename': 'googlenet.onnx',
'air_filename': 'googlenet.air' 'air_filename': 'googlenet.air'
}) })

imagenet_cfg = edict({
'name': 'imagenet',
'pre_trained': False,
'num_classes': 1000,
'lr_init': 0.1,
'batch_size': 256,
'epoch_size': 300,
'momentum': 0.9,
'weight_decay': 1e-4,
'buffer_size': None, # invalid parameter
'image_height': 224,
'image_width': 224,
'data_path': './ImageNet_Original/train/',
'val_data_path': './ImageNet_Original/val/',
'device_target': 'Ascend',
'device_id': 0,
'keep_checkpoint_max': 10,
'checkpoint_path': None,
'onnx_filename': 'googlenet.onnx',
'air_filename': 'googlenet.air',

# optimizer and lr related
'lr_scheduler': 'exponential',
'lr_epochs': [70, 140, 210, 280],
'lr_gamma': 0.3,
'eta_min': 0.0,
'T_max': 150,
'warmup_epochs': 0,

# loss related
'is_dynamic_loss_scale': 0,
'loss_scale': 1024,
'label_smooth_factor': 0.1,
'use_label_smooth': True,
})

+ 67
- 6
model_zoo/official/cv/googlenet/src/dataset.py View File

@@ -21,10 +21,10 @@ import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision import mindspore.dataset.vision.c_transforms as vision
from src.config import cifar_cfg as cfg
from src.config import cifar_cfg, imagenet_cfg




def create_dataset(data_home, repeat_num=1, training=True):
def create_dataset_cifar10(data_home, repeat_num=1, training=True):
"""Data operations.""" """Data operations."""
ds.config.set_seed(1) ds.config.set_seed(1)
data_dir = os.path.join(data_home, "cifar-10-batches-bin") data_dir = os.path.join(data_home, "cifar-10-batches-bin")
@@ -37,14 +37,14 @@ def create_dataset(data_home, repeat_num=1, training=True):
else: else:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False) data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False)


resize_height = cfg.image_height
resize_width = cfg.image_width
resize_height = cifar_cfg.image_height
resize_width = cifar_cfg.image_width


# define map operations # define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip() random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = vision.Rescale(1.0/255.0, 0.0)
rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = vision.HWC2CHW() changeswap_op = vision.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
@@ -59,7 +59,7 @@ def create_dataset(data_home, repeat_num=1, training=True):
data_set = data_set.map(input_columns="image", operations=c_trans) data_set = data_set.map(input_columns="image", operations=c_trans)


# apply batch operations # apply batch operations
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True)


# apply repeat operations # apply repeat operations
data_set = data_set.repeat(repeat_num) data_set = data_set.repeat(repeat_num)
@@ -67,6 +67,67 @@ def create_dataset(data_home, repeat_num=1, training=True):
return data_set return data_set




def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
num_parallel_workers=None, shuffle=None):
"""
create a train or eval imagenet2012 dataset for resnet50

Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend

Returns:
dataset
"""

device_num, rank_id = _get_rank_info()

if device_num == 1:
data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
else:
data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
num_shards=device_num, shard_id=rank_id)

assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
image_size = imagenet_cfg.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

# define map operations
if training:
transform_img = [
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transform_img = [
vision.Decode(),
vision.Resize(256),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]

transform_label = [C.TypeCast(mstype.int32)]

data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label)

# apply batch operations
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True)

# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)

return data_set


def _get_rank_info(): def _get_rank_info():
""" """
get rank size and rank id get rank size and rank id


+ 1
- 0
model_zoo/official/cv/googlenet/src/googlenet.py View File

@@ -112,6 +112,7 @@ class GoogleNet(nn.Cell):




def construct(self, x): def construct(self, x):
"""construct"""
x = self.conv1(x) x = self.conv1(x)
x = self.maxpool1(x) x = self.maxpool1(x)




+ 0
- 0
model_zoo/official/cv/googlenet/src/lr_scheduler/__init__.py View File


+ 20
- 0
model_zoo/official/cv/googlenet/src/lr_scheduler/linear_warmup.py View File

@@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr

+ 39
- 0
model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py View File

@@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
import math
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
""" warmup cosine annealing lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

+ 59
- 0
model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_step_lr.py View File

@@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
from collections import Counter
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""warmup step lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma ** milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
"""lr"""
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
"""lr"""
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)

+ 107
- 13
model_zoo/official/cv/googlenet/train.py View File

@@ -27,18 +27,19 @@ from mindspore import context
from mindspore.communication.management import init, get_rank from mindspore.communication.management import init, get_rank
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed from mindspore.common import set_seed


from src.config import cifar_cfg as cfg
from src.dataset import create_dataset
from src.config import cifar_cfg, imagenet_cfg
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
from src.googlenet import GoogleNet from src.googlenet import GoogleNet


set_seed(1) set_seed(1)


def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
def lr_steps_cifar10(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
"""Set learning rate.""" """Set learning rate."""
lr_each_step = [] lr_each_step = []
total_steps = steps_per_epoch * total_epochs total_steps = steps_per_epoch * total_epochs
@@ -59,11 +60,46 @@ def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
return learning_rate return learning_rate




def lr_steps_imagenet(_cfg, steps_per_epoch):
"""lr step for imagenet"""
from src.lr_scheduler.warmup_step_lr import warmup_step_lr
from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
if _cfg.lr_scheduler == 'exponential':
_lr = warmup_step_lr(_cfg.lr_init,
_cfg.lr_epochs,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
gamma=_cfg.lr_gamma,
)
elif _cfg.lr_scheduler == 'cosine_annealing':
_lr = warmup_cosine_annealing_lr(_cfg.lr_init,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
_cfg.T_max,
_cfg.eta_min)
else:
raise NotImplementedError(_cfg.lr_scheduler)

return _lr


if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Cifar10 classification')
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args() args_opt = parser.parse_args()


if args_opt.dataset_name == "cifar10":
cfg = cifar_cfg
elif args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
else:
raise ValueError("Unsupport dataset.")

# set context
device_target = cfg.device_target device_target = cfg.device_target


context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
@@ -90,7 +126,13 @@ if __name__ == '__main__':
else: else:
raise ValueError("Unsupported platform.") raise ValueError("Unsupported platform.")


dataset = create_dataset(cfg.data_path, 1)
if args_opt.dataset_name == "cifar10":
dataset = create_dataset_cifar10(cfg.data_path, 1)
elif args_opt.dataset_name == "imagenet":
dataset = create_dataset_imagenet(cfg.data_path, 1)
else:
raise ValueError("Unsupport dataset.")

batch_num = dataset.get_dataset_size() batch_num = dataset.get_dataset_size()


net = GoogleNet(num_classes=cfg.num_classes) net = GoogleNet(num_classes=cfg.num_classes)
@@ -98,23 +140,75 @@ if __name__ == '__main__':
if cfg.pre_trained: if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path) param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

loss_scale_manager = None
if args_opt.dataset_name == 'cifar10':
lr = lr_steps_cifar10(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
learning_rate=Tensor(lr),
momentum=cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)

elif args_opt.dataset_name == 'imagenet':
lr = lr_steps_imagenet(cfg, batch_num)


def get_param_groups(network):
""" get param groups """
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
else:
decay_params.append(x)

return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]


if cfg.is_dynamic_loss_scale:
cfg.loss_scale = 1

opt = Momentum(params=get_param_groups(net),
learning_rate=Tensor(lr),
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)

if cfg.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)


if device_target == "Ascend": if device_target == "Ascend":
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
ckpt_save_dir = "./" ckpt_save_dir = "./"
else: # GPU
else: # GPU
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None)
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager)
ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/"


config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num) time_cb = TimeMonitor(data_size=batch_num)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory=ckpt_save_dir, config=config_ck)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir,
config=config_ck)
loss_cb = LossMonitor() loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success") print("train success")

Loading…
Cancel
Save