|
|
|
@@ -345,7 +345,6 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { |
|
|
|
SetGraphKernelInfo(kernel_node, func_graph); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::string> inputs_format; |
|
|
|
std::vector<TypeId> inputs_type; |
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { |
|
|
|
@@ -368,12 +367,10 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { |
|
|
|
builder->SetInputsDeviceType(inputs_type); |
|
|
|
builder->SetOutputsFormat(outputs_format); |
|
|
|
builder->SetOutputsDeviceType(outputs_type); |
|
|
|
|
|
|
|
bool result = false; |
|
|
|
if (kernel_type == UNKNOWN_KERNEL_TYPE) { |
|
|
|
result = |
|
|
|
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); |
|
|
|
|
|
|
|
if (!result) { |
|
|
|
result = SelectAkgKernel(kernel_node, builder->Build()); |
|
|
|
kernel_type = AKG_KERNEL; |
|
|
|
@@ -381,7 +378,6 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { |
|
|
|
} else if (kernel_type == AKG_KERNEL) { |
|
|
|
result = SelectAkgKernel(kernel_node, builder->Build()); |
|
|
|
} |
|
|
|
|
|
|
|
if (!result) { |
|
|
|
PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type); |
|
|
|
return; |
|
|
|
|