From 3348e5a7c2a24dd331fe68f8d7487274915b275c Mon Sep 17 00:00:00 2001 From: lianliguang Date: Sat, 11 Apr 2020 15:55:50 +0800 Subject: [PATCH] deal something special of adam's kernel select --- mindspore/ccsrc/device/ascend/kernel_select_ascend.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index d05b9fafa1..0a23e2da7b 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -82,6 +82,13 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: } return true; }; + if (AnfAlgo::GetCNodeName(kernel_node) == "Adam") { + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_num - 1) != + kernel_build_info.GetInputFormat(input_num - 1)) { + return false; + } + } if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);