|
|
|
@@ -69,8 +69,20 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { |
|
|
|
GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); |
|
|
|
{ |
|
|
|
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); |
|
|
|
vector<ge::DataType> temp_dtype; |
|
|
|
OpDescPtr op_desc_before = node_item.node->GetOpDesc(); |
|
|
|
for (auto &tensor_desc: op_desc_before->GetAllOutputsDescPtr()) { |
|
|
|
temp_dtype.emplace_back(tensor_desc->GetDataType()); |
|
|
|
} |
|
|
|
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), |
|
|
|
"Invoke InferShapeAndType failed."); |
|
|
|
OpDescPtr op_desc_after = node_item.node->GetOpDesc(); |
|
|
|
auto all_output_tensor = op_desc_after->GetAllOutputsDescPtr(); |
|
|
|
for (size_t i = 0; i < all_output_tensor.size(); ++i) { |
|
|
|
if (all_output_tensor.at(i)->GetDataType() != temp_dtype[i]) { |
|
|
|
all_output_tensor.at(i)->SetDataType(temp_dtype[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); |
|
|
|
} |
|
|
|
|
|
|
|
|