From 61f3c134c005ce1b74d5b7fa03dc943116e5eca4 Mon Sep 17 00:00:00 2001 From: Lianliguang Date: Tue, 20 Oct 2020 21:57:22 +0800 Subject: [PATCH] fix bug of gpu refresh parameter & valuenode's format info when kernel selecting && do not refresh ZN_LSTM format for valuenode --- .../device/ascend/kernel_select_ascend.cc | 20 +++++++++---------- .../runtime/device/gpu/kernel_info_setter.cc | 15 ++++++++------ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index 5a423a62ae..86424fa83a 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -478,16 +478,6 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { continue; } - auto builder = std::make_shared(); - if (IsValueNode(input_kernel_node) && - AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - continue; - } if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) { continue; } @@ -500,6 +490,16 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co if (AnfAlgo::OutputAddrExist(real_input_node, 0)) { continue; } + auto builder = std::make_shared(); + if (IsValueNode(input_kernel_node) && + AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { + std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + continue; + } if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index 8d73a33329..b4eeb40247 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -133,29 +133,32 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { auto input_kernel_node = kernel_node->input(input_index + 1); MS_EXCEPTION_IF_NULL(input_kernel_node); - if (!input_kernel_node->isa()) { + auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); + MS_EXCEPTION_IF_NULL(input_with_index.first); + auto real_input_node = input_with_index.first; + if (!real_input_node->isa()) { continue; } std::shared_ptr builder = std::make_shared(); - auto param = input_kernel_node->cast(); + auto param = real_input_node->cast(); MS_EXCEPTION_IF_NULL(param); if (!AnfAlgo::IsParameterWeight(param)) { std::vector output_format = {kOpFormat_DEFAULT}; builder->SetOutputsFormat(output_format); - std::vector output_type = {AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)}; + std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); continue; } - if ((AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) || + if ((AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) || (AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); } } }