| @@ -94,6 +94,7 @@ | |||
| #include "pre_activate/ascend/ir_fission/split_fission.h" | |||
| #include "pre_activate/ascend/format_type/modify_ops_attrs.h" | |||
| #include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" | |||
| #include "pre_activate/ascend/ir_fusion/add_input_to_output.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/config_manager.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| @@ -259,6 +260,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| } | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AddInputToOutput>()); | |||
| optimizer->AddPassManager(ir_fusion_pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| @@ -70,6 +70,21 @@ class KernelQuery { | |||
| } | |||
| }; | |||
| using KernelQueryPtr = std::shared_ptr<KernelQuery>; | |||
| class OpFinder { | |||
| public: | |||
| OpFinder() = default; | |||
| virtual ~OpFinder() = default; | |||
| virtual int GetOpRegisteredOutputNum(const std::string &op_name) { | |||
| auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE); | |||
| if (op_info == nullptr) { | |||
| return -1; | |||
| } | |||
| return op_info->outputs_ptr().size(); | |||
| } | |||
| }; | |||
| using OpFinderPtr = std::shared_ptr<OpFinder>; | |||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {}); | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fusion/add_input_to_output.h" | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "pre_activate/ascend/ir_fusion/input_to_output_registry.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| void GetInputOrOutputNames(const CNodePtr &cnode, const std::string &attr_name, std::vector<std::string> *names_vec) { | |||
| MS_EXCEPTION_IF_NULL(names_vec); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| ValuePtr names_value = primitive->GetAttr(attr_name); | |||
| if (names_value == nullptr) { | |||
| return; | |||
| } | |||
| *names_vec = GetValue<std::vector<std::string>>(names_value); | |||
| } | |||
| void AddOutputs(const CNodePtr &cnode, const std::vector<size_t> &input_indices) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<std::string> input_names_vec; | |||
| GetInputOrOutputNames(cnode, kAttrInputNames, &input_names_vec); | |||
| std::vector<std::string> output_names_vec; | |||
| GetInputOrOutputNames(cnode, kAttrOutputNames, &output_names_vec); | |||
| AbstractBasePtrList abstract_list; | |||
| auto origin_abstract = cnode->abstract(); | |||
| MS_EXCEPTION_IF_NULL(origin_abstract); | |||
| if (origin_abstract->isa<abstract::AbstractTuple>()) { | |||
| auto origin_abstract_tuple = dyn_cast<abstract::AbstractTuple>(origin_abstract); | |||
| MS_EXCEPTION_IF_NULL(origin_abstract_tuple); | |||
| AbstractBasePtrList origin_abstract_list = origin_abstract_tuple->elements(); | |||
| (void)std::copy(origin_abstract_list.begin(), origin_abstract_list.end(), std::back_inserter(abstract_list)); | |||
| } else { | |||
| abstract_list.emplace_back(origin_abstract); | |||
| } | |||
| for (size_t i = 0; i < input_indices.size(); ++i) { | |||
| size_t index = input_indices[i]; | |||
| if (index + 1 >= cnode->inputs().size()) { | |||
| MS_LOG(INFO) << "The input index " << index << " for converting to output is out of range, " | |||
| << "node: " << cnode->DebugString(); | |||
| continue; | |||
| } | |||
| auto node_to_output = cnode->input(index + 1); | |||
| MS_EXCEPTION_IF_NULL(node_to_output); | |||
| abstract_list.emplace_back(node_to_output->abstract()); | |||
| if (!input_names_vec.empty() && !output_names_vec.empty() && index < input_names_vec.size()) { | |||
| output_names_vec.emplace_back(input_names_vec[index]); | |||
| } | |||
| } | |||
| if (!output_names_vec.empty()) { | |||
| AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names_vec), cnode); | |||
| } | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| cnode->set_abstract(abstract_tuple); | |||
| } | |||
| } // namespace | |||
| const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::string op_name = AnfAlgo::GetCNodeName(cnode); | |||
| InputToOutputRegister reg; | |||
| if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) { | |||
| return nullptr; | |||
| } | |||
| int output_num = op_finder_->GetOpRegisteredOutputNum(op_name); | |||
| // No need add output when it is not a tbe op. | |||
| if (output_num == -1) { | |||
| return nullptr; | |||
| } | |||
| // No need add output if the output num matches the registered output num for tbe. | |||
| if (AnfAlgo::GetOutputTensorNum(cnode) >= IntToSize(output_num)) { | |||
| return nullptr; | |||
| } | |||
| bool is_origin_tuple_output = AnfAlgo::IsTupleOutput(cnode); | |||
| AddOutputs(cnode, reg.input_indices()); | |||
| // No need to create tuple_getitem if the origin output is a tuple because there has already been some tuple_getitems | |||
| // pointed to the outputs. | |||
| if (is_origin_tuple_output) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> new_outputs; | |||
| auto new_abstract_tuple = dyn_cast<abstract::AbstractTuple>(cnode->abstract()); | |||
| MS_EXCEPTION_IF_NULL(new_abstract_tuple); | |||
| CreateMultipleOutputsOfAnfNode(func_graph, cnode, new_abstract_tuple->size(), &new_outputs); | |||
| if (new_outputs.size() != new_abstract_tuple->size()) { | |||
| MS_LOG(EXCEPTION) << "Failed to create outputs of " << cnode->DebugString(); | |||
| } | |||
| return new_outputs[0]; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class AddInputToOutput : public PatternProcessPass { | |||
| public: | |||
| explicit AddInputToOutput(bool multigraph = true) | |||
| : PatternProcessPass("add_input_to_output", multigraph), op_finder_(std::make_shared<OpFinder>()) {} | |||
| ~AddInputToOutput() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| OpFinderPtr op_finder_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ | |||
| @@ -0,0 +1,122 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fusion/input_to_output_registry.h" | |||
| #include <utility> | |||
| #include "utils/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool ApplyRMSPropPreCheck(const CNodePtr &node) { | |||
| return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); | |||
| } | |||
| bool FusedMulApplyMomentumPreCheck(const CNodePtr &node) { | |||
| TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); | |||
| return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); | |||
| } | |||
| bool SparseApplyRMSPropPreCheck(const CNodePtr &node) { | |||
| return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); | |||
| } | |||
| bool ApplyAdagradV2PreCheck(const CNodePtr &node) { | |||
| TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); | |||
| return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); | |||
| } | |||
| bool ApplyKerasMomentumPreCheck(const CNodePtr &node) { | |||
| TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); | |||
| return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); | |||
| } | |||
| bool SparseApplyFtrlPreCheck(const CNodePtr &node) { | |||
| return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); | |||
| } | |||
| bool SparseApplyFtrlV2PreCheck(const CNodePtr &node) { | |||
| return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); | |||
| } | |||
| bool SparseApplyAdagradV2PreCheck(const CNodePtr &node) { | |||
| return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); | |||
| } | |||
| bool SparseApplyAdadeltaPreCheck(const CNodePtr &node) { | |||
| return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); | |||
| } | |||
| } // namespace | |||
| InputToOutputRegistry::InputToOutputRegistry() { | |||
| Register(kApplyRMSPropOpName, {1, 2}, ApplyRMSPropPreCheck); | |||
| Register(kFusedMulApplyMomentumOpName, {1}, FusedMulApplyMomentumPreCheck); | |||
| Register(kApplyAdagradOpName, {1}); | |||
| Register(kApplyAdagradDAName, {1, 2}); | |||
| Register(kApplyAdadeltaOpName, {1, 2}); | |||
| Register(kApplyPowerSignOpName, {1}); | |||
| Register(kApplyProximalAdagradOpName, {1}); | |||
| Register(kApplyAdaMaxOpName, {1, 2}); | |||
| Register(kApplyAdagradV2OpName, {1}, ApplyAdagradV2PreCheck); | |||
| Register(kApplyKerasMomentumOpName, {1}, ApplyKerasMomentumPreCheck); | |||
| Register(kSparseApplyFtrlOpName, {1, 2}, SparseApplyFtrlPreCheck); | |||
| Register(kSparseApplyFtrlV2OpName, {1, 2}, SparseApplyFtrlV2PreCheck); | |||
| Register(kSparseApplyAdagradV2OpName, {1}, SparseApplyAdagradV2PreCheck); | |||
| Register(kSparseApplyProximalAdagradOpName, {1}); | |||
| Register(kSparseApplyAdagradOpName, {1}); | |||
| Register(kApplyFtrlV2OpName, {1, 2}); | |||
| Register(kApplyMomentumOpName, {1}); | |||
| Register(kApplyFtrlOpName, {1, 2}); | |||
| Register(kApplyAdamOpName, {1, 2}); | |||
| Register(kApplyCenteredRMSPropOpName, {1, 2, 3}); | |||
| Register(kApplyAddSignOpName, {1}); | |||
| Register(kSparseApplyRMSPropOpName, {1, 2}, SparseApplyRMSPropPreCheck); | |||
| Register(kSparseApplyAdadeltaOpName, {1, 2}, SparseApplyAdadeltaPreCheck); | |||
| Register(kApplyAdamWithAmsgradOpName, {1, 2}); | |||
| } | |||
| InputToOutputRegistry &InputToOutputRegistry::Instance() { | |||
| static InputToOutputRegistry instance; | |||
| return instance; | |||
| } | |||
| void InputToOutputRegistry::Register(const InputToOutputRegister ®) { | |||
| auto op_name = reg.op_name(); | |||
| if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { | |||
| (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); | |||
| MS_LOG(DEBUG) << op_name << " input2output register successfully!"; | |||
| } | |||
| } | |||
| void InputToOutputRegistry::Register(const std::string &op_name, const std::vector<size_t> &input_indices, | |||
| const PreCheckFunc &pre_check_func) { | |||
| if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { | |||
| InputToOutputRegister reg(op_name, pre_check_func); | |||
| reg.set_input_indices(input_indices); | |||
| (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); | |||
| MS_LOG(DEBUG) << op_name << " input2output register successfully!"; | |||
| } | |||
| } | |||
| bool InputToOutputRegistry::GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const { | |||
| if (op_input_to_output_map_.find(op_name) != op_input_to_output_map_.end()) { | |||
| *reg = op_input_to_output_map_.at(op_name); | |||
| MS_LOG(DEBUG) << op_name << " input2output find in registry."; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "ir/anf.h" | |||
| #include "common/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| using PreCheckFunc = std::function<bool(const CNodePtr &node)>; | |||
| class InputToOutputRegister { | |||
| public: | |||
| explicit InputToOutputRegister( | |||
| const std::string &op_name = "", const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; }) | |||
| : op_name_(op_name), pre_check_func_(pre_check_func) {} | |||
| virtual ~InputToOutputRegister() = default; | |||
| void set_input_indices(const std::vector<size_t> &input_indices) { input_indices_ = input_indices; } | |||
| const std::vector<size_t> &input_indices() const { return input_indices_; } | |||
| const std::string &op_name() const { return op_name_; } | |||
| private: | |||
| std::string op_name_; | |||
| std::vector<size_t> input_indices_; | |||
| PreCheckFunc pre_check_func_; | |||
| }; | |||
| class InputToOutputRegistry { | |||
| public: | |||
| static InputToOutputRegistry &Instance(); | |||
| void Register(const InputToOutputRegister ®); | |||
| void Register( | |||
| const std::string &op_name, const std::vector<size_t> &input_indices, | |||
| const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; }); | |||
| bool GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const; | |||
| private: | |||
| InputToOutputRegistry(); | |||
| ~InputToOutputRegistry() = default; | |||
| DISABLE_COPY_AND_ASSIGN(InputToOutputRegistry) | |||
| std::unordered_map<std::string, InputToOutputRegister> op_input_to_output_map_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ | |||
| @@ -164,6 +164,15 @@ constexpr auto kStridedReadOpName = "StridedRead"; | |||
| constexpr auto kStridedWriteOpName = "StridedWrite"; | |||
| constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay"; | |||
| constexpr auto kFusedAdamName = "FusedAdam"; | |||
| constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2"; | |||
| constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2"; | |||
| constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl"; | |||
| constexpr auto kSparseApplyFtrlV2OpName = "SparseApplyFtrlV2"; | |||
| constexpr auto kApplyKerasMomentumOpName = "ApplyKerasMomentum"; | |||
| constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad"; | |||
| constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp"; | |||
| constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta"; | |||
| constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #define private public | |||
| #define protected public | |||
| #include "pre_activate/ascend/ir_fusion/add_input_to_output.h" | |||
| #undef private | |||
| #undef protected | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWAddInputToOutput : public BackendCommon { | |||
| public: | |||
| TestHWAddInputToOutput() : getPyFun_("gtest_input.pre_activate.add_input_to_output_test", true) {} | |||
| ~TestHWAddInputToOutput() override = default; | |||
| public: | |||
| UT::PyFuncGraphFetcher getPyFun_; | |||
| }; | |||
| class MockOpFinder : public OpFinder { | |||
| public: | |||
| MockOpFinder() = default; | |||
| ~MockOpFinder() override = default; | |||
| int GetOpRegisteredOutputNum(const std::string &op_name) override { return 2; } | |||
| }; | |||
| TEST_F(TestHWAddInputToOutput, test_add_input_to_output) { | |||
| FuncGraphPtr g = getPyFun_.CallAndParseRet("test_add_input_to_output", "before"); | |||
| EXPECT_NE(g, nullptr); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 5; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| auto ret = kg->get_return(); | |||
| EXPECT_NE(ret, nullptr); | |||
| auto make_tuple = ret->input(1); | |||
| EXPECT_NE(make_tuple, nullptr); | |||
| auto momentum = make_tuple->cast<CNodePtr>()->input(1); | |||
| EXPECT_NE(momentum, nullptr); | |||
| EXPECT_NE(momentum->abstract(), nullptr); | |||
| EXPECT_FALSE(momentum->abstract()->isa<abstract::AbstractTuple>()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::AddInputToOutput>(); | |||
| pass->op_finder_ = std::make_shared<MockOpFinder>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(kg); | |||
| EXPECT_TRUE(momentum->abstract()->isa<abstract::AbstractTuple>()); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import operations as P | |||
| ApplyMomentum = P.ApplyMomentum() | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_add_input_to_output(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4): | |||
| return ApplyMomentum(input0, input1, input2, input3, input4) | |||
| return fns[tag] | |||