diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index d05b9fafa1..a7c8d131fb 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -82,6 +82,12 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: } return true; }; + if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" || + AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") { + if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) { + return true; + } + } if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); @@ -155,7 +161,7 @@ bool PriorityChooseItem(const std::vector &cur_item, std::vector *best return false; } } - return false; + return true; } void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node,