Merge pull request !4787 from wenchunjiang/fix_code_checktags/v0.7.0-beta
| @@ -261,6 +261,15 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph, | |||||
| } | } | ||||
| } | } | ||||
| EraseAssign(all_nodes, para_to_written_node, root_graph); | |||||
| root_graph->set_execution_order(exec_order); | |||||
| } | |||||
| void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes, | |||||
| const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node, | |||||
| NotNull<KernelGraphPtr> root_graph) { | |||||
| std::vector<CNodePtr> exec_order = root_graph->execution_order(); | |||||
| ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; }); | |||||
| while (parameter_count.HasValidElem()) { | while (parameter_count.HasValidElem()) { | ||||
| auto [para, read, written] = parameter_count.GetOneValidElem(); | auto [para, read, written] = parameter_count.GetOneValidElem(); | ||||
| MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times."; | MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times."; | ||||
| @@ -293,7 +302,6 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| root_graph->set_execution_order(exec_order); | |||||
| } | } | ||||
| void AscendControlParser::EraseLabel(NotNull<KernelGraphPtr> root_graph) { | void AscendControlParser::EraseLabel(NotNull<KernelGraphPtr> root_graph) { | ||||
| @@ -740,6 +748,18 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||||
| std::vector<CNodePtr> execution_order; | std::vector<CNodePtr> execution_order; | ||||
| uint32_t child_order_index = 0; | uint32_t child_order_index = 0; | ||||
| auto recurse_child_graph = [&](uint32_t index, uint32_t label_index, const CNodePtr &node) { | |||||
| if (!CheckLabelIndex(index, label_index, node)) { | |||||
| MS_LOG(EXCEPTION) << "Check label index fail"; | |||||
| } | |||||
| if (child_order_index >= graph->child_graph_order().size()) { | |||||
| MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); | |||||
| } | |||||
| auto child_graph = graph->child_graph_order()[child_order_index++]; | |||||
| auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); | |||||
| execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); | |||||
| }; | |||||
| for (auto &node : cnodes) { | for (auto &node : cnodes) { | ||||
| uint32_t child_graph_index = 0; | uint32_t child_graph_index = 0; | ||||
| execution_order.push_back(node); | execution_order.push_back(node); | ||||
| @@ -749,27 +769,11 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { | if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { | ||||
| std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); | std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); | ||||
| for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { | for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { | ||||
| if (!CheckLabelIndex(child_graph_index++, *iter, node)) { | |||||
| MS_LOG(EXCEPTION) << "Check label index fail"; | |||||
| } | |||||
| if (child_order_index >= graph->child_graph_order().size()) { | |||||
| MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); | |||||
| } | |||||
| auto child_graph = graph->child_graph_order()[child_order_index++]; | |||||
| auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); | |||||
| execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); | |||||
| recurse_child_graph(child_graph_index++, *iter, node); | |||||
| } | } | ||||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { | } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { | ||||
| uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | ||||
| if (!CheckLabelIndex(child_graph_index, label_index, node)) { | |||||
| MS_LOG(EXCEPTION) << "Check label index fail"; | |||||
| } | |||||
| if (child_order_index >= graph->child_graph_order().size()) { | |||||
| MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); | |||||
| } | |||||
| auto child_graph = graph->child_graph_order()[child_order_index++]; | |||||
| auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); | |||||
| execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); | |||||
| recurse_child_graph(child_graph_index, label_index, node); | |||||
| } | } | ||||
| } | } | ||||
| graph->set_execution_order(execution_order); | graph->set_execution_order(execution_order); | ||||
| @@ -44,6 +44,9 @@ class AscendControlParser { | |||||
| class ReferenceCounter; | class ReferenceCounter; | ||||
| static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list); | static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list); | ||||
| static void EraseAssign(const std::set<CNodePtr> &all_nodes, | |||||
| const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node, | |||||
| NotNull<KernelGraphPtr> root_graph); | |||||
| static void EraseLabel(NotNull<KernelGraphPtr> root_graph); | static void EraseLabel(NotNull<KernelGraphPtr> root_graph); | ||||
| static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg, | static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg, | ||||
| const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list, | const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list, | ||||
| @@ -77,6 +80,7 @@ class AscendControlParser { | |||||
| class AscendControlParser::ReferenceCounter { | class AscendControlParser::ReferenceCounter { | ||||
| public: | public: | ||||
| explicit ReferenceCounter(std::function<bool(int32_t, int32_t)> func) : predicate_(func), count_() {} | explicit ReferenceCounter(std::function<bool(int32_t, int32_t)> func) : predicate_(func), count_() {} | ||||
| ~ReferenceCounter() = default; | |||||
| void AddReadCount(const AnfNodePtr &key, int32_t num); | void AddReadCount(const AnfNodePtr &key, int32_t num); | ||||
| void AddWriteCount(const AnfNodePtr &key, int32_t num); | void AddWriteCount(const AnfNodePtr &key, int32_t num); | ||||
| void EraseElem(const AnfNodePtr &key); | void EraseElem(const AnfNodePtr &key); | ||||
| @@ -236,6 +236,5 @@ void AscendInferenceSession::GetModelInputsInfo(uint32_t graph_id, std::vector<t | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -222,12 +222,10 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||||
| auto graph_order = GetGraphOrder(final_graph_id_); | auto graph_order = GetGraphOrder(final_graph_id_); | ||||
| auto &graph_type = GetGraphOrderType(final_graph_id_); | auto &graph_type = GetGraphOrderType(final_graph_id_); | ||||
| for (size_t i = 0; i < graph_order.size(); i++) { | for (size_t i = 0; i < graph_order.size(); i++) { | ||||
| if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { | |||||
| continue; | |||||
| if (!(graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START)) { | |||||
| auto child_graph = GetGraph(graph_order[i]); | |||||
| CompileChildGraph(child_graph); | |||||
| } | } | ||||
| MS_LOG(INFO) << "Start build child graph " << graph_order[i]; | |||||
| auto child_graph = GetGraph(graph_order[i]); | |||||
| CompileChildGraph(child_graph); | |||||
| } | } | ||||
| SetSummaryNodes(graph.get()); | SetSummaryNodes(graph.get()); | ||||
| // merge child graph | // merge child graph | ||||
| @@ -379,7 +379,6 @@ void GPUSession::PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph | |||||
| tensor_loader->EmptyPrevTensor(); | tensor_loader->EmptyPrevTensor(); | ||||
| } | } | ||||
| #endif | #endif | ||||
| } // namespace gpu | } // namespace gpu | ||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,7 +33,8 @@ using std::string; | |||||
| using std::vector; | using std::vector; | ||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| namespace mindspore::inference { | |||||
| namespace mindspore { | |||||
| namespace inference { | |||||
| std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) { | std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) { | ||||
| try { | try { | ||||
| @@ -153,7 +154,10 @@ Status ServingTensor2MSTensor(size_t index, const InferTensorBase &out_tensor, t | |||||
| MSI_LOG_ERROR << "invalid data buffer"; | MSI_LOG_ERROR << "invalid data buffer"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| memcpy_s(ms_tensor->data_c(), ms_tensor->Size(), out_tensor.data(), out_tensor.data_size()); | |||||
| auto ret_code = memcpy_s(ms_tensor->data_c(), ms_tensor->Size(), out_tensor.data(), out_tensor.data_size()); | |||||
| if (ret_code != 0) { | |||||
| MS_LOG(ERROR) << "Failed to copy data from ms_tensor to out_tensor."; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -272,6 +276,10 @@ void MSInferSession::RegAllOp() { | |||||
| return; | return; | ||||
| } | } | ||||
| PyObject *c_expression_dict = PyModule_GetDict(c_expression); | PyObject *c_expression_dict = PyModule_GetDict(c_expression); | ||||
| if (c_expression_dict == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to get dict from mindspore._c_expression module."; | |||||
| return; | |||||
| } | |||||
| PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); | PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); | ||||
| if (op_info_loader_class == nullptr) { | if (op_info_loader_class == nullptr) { | ||||
| @@ -392,4 +400,5 @@ Status MSInferSession::GetModelInputsInfo(uint32_t model_id, std::vector<inferen | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace mindspore::inference | |||||
| } // namespace inference | |||||
| } // namespace mindspore | |||||
| @@ -49,7 +49,7 @@ class MSInferSession : public InferSession { | |||||
| std::shared_ptr<session::SessionBasic> session_impl_ = nullptr; | std::shared_ptr<session::SessionBasic> session_impl_ = nullptr; | ||||
| std::vector<uint32_t> graph_id_; | std::vector<uint32_t> graph_id_; | ||||
| std::string device_type_; | std::string device_type_; | ||||
| int32_t device_id_; | |||||
| int32_t device_id_ = 0; | |||||
| #ifdef ENABLE_D | #ifdef ENABLE_D | ||||
| rtContext_t context_ = nullptr; | rtContext_t context_ = nullptr; | ||||
| #endif | #endif | ||||
| @@ -246,6 +246,10 @@ void KernelGraph::SetExecOrderByDefault() { | |||||
| } | } | ||||
| CheckLoop(); | CheckLoop(); | ||||
| // resort start label / end goto | // resort start label / end goto | ||||
| execution_order_ = SortStartLabelAndEndGoto(); | |||||
| } | |||||
| std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() { | |||||
| std::vector<CNodePtr> re_order; | std::vector<CNodePtr> re_order; | ||||
| if (start_label_ != nullptr) { | if (start_label_ != nullptr) { | ||||
| re_order.push_back(start_label_); | re_order.push_back(start_label_); | ||||
| @@ -272,7 +276,7 @@ void KernelGraph::SetExecOrderByDefault() { | |||||
| if (end_goto_ != nullptr) { | if (end_goto_ != nullptr) { | ||||
| re_order.push_back(end_goto_); | re_order.push_back(end_goto_); | ||||
| } | } | ||||
| execution_order_ = re_order; | |||||
| return re_order; | |||||
| } | } | ||||
| void KernelGraph::CheckLoop() { | void KernelGraph::CheckLoop() { | ||||
| @@ -736,27 +740,29 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de | |||||
| for (const auto &tmp : prior_nodes) { | for (const auto &tmp : prior_nodes) { | ||||
| GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | ||||
| } | } | ||||
| std::vector<AnfNodePtr> real_depend_nodes; | std::vector<AnfNodePtr> real_depend_nodes; | ||||
| std::set<AnfNodePtr> depend_visited; | std::set<AnfNodePtr> depend_visited; | ||||
| for (const auto &tmp : depend_nodes) { | for (const auto &tmp : depend_nodes) { | ||||
| GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); | GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); | ||||
| } | } | ||||
| UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes); | |||||
| } | |||||
| } | |||||
| for (auto &first_node : real_prior_nodes) { | |||||
| if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { | |||||
| void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, | |||||
| const std::vector<AnfNodePtr> &real_depend_nodes) { | |||||
| for (auto &first_node : real_prior_nodes) { | |||||
| if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| for (auto &second_node : real_depend_nodes) { | |||||
| if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| for (auto &second_node : real_depend_nodes) { | |||||
| if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(first_node); | |||||
| MS_EXCEPTION_IF_NULL(second_node); | |||||
| MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() | |||||
| << ",second node:" << second_node->DebugString(); | |||||
| AddDependEdge(second_node, first_node, 1); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(first_node); | |||||
| MS_EXCEPTION_IF_NULL(second_node); | |||||
| MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); | |||||
| AddDependEdge(second_node, first_node, 1); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -174,6 +174,8 @@ class KernelGraph : public FuncGraph { | |||||
| void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); | void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); | ||||
| // add node depend edge by data edge or control depend | // add node depend edge by data edge or control depend | ||||
| void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); | void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); | ||||
| void UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, | |||||
| const std::vector<AnfNodePtr> &real_depend_nodes); | |||||
| // handle control depend | // handle control depend | ||||
| std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node); | std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node); | ||||
| bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, | ||||
| @@ -183,6 +185,7 @@ class KernelGraph : public FuncGraph { | |||||
| AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); | AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); | ||||
| AnfNodePtr TransCNodeTuple(const CNodePtr &node); | AnfNodePtr TransCNodeTuple(const CNodePtr &node); | ||||
| AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); | AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); | ||||
| std::vector<CNodePtr> SortStartLabelAndEndGoto(); | |||||
| std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | ||||
| std::vector<AnfNodePtr> child_graph_result_; | std::vector<AnfNodePtr> child_graph_result_; | ||||