|
|
|
@@ -162,8 +162,18 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co |
|
|
|
} |
|
|
|
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = |
|
|
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
bool is_ref = false; |
|
|
|
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); |
|
|
|
if (op_info != nullptr) { |
|
|
|
is_ref = op_info->is_ref(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); |
|
|
|
if (MsContext::GetInstance()->execution_mode() == kPynativeMode && |
|
|
|
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// we set special device info of a input tensor. |
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) { |
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { |
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; |
|
|
|
builder->SetOutputsFormat(output_format); |
|
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; |
|
|
|
|