| @@ -37,14 +37,12 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| std::unordered_set<AnfNodePtr> record{cnode}; | |||
| auto write_input = cnode->input(1); | |||
| if (CheckEltWiseNode(manager.get(), write_input)) { | |||
| (void)record.insert(write_input); | |||
| auto input_cnode = write_input->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||
| write_input = input_cnode->input(1); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(write_input); | |||
| if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) || | |||
| fusion_id_allocator->HasFusionIdAttr(write_input)) { | |||
| @@ -63,7 +61,6 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con | |||
| fusion_id_allocator->HasFusionIdAttr(conv_input)) { | |||
| return; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) { | |||
| (void)record.insert(conv_input); | |||
| candidate_fusion->push_back(record); | |||
| @@ -44,18 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kPynativeMode) { | |||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { | |||
| return nullptr; | |||
| } | |||
| auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); | |||
| if (do_mask_input_format != kOpFormat_DEFAULT) { | |||
| auto builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node)); | |||
| builder->SetInputFormat(kOpFormat_DEFAULT, 0); | |||
| builder->SetOutputFormat(kOpFormat_DEFAULT, 0); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||
| } | |||
| return nullptr; | |||
| return RectifyKernelInfoInPynativeProcess(node); | |||
| } | |||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { | |||
| return nullptr; | |||
| @@ -139,6 +128,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string | |||
| } | |||
| return convert_format; | |||
| } | |||
| void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, | |||
| const std::string &format) const { | |||
| for (const auto &do_mask : do_mask_node_list) { | |||
| @@ -150,5 +140,24 @@ void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<C | |||
| } | |||
| } | |||
| AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { | |||
| return nullptr; | |||
| } | |||
| auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); | |||
| if (do_mask_input_format != kOpFormat_DEFAULT) { | |||
| auto builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node)); | |||
| builder->SetInputFormat(kOpFormat_DEFAULT, 0); | |||
| builder->SetOutputFormat(kOpFormat_DEFAULT, 0); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -33,6 +33,7 @@ class RectifyDoMaskKernelInfo : public PatternProcessPass { | |||
| private: | |||
| void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const; | |||
| AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const; | |||
| std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const; | |||
| void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const; | |||
| }; | |||
| @@ -112,32 +112,13 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con | |||
| } | |||
| auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); | |||
| while (index < input_num) { | |||
| auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); | |||
| ++index; | |||
| MS_EXCEPTION_IF_NULL(replacing_node); | |||
| if (!replacing_node->isa<CNode>()) { | |||
| new_depend_inputs.push_back(replacing_node); | |||
| continue; | |||
| } | |||
| auto replacing_cnode = replacing_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(replacing_cnode); | |||
| // Deal with the make_tuple with TransData or Cast inputs. | |||
| auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode); | |||
| if (make_tuple_replace_node != nullptr) { | |||
| new_depend_inputs.push_back(make_tuple_replace_node); | |||
| continue; | |||
| } | |||
| AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); | |||
| if (replace_node == nullptr) { | |||
| new_depend_inputs.push_back(replacing_node); | |||
| MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " | |||
| << node->DebugString(); | |||
| continue; | |||
| } | |||
| auto replace_node = GetConvertNode(func_graph, node, index); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| new_depend_inputs.push_back(replace_node); | |||
| ++index; | |||
| } | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| CNodePtr new_depend; | |||
| CNodePtr new_depend = nullptr; | |||
| if (kernel_graph == nullptr) { | |||
| new_depend = func_graph->NewCNode(new_depend_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| @@ -150,5 +131,31 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con | |||
| } | |||
| return new_depend; | |||
| } | |||
| const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const size_t index) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto depend_cnode = node->cast<CNodePtr>(); | |||
| auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); | |||
| MS_EXCEPTION_IF_NULL(replacing_node); | |||
| if (!replacing_node->isa<CNode>()) { | |||
| return replacing_node; | |||
| } | |||
| auto replacing_cnode = replacing_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(replacing_cnode); | |||
| // Deal with the make_tuple with TransData or Cast inputs. | |||
| auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode); | |||
| if (make_tuple_replace_node != nullptr) { | |||
| return make_tuple_replace_node; | |||
| } | |||
| AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); | |||
| if (replace_node == nullptr) { | |||
| MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); | |||
| return replacing_node; | |||
| } | |||
| return replace_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -27,6 +27,7 @@ class OptimizeDependence : public PatternProcessPass { | |||
| ~OptimizeDependence() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||