Merge pull request !2557 from huanghui/unsorted-segment-sum-fission-passtags/v0.7.0-beta
| @@ -26,6 +26,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fission/reduce_min_fission.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" | |||
| #include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" | |||
| #include "backend/optimizer/pass/communication_op_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" | |||
| @@ -172,6 +173,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<PackFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<ConcatFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<UnsortSegmentSumFission>()); | |||
| } | |||
| } // namespace | |||
| @@ -0,0 +1,118 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(origin_node); | |||
| std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)), | |||
| origin_node->input(1)}; | |||
| auto padding = graph->NewCNode(padding_inputs); | |||
| MS_EXCEPTION_IF_NULL(padding); | |||
| padding->set_scope(origin_node->scope()); | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); | |||
| shape[shape.size() - 1] = pad_dim_size; | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)}, {shape}, | |||
| padding.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToInt(pad_dim_size)), padding); | |||
| return padding; | |||
| } | |||
| CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding, | |||
| const size_t &pad_dim_size) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(origin_node); | |||
| MS_EXCEPTION_IF_NULL(padding); | |||
| std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding, origin_node->input(2)}; | |||
| auto unsorted_segment_sum = graph->NewCNode(unsorted_segment_sum8_inputs); | |||
| MS_EXCEPTION_IF_NULL(unsorted_segment_sum); | |||
| unsorted_segment_sum->set_scope(origin_node->scope()); | |||
| auto shape = AnfAlgo::GetOutputInferShape(origin_node, 0); | |||
| shape[shape.size() - 1] = pad_dim_size; | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape}, | |||
| unsorted_segment_sum.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(SizeToInt(shape[0])), unsorted_segment_sum); | |||
| return unsorted_segment_sum; | |||
| } | |||
| CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum, | |||
| const CNodePtr &unsorted_segment_sum8) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(unsort_segment_sum); | |||
| MS_EXCEPTION_IF_NULL(unsorted_segment_sum8); | |||
| std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), | |||
| unsorted_segment_sum8}; | |||
| auto slice = graph->NewCNode(slice_inputs); | |||
| MS_EXCEPTION_IF_NULL(slice); | |||
| slice->set_scope(unsort_segment_sum->scope()); | |||
| slice->set_abstract(unsort_segment_sum->abstract()); | |||
| auto unsort_segment_sum_shape = AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0); | |||
| std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0); | |||
| AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Int(offsets)), slice); | |||
| AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Int(unsort_segment_sum_shape)), slice); | |||
| return slice; | |||
| } | |||
| } // namespace | |||
| const BaseRef UnsortSegmentSumFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| VectorRef pattern({prim::kPrimUnsortedSegmentSum, Xs}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto origin_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(origin_node); | |||
| if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) { | |||
| MS_LOG(INFO) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum | |||
| << ". CNode= " << origin_node->DebugString(); | |||
| return nullptr; | |||
| } | |||
| auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); | |||
| if (input0_shape[input0_shape.size() - 1] != 1) { | |||
| MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is " | |||
| << input0_shape[input0_shape.size() - 1]; | |||
| return nullptr; | |||
| } | |||
| size_t pad_dim_size; | |||
| auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0); | |||
| if (input_dtype == kNumberTypeFloat32) { | |||
| pad_dim_size = 8; | |||
| } else if (input_dtype == kNumberTypeFloat16) { | |||
| pad_dim_size = 16; | |||
| } else { | |||
| MS_LOG(INFO) << "UnsortedSegmentSum data type not in (float21, float16), no need change"; | |||
| return nullptr; | |||
| } | |||
| auto padding = CreatePadding(graph, origin_node, pad_dim_size); | |||
| auto unsorted_segment_sum8 = CreateUnsortedSegmentSum(graph, origin_node, padding, pad_dim_size); | |||
| return CreateSlice(graph, origin_node, unsorted_segment_sum8); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class UnsortSegmentSumFission : public PatternProcessPass { | |||
| public: | |||
| explicit UnsortSegmentSumFission(bool multigraph = true) | |||
| : PatternProcessPass("unsorted_segment_sum_fission", multigraph) {} | |||
| ~UnsortSegmentSumFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_FISSION_H_ | |||
| @@ -98,6 +98,7 @@ constexpr size_t kTopkInputNum = 3; | |||
| constexpr size_t kLarsV2InputNum = 5; | |||
| constexpr size_t kFusedMulApplyMomentumOutputNum = 2; | |||
| constexpr size_t kSplitInputNum = 2; | |||
| constexpr size_t kUnsortedSegmentSumInputNum = 2; | |||
| enum FusedBatchNormInput { | |||
| kX = 1, | |||
| @@ -182,6 +182,7 @@ constexpr auto kPushOpName = "Push"; | |||
| constexpr auto kPullOpName = "Pull"; | |||
| constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; | |||
| constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | |||
| constexpr auto kPaddingOpName = "Padding"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| @@ -253,6 +254,10 @@ constexpr auto kAttrInputNums = "inputNums"; | |||
| constexpr auto kAttrT = "T"; | |||
| constexpr auto kAttrNum = "num"; | |||
| constexpr auto kAttrRankSize = "rank_size"; | |||
| constexpr auto kAttrPadDimSize = "pad_dim_size"; | |||
| constexpr auto kAttrNumSegments = "num_segments"; | |||
| constexpr auto kAttrBegin = "begin"; | |||
| constexpr auto kAttrSize = "size"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| context.set_context(save_graphs=True) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.unsorted_segment_sum = P.UnsortedSegmentSum() | |||
| self.num_segments = 3 | |||
| def construct(self, x, segment_ids): | |||
| x = self.unsorted_segment_sum(x, segment_ids, self.num_segments) | |||
| return x | |||
| def test_net(): | |||
| input_x = np.random.randn(3, 39, 1).astype(np.float32) | |||
| segment_ids = Tensor([0, 1, 2], mindspore.int32) | |||
| net = Net() | |||
| output = net(Tensor(input_x), segment_ids) | |||
| print("result", output.asnumpy()) | |||
| if __name__ == "__main__": | |||
| test_net() | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWUnsortedSegmentSumFission : public BackendCommon { | |||
| public: | |||
| TestHWUnsortedSegmentSumFission() : get_py_fun_("gtest_input.pre_activate.unsorted_segment_sum_fission", true) {} | |||
| ~TestHWUnsortedSegmentSumFission() override = default; | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWUnsortedSegmentSumFission, test_fission) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before1"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{16, 1}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::UnsortSegmentSumFission>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "after1"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWUnsortedSegmentSumFission, test_no_fission) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before2"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp_x{16, 2}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::UnsortSegmentSumFission>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "after2"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| unsorted_segment_sum = P.UnsortedSegmentSum() | |||
| num_segments = 4 | |||
| padding = Primitive('Padding') | |||
| op_slice = Primitive('Slice') | |||
| op_unsorted_segment_sum = Primitive('UnsortedSegmentSum') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_unsorted_segment_sum_fission(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before1(input0, input1): | |||
| x = unsorted_segment_sum(input0, input1, num_segments) | |||
| return x | |||
| @fns | |||
| def after1(input0, input1): | |||
| x = padding(input0) | |||
| x = op_unsorted_segment_sum(x, input1) | |||
| x = op_slice(x) | |||
| return make_tuple(x) | |||
| @fns | |||
| def before2(input0, input1): | |||
| x = unsorted_segment_sum(input0, input1, num_segments) | |||
| return x | |||
| @fns | |||
| def after2(input0, input1): | |||
| x = op_unsorted_segment_sum(input0, input1) | |||
| return make_tuple(x) | |||
| return fns[tag] | |||