Browse Source

fix switch input type

tags/v0.6.0-beta
hexia 5 years ago
parent
commit
7b6e7bd62f
2 changed files with 13 additions and 6 deletions
  1. +11
    -4
      mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
  2. +2
    -2
      mindspore/ccsrc/kernel/rts/label_switch.cc

+ 11
- 4
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc View File

@@ -566,10 +566,17 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern
MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index) MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index)
<< "] :" << aicpu_kernel_info_list[index]->ToString(); << "] :" << aicpu_kernel_info_list[index]->ToString();
} }
MS_LOG(WARNING) << " <<<";
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid kernel info, not supported the type:" << buffer.str()
<< ", please refer to the supported dtypes in candidates kernel info list";
if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) {
auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
} else {
MS_LOG(WARNING) << " <<<";
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid kernel info, not supported the type:" << buffer.str()
<< ", please refer to the supported dtypes in candidates kernel info list";
}
} }
return select_status; return select_status;
} }


+ 2
- 2
mindspore/ccsrc/kernel/rts/label_switch.cc View File

@@ -75,8 +75,8 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr


std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo() { std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo() {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{}; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{};
vector<string> input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT};
vector<TypeId> input_type{kNumberTypeUInt32, kNumberTypeBool};
vector<string> input_format{kOpFormat_DEFAULT};
vector<TypeId> input_type{kNumberTypeInt32};
if (input_format.size() != input_type.size()) { if (input_format.size() != input_type.size()) {
MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size "
<< input_type.size(); << input_type.size();


Loading…
Cancel
Save