Browse Source

!5901 fix network issue

Merge pull request !5901 from panfengfeng/fix_network_issue
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8a71db07c2
9 changed files with 144 additions and 164 deletions
  1. +0
    -83
      model_zoo/official/cv/shufflenetv2/blocks.py
  2. +4
    -4
      model_zoo/official/cv/shufflenetv2/eval.py
  3. +11
    -6
      model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh
  4. +12
    -4
      model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh
  5. +38
    -0
      model_zoo/official/cv/shufflenetv2/src/CrossEntropySmooth.py
  6. +0
    -60
      model_zoo/official/cv/shufflenetv2/src/loss.py
  7. +67
    -2
      model_zoo/official/cv/shufflenetv2/src/shufflenetv2.py
  8. +4
    -3
      model_zoo/official/cv/shufflenetv2/train.py
  9. +8
    -2
      model_zoo/official/recommend/deepfm/train.py

+ 0
- 83
model_zoo/official/cv/shufflenetv2/blocks.py View File

@@ -1,83 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
import mindspore.ops.operations as P


class ShuffleV2Block(nn.Cell):
def __init__(self, inp, oup, mid_channels, *, ksize, stride):
super(ShuffleV2Block, self).__init__()
self.stride = stride
##assert stride in [1, 2]

self.mid_channels = mid_channels
self.ksize = ksize
pad = ksize // 2
self.pad = pad
self.inp = inp

outputs = oup - inp

branch_main = [
# pw
nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=mid_channels, momentum=0.9),
nn.ReLU(),
# dw
nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride,
pad_mode='pad', padding=pad, group=mid_channels, has_bias=False),
nn.BatchNorm2d(num_features=mid_channels, momentum=0.9),
# pw-linear
nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=outputs, momentum=0.9),
nn.ReLU(),
]
self.branch_main = nn.SequentialCell(branch_main)

if stride == 2:
branch_proj = [
# dw
nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride,
pad_mode='pad', padding=pad, group=inp, has_bias=False),
nn.BatchNorm2d(num_features=inp, momentum=0.9),
# pw-linear
nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=inp, momentum=0.9),
nn.ReLU(),
]
self.branch_proj = nn.SequentialCell(branch_proj)
else:
self.branch_proj = None

def construct(self, old_x):
if self.stride == 1:
x_proj, x = self.channel_shuffle(old_x)
return P.Concat(1)((x_proj, self.branch_main(x)))
if self.stride == 2:
x_proj = old_x
x = old_x
return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x)))
return None

def channel_shuffle(self, x):
batchsize, num_channels, height, width = P.Shape()(x)
##assert (num_channels % 4 == 0)
x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,))
x = P.Transpose()(x, (1, 0, 2,))
x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,))
return x[0], x[1]

+ 4
- 4
model_zoo/official/cv/shufflenetv2/eval.py View File

@@ -23,8 +23,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net


from src.config import config_gpu as cfg from src.config import config_gpu as cfg
from src.dataset import create_dataset from src.dataset import create_dataset
from network import ShuffleNetV2
from src.shufflenetv2 import ShuffleNetV2
from src.CrossEntropySmooth import CrossEntropySmooth


