|
|
@@ -176,7 +176,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co |
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { |
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { |
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; |
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; |
|
|
builder->SetOutputsFormat(output_format); |
|
|
builder->SetOutputsFormat(output_format); |
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; |
|
|
|
|
|
|
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)}; |
|
|
builder->SetOutputsDeviceType(output_type); |
|
|
builder->SetOutputsDeviceType(output_type); |
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); |
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); |
|
|
} |
|
|
} |
|
|
|