| @@ -23,6 +23,7 @@ | |||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" | |||
| #include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" | |||
| #include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" | |||
| #include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" | |||
| #include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" | |||
| #include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" | |||
| #include "pre_activate/pass/communication_op_fusion.h" | |||
| @@ -149,6 +150,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| } | |||
| } // namespace | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "pre_activate/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(tensor_scatter_update); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kTensorMoveOpName)), | |||
| tensor_scatter_update->input(1)}; | |||
| auto tensor_move = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(tensor_move); | |||
| tensor_move->set_scope(tensor_scatter_update->scope()); | |||
| tensor_move->set_abstract(tensor_scatter_update->abstract()); | |||
| AnfAlgo::SetNodeAttr(kAttrUseLocking, MakeValue(false), tensor_move); | |||
| return tensor_move; | |||
| } | |||
| CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update, | |||
| const CNodePtr &tensor_move) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(tensor_scatter_update); | |||
| MS_EXCEPTION_IF_NULL(tensor_move); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kScatterNdUpdateOpName)), tensor_move, | |||
| tensor_scatter_update->input(2), tensor_scatter_update->input(3)}; | |||
| auto scatter_nd_update = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(scatter_nd_update); | |||
| scatter_nd_update->set_scope(tensor_scatter_update->scope()); | |||
| scatter_nd_update->set_abstract(tensor_scatter_update->abstract()); | |||
| return scatter_nd_update; | |||
| } | |||
| } // namespace | |||
| const BaseRef TensorScatterUpdateFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| auto prim = std::make_shared<Primitive>(kTensorScatterUpdateOpName); | |||
| return VectorRef({prim, Xs}); | |||
| } | |||
| const AnfNodePtr TensorScatterUpdateFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto tensor_scatter_update = node->cast<CNodePtr>(); | |||
| if (tensor_scatter_update == nullptr || tensor_scatter_update->size() != 4) { | |||
| return nullptr; | |||
| } | |||
| auto tensor_move = CreateTensorMove(func_graph, tensor_scatter_update); | |||
| return CreateScatterNdUpdate(func_graph, tensor_scatter_update, tensor_move); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ | |||
| #include "pre_activate/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TensorScatterUpdateFission : public PatternProcessPass { | |||
| public: | |||
| explicit TensorScatterUpdateFission(bool multigraph = true) | |||
| : PatternProcessPass("tensor_scatter_update_fission", multigraph) {} | |||
| ~TensorScatterUpdateFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ | |||
| @@ -164,6 +164,18 @@ constexpr auto kStridedReadOpName = "StridedRead"; | |||
| constexpr auto kStridedWriteOpName = "StridedWrite"; | |||
| constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay"; | |||
| constexpr auto kFusedAdamName = "FusedAdam"; | |||
| constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2"; | |||
| constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2"; | |||
| constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl"; | |||
| constexpr auto kSparseApplyFtrlV2OpName = "SparseApplyFtrlV2"; | |||
| constexpr auto kApplyKerasMomentumOpName = "ApplyKerasMomentum"; | |||
| constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad"; | |||
| constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp"; | |||
| constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta"; | |||
| constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; | |||
| constexpr auto kTensorMoveOpName = "TensorMove"; | |||
| constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate"; | |||
| constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| @@ -224,6 +236,9 @@ constexpr auto kAttrOutputNum = "output_num"; | |||
| constexpr auto kAttrSizeSplits = "size_splits"; | |||
| constexpr auto kAttrOutputDefault = "output_default"; | |||
| constexpr auto kAttrPrimitiveTarget = "primitive_target"; | |||
| constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; | |||
| constexpr auto kAttrOffset = "offset"; | |||
| constexpr auto kAttrUseLocking = "use_locking"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWOptTensorScatterUpdateFission : public BackendCommon { | |||
| public: | |||
| TestHWOptTensorScatterUpdateFission() | |||
| : get_py_fun_("gtest_input.pre_activate.tensor_scatter_update_fission_test", true) {} | |||
| ~TestHWOptTensorScatterUpdateFission() override = default; | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWOptTensorScatterUpdateFission, test_fission) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp1{2, 3}; | |||
| std::vector<int> shp2{2, 2}; | |||
| std::vector<int> shp3{2}; | |||
| auto inputx = std::make_shared<abstract::AbstractTensor>(kFloat32, shp1); | |||
| auto indices = std::make_shared<abstract::AbstractTensor>(kInt32, shp2); | |||
| auto update = std::make_shared<abstract::AbstractTensor>(kFloat32, shp3); | |||
| AbstractBasePtrList args_spec_list{inputx, indices, update}; | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::TensorScatterUpdateFission>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(fg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "after"); | |||
| EXPECT_NE(g_after, nullptr); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,50 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| tensor_scatter_update = P.TensorScatterUpdate() | |||
| tensor_move = Primitive('TensorMove') | |||
| scatter_nd_update = Primitive('ScatterNdUpdate') | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_tensor_scatter_update_fission(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(x, indices, updates): | |||
| res = tensor_scatter_update(x, indices, updates) | |||
| return res | |||
| @fns | |||
| def after(x, indices, updates): | |||
| res = tensor_move(x) | |||
| res = scatter_nd_update(res, indices, updates) | |||
| return make_tuple(res) | |||
| return fns[tag] | |||