| @@ -60,7 +60,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||
| default: | |||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||
| " and FORMAT_FRACTAL_NZ is not supported."; | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return false; | |||
| } | |||
| } | |||
| @@ -59,7 +59,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||
| default: | |||
| std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||
| " and FORMAT_FRACTAL_ZZ is not supported."; | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); | |||
| return false; | |||
| } | |||
| } | |||
| @@ -92,7 +92,8 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | |||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, | |||
| "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); | |||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||
| @@ -50,21 +50,21 @@ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | |||
| bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | |||
| if (src_shape.empty()) { | |||
| std::string error = "Failed to transpose, empty src shape"; | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||
| GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to transpose, empty src shape"); | |||
| return false; | |||
| } | |||
| for (auto dim : src_shape) { | |||
| if (dim < 0) { | |||
| std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||
| return false; | |||
| } | |||
| } | |||
| if (perm_arg.size() != src_shape.size()) { | |||
| std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + | |||
| " and perm arg" + FmtToStr(perm_arg.size()) + " are different"; | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); | |||
| return false; | |||
| } | |||
| @@ -73,7 +73,7 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in | |||
| if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) { | |||
| std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) + | |||
| ", perm arg " + FmtToStr(JoinToString(perm_arg)); | |||
| GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||
| GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); | |||
| return false; | |||
| } | |||
| } | |||
| @@ -82,11 +82,11 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in | |||
| bool IsTransposeArgValid(const uint8_t *src, const std::vector<int64_t> &src_shape, DataType src_data_type, | |||
| const std::vector<int64_t> &perm_arg) { | |||
| if (src == nullptr) { | |||
| GELOGE(PARAM_INVALID, "Failed to transpose, the src is null"); | |||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to transpose, the src is null"); | |||
| return false; | |||
| } | |||
| if (GetSizeByDataType(src_data_type) < 0) { | |||
| GELOGE(UNSUPPORTED, "Failed to transpose, the data type %s is not support", | |||
| GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to transpose, the data type %s is not support", | |||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||
| return false; | |||
| } | |||
| @@ -333,7 +333,7 @@ TensorValue *HybridModel::GetConstant(const NodePtr &node) const { | |||
| return nullptr; | |||
| } | |||
| auto it = constant_tensors_.find(node); | |||
| auto it = constant_tensors_.find(node->GetName()); | |||
| if (it == constant_tensors_.end()) { | |||
| GELOGD("constant not found, node name = [%s]", node->GetName().c_str()); | |||
| return nullptr; | |||
| @@ -138,7 +138,7 @@ class HybridModel { | |||
| std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 | |||
| std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 | |||
| std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; | |||
| std::map<NodePtr, std::unique_ptr<TensorValue>> constant_tensors_; | |||
| std::map<std::string, std::unique_ptr<TensorValue>> constant_tensors_; | |||
| std::map<NodePtr, std::vector<domi::TaskDef>> task_defs_; | |||
| std::map<NodePtr, GeModelPtr> known_shape_sub_models_; | |||
| @@ -617,9 +617,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { | |||
| return SUCCESS; | |||
| } | |||
| Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { | |||
| Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { | |||
| merged_graph = MakeShared<ComputeGraph>("MergedGraph"); | |||
| for (const auto &node : root_graph.GetDirectNode()) { | |||
| for (const auto &node : root_graph->GetDirectNode()) { | |||
| GE_CHECK_NOTNULL(node); | |||
| auto op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| @@ -649,7 +649,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||
| } | |||
| } | |||
| } | |||
| GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), | |||
| GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), | |||
| "[%s] Failed to merge subgraph.", | |||
| subgraph->GetName().c_str()); | |||
| } | |||
| @@ -665,7 +665,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||
| return a_level < b_level; | |||
| }); | |||
| for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { | |||
| for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { | |||
| GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); | |||
| GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), | |||
| "Failed to add subgraph [%s]", | |||
| @@ -675,8 +675,8 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||
| return SUCCESS; | |||
| } | |||
| Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, | |||
| ComputeGraph &parent_graph, | |||
| Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, | |||
| ComputeGraphPtr &parent_graph, | |||
| ComputeGraph &sub_graph) { | |||
| auto parent_node = sub_graph.GetParentNode(); | |||
| GE_CHECK_NOTNULL(parent_node); | |||
| @@ -705,15 +705,24 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, | |||
| } | |||
| } | |||
| parent_graph.AddNode(sub_node); | |||
| if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { | |||
| for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { | |||
| auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); | |||
| GE_CHECK_NOTNULL(sub_sub_graph); | |||
| sub_sub_graph->SetParentGraph(root_graph); | |||
| } | |||
| } | |||
| parent_graph->AddNode(sub_node); | |||
| GELOGD("[%s::%s] added to parent graph: [%s].", | |||
| sub_graph.GetName().c_str(), | |||
| sub_node->GetName().c_str(), | |||
| parent_graph.GetName().c_str()); | |||
| parent_graph->GetName().c_str()); | |||
| sub_node->SetOwnerComputeGraph(root_graph); | |||
| } | |||
| GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); | |||
| root_graph.RemoveSubgraph(sub_graph.GetName()); | |||
| root_graph->RemoveSubgraph(sub_graph.GetName()); | |||
| return SUCCESS; | |||
| } | |||
| @@ -765,7 +774,7 @@ Status HybridModelBuilder::LoadGraph() { | |||
| GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | |||
| root_graph->GetDirectNodesSize(), | |||
| root_graph->GetAllNodesSize()); | |||
| GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); | |||
| GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); | |||
| root_graph = std::move(merged_graph); | |||
| GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | |||
| root_graph->GetDirectNodesSize(), | |||
| @@ -1035,6 +1044,14 @@ Status HybridModelBuilder::InitWeights() { | |||
| sub_weight_buffer->GetSize()); | |||
| auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); | |||
| hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); | |||
| std::map<std::string, NodePtr> name_to_node; | |||
| for (const auto &subgraph : ge_root_model_->GetRootGraph()->GetAllSubgraphs()) { | |||
| for (const auto &node : subgraph->GetAllNodes()) { | |||
| name_to_node.insert(make_pair(node->GetName(), node)); | |||
| } | |||
| } | |||
| for (auto &node : root_graph->GetDirectNode()) { | |||
| if (node->GetType() != CONSTANT) { | |||
| continue; | |||
| @@ -1065,10 +1082,25 @@ Status HybridModelBuilder::InitWeights() { | |||
| auto tensor_buffer = TensorBuffer::Create(weight_base + data_offset, tensor_size); | |||
| GE_CHECK_NOTNULL(tensor_buffer); | |||
| if (tensor_size > 0) { | |||
| auto tensor = std::shared_ptr<GeTensor>( | |||
| new (std::nothrow)GeTensor(tensor_desc, weight_buffer.GetData() + data_offset, tensor_size)); | |||
| OpDescPtr op_desc = nullptr; | |||
| auto iter = name_to_node.find(node->GetName()); | |||
| if (iter != name_to_node.end()) { | |||
| op_desc = iter->second->GetOpDesc(); | |||
| if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, std::move(tensor))) { | |||
| GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS failed."); | |||
| return FAILED; | |||
| } | |||
| } | |||
| } | |||
| std::unique_ptr<TensorValue> constant_tensor(new (std::nothrow)TensorValue(std::move(tensor_buffer))); | |||
| GE_CHECK_NOTNULL(constant_tensor); | |||
| constant_tensor->SetName("Constant_" + op_desc->GetName()); | |||
| hybrid_model_.constant_tensors_.emplace(node, std::move(constant_tensor)); | |||
| hybrid_model_.constant_tensors_.emplace(node->GetName(), std::move(constant_tensor)); | |||
| GELOGD("[%s] Constant node [%s] added, size = %ld", GetGraphName(), node->GetName().c_str(), tensor_size); | |||
| } | |||
| } | |||
| @@ -47,8 +47,8 @@ class HybridModelBuilder { | |||
| static Status HandleDtString(const GeTensor &tensor, void *var_addr); | |||
| static Status MergeInputNodes(ComputeGraph &compute_graph); | |||
| static Status MergeNetOutputNode(ComputeGraph &compute_graph); | |||
| static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); | |||
| static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); | |||
| static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); | |||
| static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); | |||
| static Status BuildInputMapping(GraphItem &graph_item, | |||
| std::vector<NodeItem *> &data_nodes, | |||
| bool is_root_graph); | |||
| @@ -4676,5 +4676,13 @@ TEST_F(UtestFormatTranspose, invalid_dst_format) { | |||
| EXPECT_EQ(transfer.TransShape(FORMAT_NCHW, src_shape, DT_FLOAT16, FORMAT_C1HWNC0, dst_shape), | |||
| ACL_ERROR_GE_FORMAT_INVALID); | |||
| } | |||
| TEST_F(UtestFormatTranspose, invalid_src_data) { | |||
| uint16_t *data = nullptr; | |||
| TransArgs args(nullptr, FORMAT_NCHW, FORMAT_NHWC, std::vector<int64_t>{1, 3, 8, 8}, std::vector<int64_t>{1, 8, 8, 1}, DT_INT64); | |||
| FormatTransferTranspose transpose; | |||
| TransResult result; | |||
| EXPECT_EQ(transpose.TransFormat(args, result)); | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -245,7 +245,7 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||
| ASSERT_EQ(ret,PARAM_INVALID); | |||
| } | |||
| TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||
| TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||
| ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc"); | |||
| GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||
| HybridModel model(root_model); | |||
| @@ -256,3 +256,53 @@ TEST_F(UtestGeHybrid, init_weight_success) { | |||
| HybridModelExecutor executor(model_ptr, device_id, stream); | |||
| executor.Init(); | |||
| } | |||
| TEST_F(UtestGeHybrid, unfold_subgraphs_success) { | |||
| ComputeGraphPtr merged_graph = nullptr; | |||
| ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond"); | |||
| OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); | |||
| NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); | |||
| ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body"); | |||
| /*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); | |||
| NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ | |||
| OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); | |||
| NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); | |||
| sub_sub_graph2->SetGraphUnknownFlag(true); | |||
| /*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); | |||
| NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); | |||
| sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); | |||
| sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ | |||
| ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph"); | |||
| OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); | |||
| NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); | |||
| sub_graph->SetGraphUnknownFlag(true); | |||
| sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); | |||
| sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); | |||
| sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); | |||
| sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); | |||
| ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph"); | |||
| auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL); | |||
| auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); | |||
| partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); | |||
| partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); | |||
| root_graph->AddSubGraph(sub_sub_graph1); | |||
| root_graph->AddSubGraph(sub_sub_graph2); | |||
| sub_sub_graph1->SetParentGraph(root_graph); | |||
| sub_sub_graph2->SetParentGraph(root_graph); | |||
| sub_sub_graph1->SetParentNode(sub_graph_while_node); | |||
| sub_sub_graph2->SetParentNode(sub_graph_while_node); | |||
| root_graph->AddSubGraph(sub_graph); | |||
| sub_graph->SetParentNode(partitioned_call_node); | |||
| sub_graph->SetParentGraph(root_graph); | |||
| GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph); | |||
| HybridModel hybrid_model(root_model); | |||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
| EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); | |||
| } | |||