| @@ -122,7 +122,7 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| GE_CHECK_NOTNULL(input_desc); | GE_CHECK_NOTNULL(input_desc); | ||||
| int64_t tensor_size = -1; | int64_t tensor_size = -1; | ||||
| (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | ||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]", | |||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| idx, | idx, | ||||
| src_tensor_desc->GetShape().ToString().c_str(), | src_tensor_desc->GetShape().ToString().c_str(), | ||||
| @@ -71,7 +71,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | ||||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | ||||
| "Invoke InferShapeAndType failed."); | |||||
| "Invoke InferShapeAndType failed."); | |||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | ||||
| } | } | ||||
| @@ -229,66 +229,87 @@ Status ShapeInferenceEngine::UpdatePeerNodeShape(const Node &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ShapeInferenceEngine::CanonicalizeShape(GeTensorDesc &tensor_desc, | |||||
| std::vector<int64_t> &shape, | |||||
| bool fallback_with_range) { | |||||
| const auto &tensor_shape = tensor_desc.MutableShape(); | |||||
| if (tensor_shape.IsUnknownShape()) { | |||||
| if (!fallback_with_range) { | |||||
| GELOGE(INTERNAL_ERROR, "Output shape is still unknown after shape inference. shape = [%s]", | |||||
| tensor_shape.ToString().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Calc output size by range"); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| GE_CHK_GRAPH_STATUS_RET(tensor_desc.GetShapeRange(shape_range), "Failed to get shape range"); | |||||
| if (shape_range.size() != shape.size()) { | |||||
| GELOGE(INTERNAL_ERROR, "Number of shape ranges (%zu) mismatches that of dims (%zu)", | |||||
| shape_range.size(), | |||||
| shape.size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| for (size_t dim_index = 0; dim_index < shape.size(); ++dim_index) { | |||||
| if (shape[dim_index] == ge::UNKNOWN_DIM) { | |||||
| shape[dim_index] = shape_range[dim_index].second; | |||||
| } | |||||
| } | |||||
| GELOGD("After canonicalization, shape = [%s], before = [%s]", | |||||
| GeShape(shape).ToString().c_str(), | |||||
| tensor_shape.ToString().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ShapeInferenceEngine::CalcTensorSize(DataType data_type, | |||||
| const std::vector<int64_t> &shape, | |||||
| int64_t &tensor_size) { | |||||
| GELOGD("To calc tensor size by shape = [%s]", GeShape(shape).ToString().c_str()); | |||||
| uint32_t type_size; | |||||
| if (!TypeUtils::GetDataTypeLength(data_type, type_size)) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to get data type size"); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| tensor_size = type_size; | |||||
| for (const auto &dim : shape) { | |||||
| GE_CHECK_GE(dim, 0); | |||||
| GE_CHK_STATUS_RET(Int64MulCheckOverflow(tensor_size, dim), | |||||
| "Shape size overflow, shape = [%s]", | |||||
| GeShape(shape).ToString().c_str()); | |||||
| tensor_size *= dim; | |||||
| } | |||||
| GE_CHK_STATUS_RET(CheckInt64AddOverflow(tensor_size, kAlignment - 1), | |||||
| "Tensor size is too large: %ld, shape = [%s]", | |||||
| tensor_size, | |||||
| GeShape(shape).ToString().c_str()); | |||||
| tensor_size = (tensor_size + kAlignment - 1) / kAlignment * kAlignment; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ShapeInferenceEngine::CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range) { | Status ShapeInferenceEngine::CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range) { | ||||
| auto op_desc = node_item.GetOpDesc(); | auto op_desc = node_item.GetOpDesc(); | ||||
| for (size_t output_index = 0; output_index < op_desc->GetOutputsSize(); ++output_index) { | for (size_t output_index = 0; output_index < op_desc->GetOutputsSize(); ++output_index) { | ||||
| auto tensor_desc = op_desc->MutableOutputDesc(output_index); | auto tensor_desc = op_desc->MutableOutputDesc(output_index); | ||||
| GE_CHECK_NOTNULL(tensor_desc); | GE_CHECK_NOTNULL(tensor_desc); | ||||
| const auto &shape = tensor_desc->MutableShape(); | const auto &shape = tensor_desc->MutableShape(); | ||||
| // modify on copy | |||||
| auto dims = shape.GetDims(); | auto dims = shape.GetDims(); | ||||
| auto dim_num = dims.size(); | |||||
| if (shape.IsUnknownShape()) { | |||||
| if (!fallback_with_range) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Shape of output[%zu] is still unknown after shape inference. shape = [%s]", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index, | |||||
| shape.ToString().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("[%s] Calc output[%zu] size by range", node_item.NodeName().c_str(), output_index); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| GE_CHK_GRAPH_STATUS_RET(tensor_desc->GetShapeRange(shape_range), | |||||
| "[$s] Failed to get shape range for output: %zu", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index); | |||||
| if (shape_range.size() != dim_num) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Number of shape ranges (%zu) mismatches that of dims (%zu), index = %zu", | |||||
| node_item.NodeName().c_str(), | |||||
| shape_range.size(), | |||||
| dim_num, | |||||
| output_index); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) { | |||||
| if (dims[dim_index] == ge::UNKNOWN_DIM) { | |||||
| dims[dim_index] = shape_range[dim_index].second; | |||||
| } | |||||
| } | |||||
| } | |||||
| uint32_t type_size = 0; | |||||
| if (!TypeUtils::GetDataTypeLength(tensor_desc->GetDataType(), type_size)) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to get data type size"); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| int64_t tensor_size = type_size; | |||||
| for (const auto &dim : dims) { | |||||
| GE_CHECK_GE(dim, 0); | |||||
| GE_CHK_STATUS_RET(Int64MulCheckOverflow(tensor_size, dim), | |||||
| "[%s] Shape size overflow, shape = [%s]", | |||||
| node_item.NodeName().c_str(), | |||||
| shape.ToString().c_str()); | |||||
| tensor_size *= dim; | |||||
| } | |||||
| GE_CHK_STATUS_RET(CanonicalizeShape(*tensor_desc, dims, fallback_with_range), | |||||
| "[%s] Failed to canonicalize shape for output %zu", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index); | |||||
| GE_CHK_STATUS_RET(CheckInt64AddOverflow(tensor_size, kAlignment - 1), | |||||
| "[%s] Output[%zu] Tensor size too large, shape = [%s]", | |||||
| int64_t tensor_size; | |||||
| GE_CHK_STATUS_RET(CalcTensorSize(tensor_desc->GetDataType(), dims, tensor_size), | |||||
| "[%s] Failed to calc tensor size for output %zu", | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| output_index, | |||||
| shape.ToString().c_str()); | |||||
| tensor_size = (tensor_size + kAlignment - 1) / kAlignment * kAlignment; | |||||
| output_index); | |||||
| GELOGD("[%s] Tensor size of output %zu = %ld", node_item.NodeName().c_str(), output_index, tensor_size); | |||||
| (void) TensorUtils::SetSize(*tensor_desc, tensor_size); | (void) TensorUtils::SetSize(*tensor_desc, tensor_size); | ||||
| } | } | ||||
| @@ -37,6 +37,8 @@ class ShapeInferenceEngine { | |||||
| static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false); | static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false); | ||||
| private: | private: | ||||
| static Status CanonicalizeShape(GeTensorDesc &tensor_desc, std::vector<int64_t> &shape, bool fallback_with_range); | |||||
| static Status CalcTensorSize(DataType data_type, const std::vector<int64_t> &shape, int64_t &tensor_size); | |||||
| static Status UpdatePeerNodeShape(const Node &node); | static Status UpdatePeerNodeShape(const Node &node); | ||||
| Status AwaitDependentNodes(NodeState &node_state); | Status AwaitDependentNodes(NodeState &node_state); | ||||
| @@ -127,12 +127,7 @@ Status NodeItem::Create(const NodePtr &node, std::unique_ptr<NodeItem> &node_ite | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status NodeItem::Init() { | |||||
| GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX); | |||||
| GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | |||||
| num_inputs = static_cast<int>(op_desc->GetInputsSize()); | |||||
| num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | |||||
| void NodeItem::ResolveOptionalInputs() { | |||||
| if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) { | if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) { | ||||
| has_optional_inputs = true; | has_optional_inputs = true; | ||||
| for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | ||||
| @@ -144,7 +139,18 @@ Status NodeItem::Init() { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| Status NodeItem::InitInputsAndOutputs() { | |||||
| GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX); | |||||
| GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | |||||
| num_inputs = static_cast<int>(op_desc->GetInputsSize()); | |||||
| num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | |||||
| ResolveOptionalInputs(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status NodeItem::ResolveDynamicState() { | |||||
| (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic); | (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic); | ||||
| GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic); | GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic); | ||||
| if (!is_dynamic) { | if (!is_dynamic) { | ||||
| @@ -152,42 +158,54 @@ Status NodeItem::Init() { | |||||
| "[%s] Failed to get shape status.", | "[%s] Failed to get shape status.", | ||||
| node->GetName().c_str()); | node->GetName().c_str()); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| if (is_dynamic) { | |||||
| for (int i = 0; i < num_inputs; ++i) { | |||||
| const auto &input_desc = MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| if (input_desc->MutableShape().IsUnknownShape()) { | |||||
| is_input_shape_static_.push_back(false); | |||||
| } else { | |||||
| num_static_input_shapes++; | |||||
| is_input_shape_static_.push_back(true); | |||||
| GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", | |||||
| NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str()); | |||||
| } | |||||
| Status NodeItem::ResolveStaticInputsAndOutputs() { | |||||
| for (int i = 0; i < num_inputs; ++i) { | |||||
| const auto &input_desc = MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| if (input_desc->MutableShape().IsUnknownShape()) { | |||||
| is_input_shape_static_.push_back(false); | |||||
| } else { | |||||
| num_static_input_shapes++; | |||||
| is_input_shape_static_.push_back(true); | |||||
| GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", | |||||
| NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str()); | |||||
| } | } | ||||
| } | |||||
| for (int i = 0; i < num_outputs; ++i) { | |||||
| const auto &output_desc = op_desc->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| if (output_desc->MutableShape().IsUnknownShape()) { | |||||
| is_output_shape_static = false; | |||||
| break; | |||||
| } | |||||
| for (int i = 0; i < num_outputs; ++i) { | |||||
| const auto &output_desc = op_desc->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| if (output_desc->MutableShape().IsUnknownShape()) { | |||||
| is_output_shape_static = false; | |||||
| break; | |||||
| } | } | ||||
| } | |||||
| if (is_output_shape_static) { | |||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this)); | |||||
| } | |||||
| if (is_output_shape_static) { | |||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this)); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| if (IsControlOp() || node_type == PARTITIONEDCALL) { | |||||
| shape_inference_type = DEPEND_COMPUTE; | |||||
| } else { | |||||
| int32_t unknown_shape_type_val = 0; | |||||
| (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); | |||||
| shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val); | |||||
| } | |||||
| void NodeItem::ResolveUnknownShapeType() { | |||||
| if (IsControlOp() || node_type == PARTITIONEDCALL) { | |||||
| shape_inference_type = DEPEND_COMPUTE; | |||||
| } else { | |||||
| int32_t unknown_shape_type_val = 0; | |||||
| (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); | |||||
| shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val); | |||||
| } | |||||
| } | |||||
| Status NodeItem::Init() { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs()); | |||||
| GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState()); | |||||
| if (is_dynamic) { | |||||
| ResolveUnknownShapeType(); | |||||
| GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs()); | |||||
| GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); | GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); | ||||
| } | } | ||||
| @@ -103,6 +103,11 @@ struct NodeItem { | |||||
| private: | private: | ||||
| explicit NodeItem(NodePtr node); | explicit NodeItem(NodePtr node); | ||||
| Status Init(); | Status Init(); | ||||
| Status InitInputsAndOutputs(); | |||||
| void ResolveOptionalInputs(); | |||||
| Status ResolveDynamicState(); | |||||
| Status ResolveStaticInputsAndOutputs(); | |||||
| void ResolveUnknownShapeType(); | |||||
| std::vector<bool> is_input_shape_static_; | std::vector<bool> is_input_shape_static_; | ||||
| std::vector<uint32_t> input_desc_indices_; | std::vector<uint32_t> input_desc_indices_; | ||||