Browse Source

Pre Merge pull request !986 from 周超/dev2

pull/986/MERGE
周超 Gitee 5 years ago
parent
commit
e947fa3f08
1 changed files with 12 additions and 0 deletions
  1. +12
    -0
      ge/hybrid/executor/worker/shape_inference_engine.cc

+ 12
- 0
ge/hybrid/executor/worker/shape_inference_engine.cc View File

@@ -69,8 +69,20 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str());
{
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start");
vector<ge::DataType> temp_dtype;
OpDescPtr op_desc_before = node_item.node->GetOpDesc();
for (auto &tensor_desc: op_desc_before->GetAllOutputsDescPtr()) {
temp_dtype.emplace_back(tensor_desc->GetDataType());
}
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true),
"Invoke InferShapeAndType failed.");
OpDescPtr op_desc_after = node_item.node->GetOpDesc();
auto all_output_tensor = op_desc_after->GetAllOutputsDescPtr();
for (size_t i = 0; i < all_output_tensor.size(); ++i) {
if (all_output_tensor.at(i)->GetDataType() != temp_dtype[i]) {
all_output_tensor.at(i)->SetDataType(temp_dtype[i]);
}
}
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End");
}



Loading…
Cancel
Save