diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 74acf11a6a..f0b6ba4567 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -178,7 +178,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); - std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; + std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; builder->SetOutputsDeviceType(output_type); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); } diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index bd3ed8e30e..626eef570b 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -582,6 +582,7 @@ bool IsShapeMatchFormat(const std::vector &shape, const std::string &for bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { MS_EXCEPTION_IF_NULL(kernel_node); + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); const size_t kCAxis = 1; for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); @@ -594,6 +595,12 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { return false; } + if (kernel_name == "ReduceMean") { + auto keep_dims = AnfAlgo::GetNodeAttr(kernel_node, kAttrKeepDims); + if (keep_dims == false && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) { + return false; + } + } } for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); @@ -606,6 +613,12 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: } return false; } + if (kernel_name == "ReduceMean") { + auto keep_dims = AnfAlgo::GetNodeAttr(kernel_node, kAttrKeepDims); + if (keep_dims == false && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) { + return false; + } + } } if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 0422456971..fddd132d06 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -315,7 +315,10 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { - if (node->isa() || node->isa()) { + if (node->isa()) { + return true; + } + if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { return true; } return false;