| @@ -82,12 +82,6 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: | |||||
| } | } | ||||
| return true; | return true; | ||||
| }; | }; | ||||
| if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" || | |||||
| AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") { | |||||
| if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | ||||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | ||||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | ||||
| @@ -161,7 +155,7 @@ bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| return true; | |||||
| return false; | |||||
| } | } | ||||
| void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, | void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, | ||||