diff --git a/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py b/tests/st/auto_parallel/resnet50_expand_loss.py similarity index 87% rename from tests/st/auto_parallel/test_resnet50_expand_loss_2p.py rename to tests/st/auto_parallel/resnet50_expand_loss.py index fa46a87bc9..0a74752b2d 100644 --- a/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py +++ b/tests/st/auto_parallel/resnet50_expand_loss.py @@ -22,7 +22,7 @@ import mindspore.context as context import mindspore.nn as nn import mindspore.ops.functional as F from mindspore import Tensor -from mindspore.common.initializer import One +from mindspore.common.initializer import TruncatedNormal from mindspore.communication.management import init from mindspore.nn.loss.loss import _Loss from mindspore.nn.optim.momentum import Momentum @@ -35,10 +35,11 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(device_id=int(os.getenv('DEVICE_ID'))) init() context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL) +np.random.seed(10) def weight_variable(): - return One() + return TruncatedNormal(0.01) def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): @@ -93,11 +94,9 @@ class BasicBlock(nn.Cell): identity = x x = self.conv1(x) - x = self.bn1(x) x = self.relu(x) x = self.conv2(x) - x = self.bn2(x) if self.downsample: identity = self.down_sample_layer(identity) @@ -120,13 +119,10 @@ class ResidualBlock(nn.Cell): out_chls = out_channels // self.expansion self.conv1 = _conv1x1(in_channels, out_chls, stride=1) - self.bn1 = _fused_bn(out_chls, momentum=momentum) self.conv2 = _conv3x3(out_chls, out_chls, stride=stride) - self.bn2 = _fused_bn(out_chls, momentum=momentum) self.conv3 = _conv1x1(out_chls, out_channels, stride=1) - self.bn3 = _fused_bn(out_channels, momentum=momentum) self.relu = P.ReLU() self.downsample = (in_channels != out_channels) @@ -134,7 +130,6 @@ class ResidualBlock(nn.Cell): if self.downsample: self.conv_down_sample = _conv1x1(in_channels, out_channels, stride=stride) - self.bn_down_sample = _fused_bn(out_channels, momentum=momentum) elif self.stride != 1: self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same') @@ -144,19 +139,15 @@ class ResidualBlock(nn.Cell): identity = x out = self.conv1(x) - out = self.bn1(out) out = self.relu(out) out = self.conv2(out) - out = self.bn2(out) out = self.relu(out) out = self.conv3(out) - out = self.bn3(out) if self.downsample: identity = self.conv_down_sample(identity) - identity = self.bn_down_sample(identity) elif self.stride != 1: identity = self.maxpool_down(identity) @@ -211,7 +202,7 @@ class ResNet(nn.Cell): self.mean = P.ReduceMean(keep_dims=True) self.end_point = nn.Dense(2048, num_classes, has_bias=True, weight_init=weight_variable(), - bias_init=weight_variable()) + bias_init=weight_variable()).add_flags_recursive(fp16=True) self.squeeze = P.Squeeze() self.cast = P.Cast() @@ -231,7 +222,6 @@ class ResNet(nn.Cell): def construct(self, x): x = self.conv1(x) - x = self.bn1(x) x = self.relu(x) c1 = self.maxpool(x) @@ -277,6 +267,7 @@ class SoftmaxCrossEntropyExpand(_Loss): self.eps = Tensor(1e-24, mstype.float32) def construct(self, logit, label): + logit = self.cast(logit, mstype.float32) logit_max = self.max(logit, -1) exp = self.exp(self.sub(logit, logit_max)) exp_sum = self.sum(exp, -1) @@ -369,41 +360,19 @@ class ModelCallback(Callback): self.loss_list.append(result.asnumpy().mean()) -@pytest.mark.level0 -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_single -def test_train_feed(num_classes=8192): +def test_train_feed(num_classes=65536): set_algo_parameters(elementwise_op_strategy_follow=True) parallel_callback = ModelCallback() data_gen = DataGenerator() - _, input_part = data_gen.input_data((32 * 2, 3, 224, 224)) - _, label_part = data_gen.label_data((32 * 2,)) + _, input_part = data_gen.input_data((32 * 8, 3, 224, 224)) + _, label_part = data_gen.label_data((32 * 8,)) dataset = Dataset(input_part, label_part) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) - expect_out = [9.010913, 8.855984, 8.56246, 8.146317, 7.624489] - assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_single -def test_train_feed2(num_classes=1001): - set_algo_parameters(elementwise_op_strategy_follow=True) - parallel_callback = ModelCallback() - data_gen = DataGenerator() - _, input_part = data_gen.input_data((32 * 2, 3, 224, 224)) - _, label_part = data_gen.label_data((32 * 2,)) - dataset = Dataset(input_part, label_part) - net = resnet50(num_classes) - loss = SoftmaxCrossEntropyExpand(sparse=True) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9) - model = Model(net, loss_fn=loss, optimizer=opt) - model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) - loss_value = np.array(parallel_callback.loss_list) - expect_out = [6.908755, 6.8358116, 6.6986914, 6.506859, 6.2708097] + expect_out = [11.11153, 11.090023, 11.050361, 10.994822, 10.924148] + print(loss_value) assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) diff --git a/tests/st/auto_parallel/run_auto_parallel_resnet50_expand_loss.sh b/tests/st/auto_parallel/run_auto_parallel_resnet50_expand_loss.sh index 094668ba5c..00f3d1af33 100644 --- a/tests/st/auto_parallel/run_auto_parallel_resnet50_expand_loss.sh +++ b/tests/st/auto_parallel/run_auto_parallel_resnet50_expand_loss.sh @@ -18,7 +18,6 @@ BASE_PATH=$(cd "$(dirname $0)"; pwd) CONFIG_PATH=/home/workspace/mindspore_config export DEVICE_NUM=8 export RANK_SIZE=$DEVICE_NUM -ulimit -n 65535 source ${BASE_PATH}/env.sh unset SLOG_PRINT_TO_STDOUT export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json diff --git a/tests/st/auto_parallel/test_resnet50_expand_loss.py b/tests/st/auto_parallel/test_resnet50_expand_loss.py new file mode 100644 index 0000000000..ddcddd73c2 --- /dev/null +++ b/tests/st/auto_parallel/test_resnet50_expand_loss.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_single +def test_expand_loss(): + sh_path = os.path.split(os.path.realpath(__file__))[0] + ret = os.system(f"sh {sh_path}/run_auto_parallel_resnet50_expand_loss.sh") + assert ret == 0