Browse Source

!7533 fix bug of gpu‘s kernel setter

Merge pull request !7533 from lianliguang/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
08dad79529
2 changed files with 19 additions and 16 deletions
  1. +10
    -10
      mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc
  2. +9
    -6
      mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc

+ 10
- 10
mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc View File

@@ -316,16 +316,6 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue; continue;
} }
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) { if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
continue; continue;
} }
@@ -338,6 +328,16 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (AnfAlgo::OutputAddrExist(real_input_node, 0)) { if (AnfAlgo::OutputAddrExist(real_input_node, 0)) {
continue; continue;
} }
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);


+ 9
- 6
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc View File

@@ -133,29 +133,32 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = kernel_node->input(input_index + 1); auto input_kernel_node = kernel_node->input(input_index + 1);
MS_EXCEPTION_IF_NULL(input_kernel_node); MS_EXCEPTION_IF_NULL(input_kernel_node);
if (!input_kernel_node->isa<Parameter>()) {
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (!real_input_node->isa<Parameter>()) {
continue; continue;
} }
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();


auto param = input_kernel_node->cast<ParameterPtr>();
auto param = real_input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param); MS_EXCEPTION_IF_NULL(param);
if (!AnfAlgo::IsParameterWeight(param)) { if (!AnfAlgo::IsParameterWeight(param)) {
std::vector<std::string> output_format = {kOpFormat_DEFAULT}; std::vector<std::string> output_format = {kOpFormat_DEFAULT};
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)};
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
builder->SetOutputsDeviceType(output_type); builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
continue; continue;
} }
if ((AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) ||
if ((AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) ||
(AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) { (AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type); builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
} }
} }
} }


Loading…
Cancel
Save