From e0e055a0b86a082f1179af672cc74a19ef8f6914 Mon Sep 17 00:00:00 2001 From: lichenever Date: Thu, 11 Jun 2020 20:21:28 +0800 Subject: [PATCH] add sparse gatherv2 --- mindspore/ccsrc/parallel/dynamic_creator.h | 1 + .../parallel/ops_info/gather_v2_p_info.cc | 2 +- .../parallel/ops_info/gather_v2_p_info.h | 12 + mindspore/ccsrc/parallel/ops_info/ops_utils.h | 1 + .../ccsrc/parallel/step_auto_parallel.cc | 2 +- mindspore/ccsrc/parallel/step_parallel.cc | 2 +- .../python/parallel/test_sparse_gather_v2.py | 220 ++++++++++++++++++ 7 files changed, 237 insertions(+), 3 deletions(-) create mode 100644 tests/ut/python/parallel/test_sparse_gather_v2.py diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 4fd5f34cf2..f8e1d62d0a 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -121,6 +121,7 @@ REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo); REGISTER(AssignSubInfo); REGISTER(ReLUInfo); REGISTER(GatherV2Info); +REGISTER(SparseGatherV2Info); REGISTER(SqrtInfo); REGISTER(SigmoidInfo); REGISTER(GetNextInfo); diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index 3d9470e7d8..1c40350e6a 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -399,7 +399,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); auto gather_v2 = - gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); + gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)}); diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index 22aff16b49..b139ee215c 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -63,6 +63,7 @@ class GatherV2PInfo : public OperatorInfo { int32_t axis_; std::string target_; + std::string replace_op_name_ = GATHERV2; int32_t bias_; int32_t slice_size_; Shape out_dev_matrix_shape_; @@ -70,6 +71,17 @@ class GatherV2PInfo : public OperatorInfo { bool reduce_scatter_flag_ = false; int32_t split_num_ = 1; }; + +class SparseGatherV2Info : public GatherV2PInfo { + public: + SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + ~SparseGatherV2Info() override = default; + + private: + std::string replace_op_name_ = SPARSE_GATHERV2; +}; } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 1110bedc3f..bc0d669baa 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -205,6 +205,7 @@ constexpr char EQUAL[] = "Equal"; constexpr char NOT_EQUAL[] = "NotEqual"; constexpr char LOGICALNOT[] = "LogicalNot"; constexpr char GATHERV2[] = "GatherV2"; +constexpr char SPARSE_GATHERV2[] = "SparseGatherV2"; constexpr char STRIDEDSLICE[] = "StridedSlice"; constexpr char BROADCAST[] = "Broadcast"; constexpr char SQRT[] = "Sqrt"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 7d1ff623b9..429241c8b7 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -261,7 +261,7 @@ bool IsSplittableOperator(const std::string &op_name) { REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, - STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, + STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; // clang-format on diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 166ce6b038..4528ff8639 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -535,7 +535,7 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; auto prim = GetValueNode(node->input(0)); - if (prim->name() == GATHERV2) { + if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)}; } if (!params.empty()) { diff --git a/tests/ut/python/parallel/test_sparse_gather_v2.py b/tests/ut/python/parallel/test_sparse_gather_v2.py new file mode 100644 index 0000000000..dd0517a08e --- /dev/null +++ b/tests/ut/python/parallel/test_sparse_gather_v2.py @@ -0,0 +1,220 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return C.grad_all(self.network)(x, y) + + +class Net(nn.Cell): + def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""): + super().__init__() + if shape is None: + shape = [64, 64] + self.gatherv2 = P.SparseGatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target) + self.mul = P.Mul().set_strategy(strategy2) + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.axis = axis + + def construct(self, x, y): + out = self.gatherv2(x, self.index, self.axis) + out = self.mul(out, y) + return out + + +def test_gatherv2_semi_auto0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto3(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto4(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto5(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto6(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, None, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto7(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, None, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_semi_auto8(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8,), (1, 1)) + strategy2 = ((4, 2), (4, 2)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_auto0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss(Net(0))) + net.set_auto_parallel() + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_auto1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss(Net(1))) + net.set_auto_parallel() + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_cpu0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_cpu1(): + context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((16, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_cpu2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y)