Merge pull request !376 from wenchunjiang/getnext_passtags/v0.2.0-alpha
| @@ -360,5 +360,17 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); | |||
| return tuple_getitem; | |||
| } | |||
| AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto prim = std::make_shared<Primitive>(kMemCpyAsyncOpName); | |||
| std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node}; | |||
| auto new_node = graph->NewCNode(new_node_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_abstract(node->abstract()); | |||
| new_node->set_scope(node->scope()); | |||
| return new_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -65,6 +65,8 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP | |||
| CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | |||
| AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| if (func_graph == nullptr || node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| if (output_num == 0) { | |||
| MS_LOG(DEBUG) << "Output number is zero, no need to insert memcpy_async!"; | |||
| return node; | |||
| } | |||
| // getnext output is tuple and dynamic | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| auto tuple_get_item = CreatTupleGetItemNode(func_graph, node, output_index); | |||
| auto new_node = CreateMemcpyAsyncOp(func_graph, tuple_get_item); | |||
| if (new_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); | |||
| make_tuple_inputs.push_back(new_node); | |||
| } | |||
| AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| return make_tuple; | |||
| } | |||
| const BaseRef InsertMemcpyAsyncForGetNext::DefinePattern() const { | |||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | |||
| auto prim = std::make_shared<Primitive>(kGetNextOpName); | |||
| return VectorRef({prim, Xs}); | |||
| } | |||
| const AnfNodePtr InsertMemcpyAsyncForGetNext::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (func_graph == nullptr || node == nullptr || !AnfAlgo::IsRealKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| if (AnfAlgo::HasNodeAttr(kAttrVisited, node)) { | |||
| MS_LOG(DEBUG) << "Node op_name[" << kGetNextOpName << "] has visited."; | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| return InsertMemcpyAsyncForGetNextOutputs(func_graph, node); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ | |||
| #include "pre_activate/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertMemcpyAsyncForGetNext : public PatternProcessPass { | |||
| public: | |||
| explicit InsertMemcpyAsyncForGetNext(bool multigraph = true) | |||
| : PatternProcessPass("insert_memcpy_async_for_getnext", multigraph) {} | |||
| ~InsertMemcpyAsyncForGetNext() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ | |||
| @@ -18,22 +18,11 @@ | |||
| #include "utils/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "optimizer/opt.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto prim = std::make_shared<Primitive>(kMemCpyAsyncOpName); | |||
| std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node}; | |||
| auto new_node = graph->NewCNode(new_node_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_abstract(node->abstract()); | |||
| new_node->set_scope(node->scope()); | |||
| return new_node; | |||
| } | |||
| const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "session/ascend_session.h" | |||
| #include "pipeline/resource.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/manager.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "utils/utils.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | |||
| class TestHWInsertMemcpyAsyncForGetNext : public BackendCommon { | |||
| public: | |||
| TestHWInsertMemcpyAsyncForGetNext() : get_py_fun_("gtest_input.pre_activate.insert_memcpy_async_for_getnext", true) {} | |||
| ~TestHWInsertMemcpyAsyncForGetNext() override = default; | |||
| public: | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWInsertMemcpyAsyncForGetNext, test_insert_memcpy_async_for_getnext_multi_output) { | |||
| FuncGraphPtr g_before = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_getnext", "getnext_multi_output_before"); | |||
| AbstractBasePtrList args_spec_list{}; | |||
| auto kernel_graph = GetKernelGraph(g_before, args_spec_list); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); | |||
| builder.SetOutputsDeviceType({kFloat32->type_id(), kInt32->type_id()}); | |||
| auto ret = kernel_graph->get_return(); | |||
| EXPECT_NE(ret->input(1), nullptr); | |||
| EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1), nullptr); | |||
| auto get_next = ret->input(1)->cast<CNodePtr>()->input(1); | |||
| get_next->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), get_next.get()); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::InsertMemcpyAsyncForGetNext>()); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kernel_graph); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_getnext", "getnext_multi_output_after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import Primitive | |||
| import mindspore as ms | |||
| get_next = P.GetNext([ms.float32, ms.int32], [[32, 64], [32]], 2, "") | |||
| memcpy_async = Primitive('memcpy_async') | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_insert_memcpy_async_for_getnext(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def getnext_multi_output_before(): | |||
| res = get_next() | |||
| return res | |||
| @fns | |||
| def getnext_multi_output_after(): | |||
| res = get_next() | |||
| data = tuple_getitem(res, 0) | |||
| label = tuple_getitem(res, 1) | |||
| memcpy_async_data = memcpy_async(data) | |||
| memcpy_async_label = memcpy_async(label) | |||
| tuple = make_tuple(make_tuple(memcpy_async_data, memcpy_async_label)) | |||
| return tuple | |||
| return fns[tag] | |||