Browse Source

fix bug of gpu refresh parameter & valuenode's format info when kernel selecting && do not refresh ZN_LSTM format for valuenode

tags/v1.1.0
Lianliguang 5 years ago
parent
commit
61f3c134c0
2 changed files with 19 additions and 16 deletions
  1. +10
    -10
      mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc
  2. +9
    -6
      mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc

+ 10
- 10
mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc View File

@@ -478,16 +478,6 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
continue;
}
@@ -500,6 +490,16 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (AnfAlgo::OutputAddrExist(real_input_node, 0)) {
continue;
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);


+ 9
- 6
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc View File

@@ -133,29 +133,32 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = kernel_node->input(input_index + 1);
MS_EXCEPTION_IF_NULL(input_kernel_node);
if (!input_kernel_node->isa<Parameter>()) {
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (!real_input_node->isa<Parameter>()) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();

auto param = input_kernel_node->cast<ParameterPtr>();
auto param = real_input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (!AnfAlgo::IsParameterWeight(param)) {
std::vector<std::string> output_format = {kOpFormat_DEFAULT};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)};
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
continue;
}
if ((AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) ||
if ((AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) ||
(AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
}
}


Loading…
Cancel
Save