Merge pull request !5901 from panfengfeng/fix_network_issuetags/v1.0.0
| @@ -1,83 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| class ShuffleV2Block(nn.Cell): | |||
| def __init__(self, inp, oup, mid_channels, *, ksize, stride): | |||
| super(ShuffleV2Block, self).__init__() | |||
| self.stride = stride | |||
| ##assert stride in [1, 2] | |||
| self.mid_channels = mid_channels | |||
| self.ksize = ksize | |||
| pad = ksize // 2 | |||
| self.pad = pad | |||
| self.inp = inp | |||
| outputs = oup - inp | |||
| branch_main = [ | |||
| # pw | |||
| nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, | |||
| pad_mode='pad', padding=0, has_bias=False), | |||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||
| nn.ReLU(), | |||
| # dw | |||
| nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, | |||
| pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), | |||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||
| # pw-linear | |||
| nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, | |||
| pad_mode='pad', padding=0, has_bias=False), | |||
| nn.BatchNorm2d(num_features=outputs, momentum=0.9), | |||
| nn.ReLU(), | |||
| ] | |||
| self.branch_main = nn.SequentialCell(branch_main) | |||
| if stride == 2: | |||
| branch_proj = [ | |||
| # dw | |||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, | |||
| pad_mode='pad', padding=pad, group=inp, has_bias=False), | |||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||
| # pw-linear | |||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, | |||
| pad_mode='pad', padding=0, has_bias=False), | |||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||
| nn.ReLU(), | |||
| ] | |||
| self.branch_proj = nn.SequentialCell(branch_proj) | |||
| else: | |||
| self.branch_proj = None | |||
| def construct(self, old_x): | |||
| if self.stride == 1: | |||
| x_proj, x = self.channel_shuffle(old_x) | |||
| return P.Concat(1)((x_proj, self.branch_main(x))) | |||
| if self.stride == 2: | |||
| x_proj = old_x | |||
| x = old_x | |||
| return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) | |||
| return None | |||
| def channel_shuffle(self, x): | |||
| batchsize, num_channels, height, width = P.Shape()(x) | |||
| ##assert (num_channels % 4 == 0) | |||
| x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) | |||
| x = P.Transpose()(x, (1, 0, 2,)) | |||
| x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) | |||
| return x[0], x[1] | |||
| @@ -23,8 +23,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import config_gpu as cfg | |||
| from src.dataset import create_dataset | |||
| from network import ShuffleNetV2 | |||
| from src.shufflenetv2 import ShuffleNetV2 | |||
| from src.CrossEntropySmooth import CrossEntropySmooth | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='image classification evaluation') | |||
| @@ -43,8 +43,8 @@ if __name__ == '__main__': | |||
| load_param_into_net(net, ckpt) | |||
| net.set_train(False) | |||
| dataset = create_dataset(args_opt.dataset_path, False, 0, 1) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False, | |||
| smooth_factor=0.1, num_classes=cfg.num_classes) | |||
| loss = CrossEntropySmooth(sparse=True, reduction='mean', | |||
| smooth_factor=0.1, num_classes=cfg.num_classes) | |||
| eval_metrics = {'Loss': nn.Loss(), | |||
| 'Top1-Acc': nn.Top1CategoricalAccuracy(), | |||
| 'Top5-Acc': nn.Top5CategoricalAccuracy()} | |||
| @@ -13,10 +13,10 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -lt 3 ] | |||
| if [ $# != 3 ] && [ $# != 4 ] | |||
| then | |||
| echo "Usage: \ | |||
| sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] \ | |||
| echo "Usage: | |||
| sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||
| " | |||
| exit 1 | |||
| fi | |||
| @@ -48,10 +48,15 @@ cd ../train || exit | |||
| export CUDA_VISIBLE_DEVICES="$2" | |||
| if [ $1 -gt 1 ] | |||
| if [ $# == 3 ] | |||
| then | |||
| mpirun -n $1 --allow-run-as-root \ | |||
| python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 > train.log 2>&1 & | |||
| else | |||
| python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$3 > train.log 2>&1 & | |||
| fi | |||
| if [ $# == 4 ] | |||
| then | |||
| mpirun -n $1 --allow-run-as-root \ | |||
| python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 --resume=$4 > train.log 2>&1 & | |||
| fi | |||
| @@ -13,10 +13,10 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -lt 1 ] | |||
| if [ $# != 1 ] && [ $# != 2 ] | |||
| then | |||
| echo "Usage: \ | |||
| sh run_standalone_train_for_gpu.sh [DATASET_PATH] \ | |||
| echo "Usage: | |||
| sh run_standalone_train_for_gpu.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||
| " | |||
| exit 1 | |||
| fi | |||
| @@ -37,4 +37,12 @@ fi | |||
| mkdir ../train | |||
| cd ../train || exit | |||
| python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 > train.log 2>&1 & | |||
| if [ $# == 1 ] | |||
| then | |||
| python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 > train.log 2>&1 & | |||
| fi | |||
| if [ $# == 2 ] | |||
| then | |||
| python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 --resume=$2 > train.log 2>&1 & | |||
| fi | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """define loss function for network""" | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| class CrossEntropySmooth(_Loss): | |||
| """CrossEntropy""" | |||
| def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): | |||
| super(CrossEntropySmooth, self).__init__() | |||
| self.onehot = P.OneHot() | |||
| self.sparse = sparse | |||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||
| self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) | |||
| def construct(self, logit, label): | |||
| if self.sparse: | |||
| label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||
| loss = self.ce(logit, label) | |||
| return loss | |||
| @@ -1,60 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """define loss function for network.""" | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore import Tensor | |||
| import mindspore.nn as nn | |||
| class CrossEntropy(_Loss): | |||
| """the redefined loss function with SoftmaxCrossEntropyWithLogits""" | |||
| def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): | |||
| super(CrossEntropy, self).__init__() | |||
| self.factor = factor | |||
| self.onehot = P.OneHot() | |||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||
| self.mean = P.ReduceMean(False) | |||
| def construct(self, logits, label): | |||
| logit, aux = logits | |||
| one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||
| loss_logit = self.ce(logit, one_hot_label) | |||
| loss_logit = self.mean(loss_logit, 0) | |||
| one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) | |||
| loss_aux = self.ce(aux, one_hot_label_aux) | |||
| loss_aux = self.mean(loss_aux, 0) | |||
| return loss_logit + self.factor*loss_aux | |||
| class CrossEntropy_Val(_Loss): | |||
| """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" | |||
| def __init__(self, smooth_factor=0, num_classes=1000): | |||
| super(CrossEntropy_Val, self).__init__() | |||
| self.onehot = P.OneHot() | |||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||
| self.mean = P.ReduceMean(False) | |||
| def construct(self, logits, label): | |||
| one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) | |||
| loss_logit = self.ce(logits, one_hot_label) | |||
| loss_logit = self.mean(loss_logit, 0) | |||
| return loss_logit | |||
| @@ -14,13 +14,78 @@ | |||
| # ============================================================================ | |||
| import numpy as np | |||
| from blocks import ShuffleV2Block | |||
| from mindspore import Tensor | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| class ShuffleV2Block(nn.Cell): | |||
| def __init__(self, inp, oup, mid_channels, *, ksize, stride): | |||
| super(ShuffleV2Block, self).__init__() | |||
| self.stride = stride | |||
| ##assert stride in [1, 2] | |||
| self.mid_channels = mid_channels | |||
| self.ksize = ksize | |||
| pad = ksize // 2 | |||
| self.pad = pad | |||
| self.inp = inp | |||
| outputs = oup - inp | |||
| branch_main = [ | |||
| # pw | |||
| nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, | |||
| pad_mode='pad', padding=0, has_bias=False), | |||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||
| nn.ReLU(), | |||
| # dw | |||
| nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, | |||
| pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), | |||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||
| # pw-linear | |||
| nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, | |||
| pad_mode='pad', padding=0, has_bias=False), | |||
| nn.BatchNorm2d(num_features=outputs, momentum=0.9), | |||
| nn.ReLU(), | |||
| ] | |||
| self.branch_main = nn.SequentialCell(branch_main) | |||
| if stride == 2: | |||
| branch_proj = [ | |||
| # dw | |||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, | |||
| pad_mode='pad', padding=pad, group=inp, has_bias=False), | |||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||
| # pw-linear | |||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, | |||
| pad_mode='pad', padding=0, has_bias=False), | |||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||
| nn.ReLU(), | |||
| ] | |||
| self.branch_proj = nn.SequentialCell(branch_proj) | |||
| else: | |||
| self.branch_proj = None | |||
| def construct(self, old_x): | |||
| if self.stride == 1: | |||
| x_proj, x = self.channel_shuffle(old_x) | |||
| return P.Concat(1)((x_proj, self.branch_main(x))) | |||
| if self.stride == 2: | |||
| x_proj = old_x | |||
| x = old_x | |||
| return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) | |||
| return None | |||
| def channel_shuffle(self, x): | |||
| batchsize, num_channels, height, width = P.Shape()(x) | |||
| ##assert (num_channels % 4 == 0) | |||
| x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) | |||
| x = P.Transpose()(x, (1, 0, 2,)) | |||
| x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) | |||
| return x[0], x[1] | |||
| class ShuffleNetV2(nn.Cell): | |||
| def __init__(self, input_size=224, n_class=1000, model_size='1.0x'): | |||
| super(ShuffleNetV2, self).__init__() | |||
| @@ -17,7 +17,6 @@ import argparse | |||
| import ast | |||
| import os | |||
| from network import ShuffleNetV2 | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| @@ -30,9 +29,11 @@ from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common import set_seed | |||
| from src.shufflenetv2 import ShuffleNetV2 | |||
| from src.config import config_gpu as cfg | |||
| from src.dataset import create_dataset | |||
| from src.lr_generator import get_lr_basic | |||
| from src.CrossEntropySmooth import CrossEntropySmooth | |||
| set_seed(cfg.random_seed) | |||
| @@ -73,8 +74,8 @@ if __name__ == '__main__': | |||
| net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size) | |||
| # loss | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, | |||
| smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||
| loss = CrossEntropySmooth(sparse=True, reduction="mean", | |||
| smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||
| # learning rate schedule | |||
| lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size, | |||
| @@ -71,8 +71,14 @@ if __name__ == '__main__': | |||
| print("Unsupported device_target ", args_opt.device_target) | |||
| exit() | |||
| else: | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | |||
| if args_opt.device_target == "Ascend": | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | |||
| elif args_opt.device_target == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| else: | |||
| print("Unsupported device_target ", args_opt.device_target) | |||
| exit() | |||
| rank_size = None | |||
| rank_id = None | |||