diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index bec2d0c67a..d77d366c09 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -146,8 +146,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { std::vector all_graphs; auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); // Update Graph Dynamic Shape Attr - UpdateGraphDynamicShapeAttr(NOT_NULL(root_graph)); - root_graph->UpdateGraphDynamicAttr(); + UpdateAllGraphDynamicShapeAttr(all_graphs); BackendOptimization(all_graphs); // empty graph dont entry to backend if (root_graph->execution_order().empty()) { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 2e9ca1918d..58ab3be525 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1527,6 +1527,17 @@ bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) { return false; } +void SessionBasic::UpdateAllGraphDynamicShapeAttr(const std::vector &all_graphs) { + bool is_dynamic = false; + for (const auto &graph : all_graphs) { + UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); + is_dynamic = graph->is_dynamic_shape() || is_dynamic; + } + if (is_dynamic && all_graphs.size() > 1) { + MS_LOG(EXCEPTION) << "Dynamic shape is not supported with control flow."; + } +} + void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull &root_graph) { for (const auto &cnode : root_graph->execution_order()) { auto output_dynamic = IsNodeOutputDynamicShape(NOT_NULL(cnode)); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index d526cd781d..b304eb1053 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -180,6 +180,7 @@ class SessionBasic : public std::enable_shared_from_this { void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector &node_list); void UpdateGraphDynamicShapeAttr(const NotNull &root_graph); + void UpdateAllGraphDynamicShapeAttr(const std::vector &all_graphs); std::unordered_map> graphs_; std::unordered_map> run_op_graphs_;