if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification evaluation') parser = argparse.ArgumentParser(description='image classification evaluation')
@@ -43,8 +43,8 @@ if __name__ == '__main__':
load_param_into_net(net, ckpt) load_param_into_net(net, ckpt)
net.set_train(False) net.set_train(False)
dataset = create_dataset(args_opt.dataset_path, False, 0, 1) dataset = create_dataset(args_opt.dataset_path, False, 0, 1)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False,
smooth_factor=0.1, num_classes=cfg.num_classes)
loss = CrossEntropySmooth(sparse=True, reduction='mean',
smooth_factor=0.1, num_classes=cfg.num_classes)
eval_metrics = {'Loss': nn.Loss(), eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(), 'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()} 'Top5-Acc': nn.Top5CategoricalAccuracy()}


+ 11
- 6
model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh View File

@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
if [ $# -lt 3 ]
if [ $# != 3 ] && [ $# != 4 ]
then then
echo "Usage: \
sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] \
echo "Usage:
sh run_distribute_train_for_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
" "
exit 1 exit 1
fi fi
@@ -48,10 +48,15 @@ cd ../train || exit


export CUDA_VISIBLE_DEVICES="$2" export CUDA_VISIBLE_DEVICES="$2"


if [ $1 -gt 1 ]
if [ $# == 3 ]
then then
mpirun -n $1 --allow-run-as-root \ mpirun -n $1 --allow-run-as-root \
python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 > train.log 2>&1 & python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 > train.log 2>&1 &
else
python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$3 > train.log 2>&1 &
fi fi

if [ $# == 4 ]
then
mpirun -n $1 --allow-run-as-root \
python ${BASEPATH}/../train.py --platform='GPU' --is_distributed=True --dataset_path=$3 --resume=$4 > train.log 2>&1 &
fi


+ 12
- 4
model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh View File

@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
if [ $# -lt 1 ]
if [ $# != 1 ] && [ $# != 2 ]
then then
echo "Usage: \
sh run_standalone_train_for_gpu.sh [DATASET_PATH] \
echo "Usage:
sh run_standalone_train_for_gpu.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
" "
exit 1 exit 1
fi fi
@@ -37,4 +37,12 @@ fi
mkdir ../train mkdir ../train
cd ../train || exit cd ../train || exit


python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 > train.log 2>&1 &
if [ $# == 1 ]
then
python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 > train.log 2>&1 &
fi

if [ $# == 2 ]
then
python ${BASEPATH}/../train.py --platform='GPU' --dataset_path=$1 --resume=$2 > train.log 2>&1 &
fi

+ 38
- 0
model_zoo/official/cv/shufflenetv2/src/CrossEntropySmooth.py View File

@@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P


class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

+ 0
- 60
model_zoo/official/cv/shufflenetv2/src/loss.py View File

@@ -1,60 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
import mindspore.nn as nn


class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
super(CrossEntropy, self).__init__()
self.factor = factor
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)

def construct(self, logits, label):
logit, aux = logits
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss_logit = self.ce(logit, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value)
loss_aux = self.ce(aux, one_hot_label_aux)
loss_aux = self.mean(loss_aux, 0)
return loss_logit + self.factor*loss_aux


class CrossEntropy_Val(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
def __init__(self, smooth_factor=0, num_classes=1000):
super(CrossEntropy_Val, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)

def construct(self, logits, label):
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit

model_zoo/official/cv/shufflenetv2/network.py → model_zoo/official/cv/shufflenetv2/src/shufflenetv2.py View File

@@ -14,13 +14,78 @@
# ============================================================================ # ============================================================================
import numpy as np import numpy as np


from blocks import ShuffleV2Block

from mindspore import Tensor from mindspore import Tensor
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.operations as P import mindspore.ops.operations as P




class ShuffleV2Block(nn.Cell):
def __init__(self, inp, oup, mid_channels, *, ksize, stride):
super(ShuffleV2Block, self).__init__()
self.stride = stride
##assert stride in [1, 2]

self.mid_channels = mid_channels
self.ksize = ksize
pad = ksize // 2
self.pad = pad
self.inp = inp

outputs = oup - inp

branch_main = [
# pw
nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=mid_channels, momentum=0.9),
nn.ReLU(),
# dw
nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride,
pad_mode='pad', padding=pad, group=mid_channels, has_bias=False),
nn.BatchNorm2d(num_features=mid_channels, momentum=0.9),
# pw-linear
nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=outputs, momentum=0.9),
nn.ReLU(),
]
self.branch_main = nn.SequentialCell(branch_main)

if stride == 2:
branch_proj = [
# dw
nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride,
pad_mode='pad', padding=pad, group=inp, has_bias=False),
nn.BatchNorm2d(num_features=inp, momentum=0.9),
# pw-linear
nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1,
pad_mode='pad', padding=0, has_bias=False),
nn.BatchNorm2d(num_features=inp, momentum=0.9),
nn.ReLU(),
]
self.branch_proj = nn.SequentialCell(branch_proj)
else:
self.branch_proj = None

def construct(self, old_x):
if self.stride == 1:
x_proj, x = self.channel_shuffle(old_x)
return P.Concat(1)((x_proj, self.branch_main(x)))
if self.stride == 2:
x_proj = old_x
x = old_x
return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x)))
return None

def channel_shuffle(self, x):
batchsize, num_channels, height, width = P.Shape()(x)
##assert (num_channels % 4 == 0)
x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,))
x = P.Transpose()(x, (1, 0, 2,))
x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,))
return x[0], x[1]


class ShuffleNetV2(nn.Cell): class ShuffleNetV2(nn.Cell):
def __init__(self, input_size=224, n_class=1000, model_size='1.0x'): def __init__(self, input_size=224, n_class=1000, model_size='1.0x'):
super(ShuffleNetV2, self).__init__() super(ShuffleNetV2, self).__init__()

+ 4
- 3
model_zoo/official/cv/shufflenetv2/train.py View File

@@ -17,7 +17,6 @@ import argparse
import ast import ast
import os import os


from network import ShuffleNetV2


import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
@@ -30,9 +29,11 @@ from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed from mindspore.common import set_seed


from src.shufflenetv2 import ShuffleNetV2
from src.config import config_gpu as cfg from src.config import config_gpu as cfg
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr_basic from src.lr_generator import get_lr_basic
from src.CrossEntropySmooth import CrossEntropySmooth


set_seed(cfg.random_seed) set_seed(cfg.random_seed)


@@ -73,8 +74,8 @@ if __name__ == '__main__':
net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size) net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size)


# loss # loss
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False,
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)


# learning rate schedule # learning rate schedule
lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size, lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size,


+ 8
- 2
model_zoo/official/recommend/deepfm/train.py View File

@@ -71,8 +71,14 @@ if __name__ == '__main__':
print("Unsupported device_target ", args_opt.device_target) print("Unsupported device_target ", args_opt.device_target)
exit() exit()
else: else:
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
print("Unsupported device_target ", args_opt.device_target)
exit()
rank_size = None rank_size = None
rank_id = None rank_id = None




Loading…
Cancel
Save