Merge pull request !5901 from panfengfeng/fix_network_issuetags/v1.0.0
| @@ -1,83 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.operations as P | |||||
| class ShuffleV2Block(nn.Cell): | |||||
| def __init__(self, inp, oup, mid_channels, *, ksize, stride): | |||||
| super(ShuffleV2Block, self).__init__() | |||||
| self.stride = stride | |||||
| ##assert stride in [1, 2] | |||||
| self.mid_channels = mid_channels | |||||
| self.ksize = ksize | |||||
| pad = ksize // 2 | |||||
| self.pad = pad | |||||
| self.inp = inp | |||||
| outputs = oup - inp | |||||
| branch_main = [ | |||||
| # pw | |||||
| nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| # dw | |||||
| nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, | |||||
| pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||||
| # pw-linear | |||||
| nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=outputs, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| ] | |||||
| self.branch_main = nn.SequentialCell(branch_main) | |||||
| if stride == 2: | |||||
| branch_proj = [ | |||||
| # dw | |||||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, | |||||
| pad_mode='pad', padding=pad, group=inp, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||||
| # pw-linear | |||||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| ] | |||||
| self.branch_proj = nn.SequentialCell(branch_proj) | |||||
| else: | |||||
| self.branch_proj = None | |||||
| def construct(self, old_x): | |||||
| if self.stride == 1: | |||||
| x_proj, x = self.channel_shuffle(old_x) | |||||
| return P.Concat(1)((x_proj, self.branch_main(x))) | |||||
| if self.stride == 2: | |||||
| x_proj = old_x | |||||
| x = old_x | |||||
| return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) | |||||
| return None | |||||
| def channel_shuffle(self, x): | |||||
| batchsize, num_channels, height, width = P.Shape()(x) | |||||
| ##assert (num_channels % 4 == 0) | |||||
| x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) | |||||
| x = P.Transpose()(x, (1, 0, 2,)) | |||||
| x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) | |||||
| return x[0], x[1] | |||||
| @@ -23,8 +23,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.config import config_gpu as cfg | from src.config import config_gpu as cfg | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from network import ShuffleNetV2 | |||||
| from src.shufflenetv2 import ShuffleNetV2 | |||||
| from src.CrossEntropySmooth import CrossEntropySmooth | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| parser = argparse.ArgumentParser(description='image classification evaluation') | parser = argparse.ArgumentParser(description='image classification evaluation') | ||||
| @@ -43,8 +43,8 @@ if __name__ == '__main__': | |||||
| load_param_into_net(net, ckpt) | load_param_into_net(net, ckpt) | ||||
| net.set_train(False) | net.set_train(False) | ||||
| dataset = create_dataset(args_opt.dataset_path, False, 0, 1) | dataset = create_dataset(args_opt.dataset_path, False, 0, 1) | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False, | |||||
| smooth_factor=0.1, num_classes=cfg.num_classes) | |||||
| loss = CrossEntropySmooth(sparse=True, reduction='mean', | |||||
| smooth_factor=0.1, num_classes=cfg.num_classes) | |||||
| eval_metrics = {'Loss': nn.Loss(), | eval_metrics = {'Loss': nn.Loss(), | ||||
| 'Top1-Acc': nn.Top1CategoricalAccuracy(), | 'Top1-Acc': nn.Top1CategoricalAccuracy(), | ||||
| 'Top5-Acc': nn.Top5CategoricalAccuracy()} | 'Top5-Acc': nn.Top5CategoricalAccuracy()} | ||||
| @@ -13,10 +13,10 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# -lt 3 ] | |||||
| if [ $# != 3 ] && [ $# != 4 ] | |||||
| then | then | ||||
| echo "Usage: \ | |||||
| sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] \ | |||||
| echo "Usage: | |||||
| sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||||
| " | " | ||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -48,10 +48,15 @@ cd ../train || exit | |||||
| export CUDA_VISIBLE_DEVICES="$2" | export CUDA_VISIBLE_DEVICES="$2" | ||||
| if [ $1 -gt 1 ] | |||||
| if [ $# == 3 ] | |||||
| then | then | ||||
| mpirun -n $1 --allow-run-as-root \ | mpirun -n $1 --allow-run-as-root \ | ||||
| python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 > train.log 2>&1 & | python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 > train.log 2>&1 & | ||||
| else | |||||
| python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$3 > train.log 2>&1 & | |||||
| fi | fi | ||||
| if [ $# == 4 ] | |||||
| then | |||||
| mpirun -n $1 --allow-run-as-root \ | |||||
| python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 --resume=$4 > train.log 2>&1 & | |||||
| fi | |||||
| @@ -13,10 +13,10 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# -lt 1 ] | |||||
| if [ $# != 1 ] && [ $# != 2 ] | |||||
| then | then | ||||
| echo "Usage: \ | |||||
| sh run_standalone_train_for_gpu.sh [DATASET_PATH] \ | |||||
| echo "Usage: | |||||
| sh run_standalone_train_for_gpu.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||||
| " | " | ||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -37,4 +37,12 @@ fi | |||||
| mkdir ../train | mkdir ../train | ||||
| cd ../train || exit | cd ../train || exit | ||||
| python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 > train.log 2>&1 & | |||||
| if [ $# == 1 ] | |||||
| then | |||||
| python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 > train.log 2>&1 & | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 --resume=$2 > train.log 2>&1 & | |||||
| fi | |||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """define loss function for network""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.loss.loss import _Loss | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import operations as P | |||||
| class CrossEntropySmooth(_Loss): | |||||
| """CrossEntropy""" | |||||
| def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): | |||||
| super(CrossEntropySmooth, self).__init__() | |||||
| self.onehot = P.OneHot() | |||||
| self.sparse = sparse | |||||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||||
| self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) | |||||
| def construct(self, logit, label): | |||||
| if self.sparse: | |||||
| label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||||
| loss = self.ce(logit, label) | |||||
| return loss | |||||
| @@ -1,60 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """define loss function for network.""" | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.loss.loss import _Loss | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore import Tensor | |||||
| import mindspore.nn as nn | |||||
| class CrossEntropy(_Loss): | |||||
| """the redefined loss function with SoftmaxCrossEntropyWithLogits""" | |||||
| def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): | |||||
| super(CrossEntropy, self).__init__() | |||||
| self.factor = factor | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||||
| self.mean = P.ReduceMean(False) | |||||
| def construct(self, logits, label): | |||||
| logit, aux = logits | |||||
| one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||||
| loss_logit = self.ce(logit, one_hot_label) | |||||
| loss_logit = self.mean(loss_logit, 0) | |||||
| one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) | |||||
| loss_aux = self.ce(aux, one_hot_label_aux) | |||||
| loss_aux = self.mean(loss_aux, 0) | |||||
| return loss_logit + self.factor*loss_aux | |||||
| class CrossEntropy_Val(_Loss): | |||||
| """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" | |||||
| def __init__(self, smooth_factor=0, num_classes=1000): | |||||
| super(CrossEntropy_Val, self).__init__() | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||||
| self.mean = P.ReduceMean(False) | |||||
| def construct(self, logits, label): | |||||
| one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) | |||||
| loss_logit = self.ce(logits, one_hot_label) | |||||
| loss_logit = self.mean(loss_logit, 0) | |||||
| return loss_logit | |||||
| @@ -14,13 +14,78 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| import numpy as np | import numpy as np | ||||
| from blocks import ShuffleV2Block | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.ops.operations as P | import mindspore.ops.operations as P | ||||
| class ShuffleV2Block(nn.Cell): | |||||
| def __init__(self, inp, oup, mid_channels, *, ksize, stride): | |||||
| super(ShuffleV2Block, self).__init__() | |||||
| self.stride = stride | |||||
| ##assert stride in [1, 2] | |||||
| self.mid_channels = mid_channels | |||||
| self.ksize = ksize | |||||
| pad = ksize // 2 | |||||
| self.pad = pad | |||||
| self.inp = inp | |||||
| outputs = oup - inp | |||||
| branch_main = [ | |||||
| # pw | |||||
| nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| # dw | |||||
| nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, | |||||
| pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||||
| # pw-linear | |||||
| nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=outputs, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| ] | |||||
| self.branch_main = nn.SequentialCell(branch_main) | |||||
| if stride == 2: | |||||
| branch_proj = [ | |||||
| # dw | |||||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, | |||||
| pad_mode='pad', padding=pad, group=inp, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||||
| # pw-linear | |||||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| ] | |||||
| self.branch_proj = nn.SequentialCell(branch_proj) | |||||
| else: | |||||
| self.branch_proj = None | |||||
| def construct(self, old_x): | |||||
| if self.stride == 1: | |||||
| x_proj, x = self.channel_shuffle(old_x) | |||||
| return P.Concat(1)((x_proj, self.branch_main(x))) | |||||
| if self.stride == 2: | |||||
| x_proj = old_x | |||||
| x = old_x | |||||
| return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) | |||||
| return None | |||||
| def channel_shuffle(self, x): | |||||
| batchsize, num_channels, height, width = P.Shape()(x) | |||||
| ##assert (num_channels % 4 == 0) | |||||
| x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) | |||||
| x = P.Transpose()(x, (1, 0, 2,)) | |||||
| x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) | |||||
| return x[0], x[1] | |||||
| class ShuffleNetV2(nn.Cell): | class ShuffleNetV2(nn.Cell): | ||||
| def __init__(self, input_size=224, n_class=1000, model_size='1.0x'): | def __init__(self, input_size=224, n_class=1000, model_size='1.0x'): | ||||
| super(ShuffleNetV2, self).__init__() | super(ShuffleNetV2, self).__init__() | ||||
| @@ -17,7 +17,6 @@ import argparse | |||||
| import ast | import ast | ||||
| import os | import os | ||||
| from network import ShuffleNetV2 | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| @@ -30,9 +29,11 @@ from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src.shufflenetv2 import ShuffleNetV2 | |||||
| from src.config import config_gpu as cfg | from src.config import config_gpu as cfg | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.lr_generator import get_lr_basic | from src.lr_generator import get_lr_basic | ||||
| from src.CrossEntropySmooth import CrossEntropySmooth | |||||
| set_seed(cfg.random_seed) | set_seed(cfg.random_seed) | ||||
| @@ -73,8 +74,8 @@ if __name__ == '__main__': | |||||
| net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size) | net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size) | ||||
| # loss | # loss | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, | |||||
| smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||||
| loss = CrossEntropySmooth(sparse=True, reduction="mean", | |||||
| smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||||
| # learning rate schedule | # learning rate schedule | ||||
| lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size, | lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size, | ||||
| @@ -71,8 +71,14 @@ if __name__ == '__main__': | |||||
| print("Unsupported device_target ", args_opt.device_target) | print("Unsupported device_target ", args_opt.device_target) | ||||
| exit() | exit() | ||||
| else: | else: | ||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | |||||
| if args_opt.device_target == "Ascend": | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | |||||
| elif args_opt.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||||
| else: | |||||
| print("Unsupported device_target ", args_opt.device_target) | |||||
| exit() | |||||
| rank_size = None | rank_size = None | ||||
| rank_id = None | rank_id = None | ||||