|
|
@@ -82,12 +82,6 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
}; |
|
|
}; |
|
|
if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" || |
|
|
|
|
|
AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") { |
|
|
|
|
|
if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { |
|
|
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { |
|
|
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && |
|
|
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && |
|
|
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); |
|
|
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); |
|
|
@@ -161,7 +155,7 @@ bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return true; |
|
|
|
|
|
|
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, |
|
|
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, |
|
|
|