| @@ -1637,6 +1637,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem | |||||
| auto temp_graph = MakeShared<ComputeGraph>("temp"); | auto temp_graph = MakeShared<ComputeGraph>("temp"); | ||||
| GE_CHECK_NOTNULL(temp_graph); | GE_CHECK_NOTNULL(temp_graph); | ||||
| auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); | auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); | ||||
| wrapper_op_desc->SetId(parent_node_item->node_id); | |||||
| GeModelPtr ge_model = subgraph_models_[subgraph_name]; | GeModelPtr ge_model = subgraph_models_[subgraph_name]; | ||||
| GE_CHECK_NOTNULL(ge_model); | GE_CHECK_NOTNULL(ge_model); | ||||
| hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); | hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); | ||||
| @@ -1916,7 +1917,6 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||||
| NodeItem *node_item = nullptr; | NodeItem *node_item = nullptr; | ||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); | GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | ||||
| GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item)); | |||||
| GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | ||||
| node_item->input_start = input_start; | node_item->input_start = input_start; | ||||
| @@ -2069,22 +2069,17 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||||
| } | } | ||||
| Status HybridModelBuilder::ParseDependentByParallelGroup() { | Status HybridModelBuilder::ParseDependentByParallelGroup() { | ||||
| for (auto &it : hybrid_model_.node_items_) { | |||||
| GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get())); | |||||
| } | |||||
| for (const auto &it : node_to_parallel_groups_) { | for (const auto &it : node_to_parallel_groups_) { | ||||
| auto node_item = it.first; | auto node_item = it.first; | ||||
| auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); | |||||
| auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); | |||||
| for (const auto ¶llel_group : it.second) { | for (const auto ¶llel_group : it.second) { | ||||
| auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; | auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; | ||||
| NodeItem *nearest_dep_node = nullptr; | NodeItem *nearest_dep_node = nullptr; | ||||
| int max_id = -1; | int max_id = -1; | ||||
| for (auto &dep_node : dependent_nodes) { | for (auto &dep_node : dependent_nodes) { | ||||
| if (node_item == dep_node) { | |||||
| continue; | |||||
| } | |||||
| auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node); | |||||
| if (src_engine_type == dst_engine_type) { | |||||
| continue; | |||||
| } | |||||
| if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) { | if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) { | ||||
| nearest_dep_node = dep_node; | nearest_dep_node = dep_node; | ||||
| max_id = dep_node->node_id; | max_id = dep_node->node_id; | ||||
| @@ -2092,10 +2087,12 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() { | |||||
| } | } | ||||
| if (nearest_dep_node != nullptr) { | if (nearest_dep_node != nullptr) { | ||||
| GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", | |||||
| parallel_group.c_str(), | |||||
| nearest_dep_node->NodeName().c_str(), | |||||
| node_item->NodeName().c_str()); | |||||
| GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str()); | |||||
| auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node); | |||||
| if (src_engine_type == dst_executor_type) { | |||||
| GELOGD("No need to add dependency for nodes with same executor type"); | |||||
| continue; | |||||
| } | |||||
| auto &deps = node_item->dependents_for_execution; | auto &deps = node_item->dependents_for_execution; | ||||
| if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) { | if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) { | ||||
| GELOGD("%s->%s Already has dependency, skip it", | GELOGD("%s->%s Already has dependency, skip it", | ||||
| @@ -2105,6 +2102,10 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() { | |||||
| } | } | ||||
| nearest_dep_node->has_observer = true; | nearest_dep_node->has_observer = true; | ||||
| deps.emplace_back(nearest_dep_node->node); | deps.emplace_back(nearest_dep_node->node); | ||||
| GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", | |||||
| parallel_group.c_str(), | |||||
| nearest_dep_node->NodeName().c_str(), | |||||
| node_item->NodeName().c_str()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||