diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 953380fb32..723c018d7f 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -127,6 +127,7 @@ REGISTER(NegInfo); REGISTER(BatchMatMulInfo); REGISTER(ExpandDimsInfo); REGISTER(SqueezeInfo); +REGISTER(SigmoidCrossEntropyWithLogitsInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h index 376a1fb4cf..78dfc23803 100644 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h @@ -120,6 +120,15 @@ class AssignSubInfo : public ArithmeticBase { : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~AssignSubInfo() override = default; }; + +// All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. +class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { + public: + SigmoidCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, + const PrimitiveAttrs& attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~SigmoidCrossEntropyWithLogitsInfo() override = default; +}; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 50920e5954..0b52d8e83c 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -138,6 +138,7 @@ constexpr char ALL_GATHER[] = "AllGather"; constexpr char REDUCE_SCATTER[] = "ReduceScatter"; constexpr char CONCAT[] = "Concat"; constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits"; +constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits"; constexpr char MATMUL[] = "MatMul"; constexpr char GELU[] = "Gelu"; constexpr char TANH[] = "Tanh"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 5caf6573f2..1d52eac82d 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -78,6 +78,7 @@ std::vector splittable_op_ = {MATMUL, FUSE_BATCH_NORM, POOLING, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, + SIGMOID_CROSS_ENTROPY_WITH_LOGITS, MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, diff --git a/tests/ut/python/parallel/test_sigmoid_cross_entropy_with_logits.py b/tests/ut/python/parallel/test_sigmoid_cross_entropy_with_logits.py new file mode 100644 index 0000000000..d59d053b07 --- /dev/null +++ b/tests/ut/python/parallel/test_sigmoid_cross_entropy_with_logits.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.api import _executor + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.loss = P.SigmoidCrossEntropyWithLogits().set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.loss(out, b) + return out + + +_x = Tensor(np.ones([128, 64]), dtype=ms.float32) +_w1 = Tensor(np.ones([128, 64]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64]), dtype=ms.float32) + + +def compile(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_sigmoid_cross_entropy_with_logits_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1), (16, 1)) + strategy2 = ((16, 1), (16, 1)) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_sigmoid_cross_entropy_with_logits_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 16), (1, 16)) + strategy2 = ((1, 16), (1, 16)) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_sigmoid_cross_entropy_with_logits_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 8), (2, 8)) + strategy2 = ((2, 8), (2, 8)) + net = Net(_w1, strategy1, strategy2) + compile(net) + + +def test_sigmoid_cross_entropy_with_logits_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1) + compile(net) + + +def test_sigmoid_cross_entropy_with_logits_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 8), (2, 8)) + strategy2 = ((2, 2), (2, 2)) + net = Net(_w1, strategy1, strategy2) + compile(net)