diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index 46ee6bd6..5f688d7c 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -69,8 +69,20 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); { RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); + vector temp_dtype; + OpDescPtr op_desc_before = node_item.node->GetOpDesc(); + for (auto &tensor_desc: op_desc_before->GetAllOutputsDescPtr()) { + temp_dtype.emplace_back(tensor_desc->GetDataType()); + } GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), "Invoke InferShapeAndType failed."); + OpDescPtr op_desc_after = node_item.node->GetOpDesc(); + auto all_output_tensor = op_desc_after->GetAllOutputsDescPtr(); + for (size_t i = 0; i < all_output_tensor.size(); ++i) { + if (all_output_tensor.at(i)->GetDataType() != temp_dtype[i]) { + all_output_tensor.at(i)->SetDataType(temp_dtype[i]); + } + } RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); }