Merge pull request !1911 from lianliguang/add-a-function-to-charge-the-node-input-or-output-if-is-a-scalartags/v0.5.0-beta
| @@ -37,11 +37,11 @@ class SupportedChecker { | |||||
| public: | public: | ||||
| SupportedChecker() = default; | SupportedChecker() = default; | ||||
| virtual ~SupportedChecker() = default; | virtual ~SupportedChecker() = default; | ||||
| virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, | |||||
| virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, | |||||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | ||||
| return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); | return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); | ||||
| } | } | ||||
| virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, | |||||
| virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, | |||||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | ||||
| return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); | return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); | ||||
| } | } | ||||
| @@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); | auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); | ||||
| if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) { | |||||
| return node; | |||||
| } else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) { | |||||
| if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) { | |||||
| return nullptr; | |||||
| } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) { | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info); | auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info); | ||||
| builder->SetKernelType(AICPU_KERNEL); | builder->SetKernelType(AICPU_KERNEL); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | ||||
| @@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph | |||||
| MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" | MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" | ||||
| << node->DebugString() << "]"; | << node->DebugString() << "]"; | ||||
| } | } | ||||
| return node; | |||||
| return nullptr; | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod | |||||
| auto indices_const = CreateValueNode(new_cnode); | auto indices_const = CreateValueNode(new_cnode); | ||||
| new_cnode->add_input(indices_const); | new_cnode->add_input(indices_const); | ||||
| MS_EXCEPTION_IF_NULL(supported_checker_); | MS_EXCEPTION_IF_NULL(supported_checker_); | ||||
| if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { | |||||
| if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { | |||||
| MS_LOG(INFO) << "split topk failed, check to aicpu."; | MS_LOG(INFO) << "split topk failed, check to aicpu."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap | |||||
| new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); | new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); | ||||
| auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName); | auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName); | ||||
| if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) { | |||||
| if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) { | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata), | std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata), | ||||
| utils::cast<AnfNodePtr>((*equiv)[input_varptr_])}; | utils::cast<AnfNodePtr>((*equiv)[input_varptr_])}; | ||||
| auto new_node = func_graph->NewCNode(inputs); | auto new_node = func_graph->NewCNode(inputs); | ||||
| @@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { | |||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); | MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); | |||||
| if (shape.empty()) { | |||||
| return true; | |||||
| } | |||||
| return shape.size() == kShape1dDims && shape[0] == 1; | |||||
| } | |||||
| bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); | |||||
| if (shape.empty()) { | |||||
| return true; | |||||
| } | |||||
| return shape.size() == kShape1dDims && shape[0] == 1; | |||||
| } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm { | |||||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | ||||
| static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | ||||
| static bool IsSwitchCall(const CNodePtr &call_node); | static bool IsSwitchCall(const CNodePtr &call_node); | ||||
| static bool IsScalarInput(const CNodePtr &cnode, size_t index); | |||||
| static bool IsScalarOutput(const CNodePtr &cnode, size_t index); | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | using AnfAlgo = session::AnfRuntimeAlgorithm; | ||||
| @@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other"; | |||||
| // some size | // some size | ||||
| const size_t kShape4dDims = 4; | const size_t kShape4dDims = 4; | ||||
| const size_t kShape2dDims = 2; | |||||
| const size_t kShape5dDims = 5; | const size_t kShape5dDims = 5; | ||||
| const size_t kShape1dDims = 1; | |||||
| const size_t kCubeSize = 16; | const size_t kCubeSize = 16; | ||||
| const size_t kMemAlignSize = 512; | const size_t kMemAlignSize = 512; | ||||
| const int kParameterDataTensorMask = 0; | const int kParameterDataTensorMask = 0; | ||||
| @@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker { | |||||
| public: | public: | ||||
| MockSupportedChecker() = default; | MockSupportedChecker() = default; | ||||
| ~MockSupportedChecker() override = default; | ~MockSupportedChecker() override = default; | ||||
| bool CheckAiCoreSupported(const AnfNodePtr &anf_node, | |||||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||||
| bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||||
| return true; | return true; | ||||
| } | } | ||||
| }; // namespace opt | }; // namespace opt | ||||
| @@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker { | |||||
| public: | public: | ||||
| MockSupportedChecker() = default; | MockSupportedChecker() = default; | ||||
| ~MockSupportedChecker() override = default; | ~MockSupportedChecker() override = default; | ||||
| bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||||
| bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||||
| return true; | return true; | ||||
| } | } | ||||
| }; | }; | ||||