Merge pull request !206 from chenfei/mastertags/v0.2.0-alpha
| @@ -154,6 +154,9 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||||
| const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | ||||
| const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | ||||
| const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | ||||
| const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||||
| const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||||
| const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||||
| // NN | // NN | ||||
| const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | ||||
| @@ -160,6 +160,9 @@ extern const PrimitivePtr kPrimMul; | |||||
| extern const PrimitivePtr kPrimMinimum; | extern const PrimitivePtr kPrimMinimum; | ||||
| extern const PrimitivePtr kPrimMaximum; | extern const PrimitivePtr kPrimMaximum; | ||||
| extern const PrimitivePtr kPrimSquare; | extern const PrimitivePtr kPrimSquare; | ||||
| extern const PrimitivePtr kPrimEqual; | |||||
| extern const PrimitivePtr kPrimLess; | |||||
| extern const PrimitivePtr kPrimLessEqual; | |||||
| // NN | // NN | ||||
| extern const PrimitivePtr kPrimFlatten; | extern const PrimitivePtr kPrimFlatten; | ||||
| @@ -506,11 +506,13 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true | |||||
| kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); | kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); | ||||
| kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); | kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); | ||||
| kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); | kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); | ||||
| // condition graph's output must be single output | |||||
| if (condition_graph->outputs().size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Condition_graph output num " << condition_graph_id << " should be 1"; | |||||
| auto cond_output_it = condition_output_.find(condition_graph_id); | |||||
| if (cond_output_it == condition_output_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; | |||||
| } | } | ||||
| AnfNodePtr cond_output_kernel = condition_graph->outputs()[0]; | |||||
| auto cond_output_kernel = | |||||
| AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first; | |||||
| MS_EXCEPTION_IF_NULL(cond_output_kernel); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; | std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; | ||||
| CNodePtr switch_node = condition_graph->NewCNode(inputs); | CNodePtr switch_node = condition_graph->NewCNode(inputs); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get()); | AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get()); | ||||
| @@ -569,12 +571,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { | |||||
| } | } | ||||
| } | } | ||||
| void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id) { | |||||
| void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id, | |||||
| const AnfNodePtr &output) { | |||||
| if (switches_.find(cond_graph_id) != switches_.end()) { | if (switches_.find(cond_graph_id) != switches_.end()) { | ||||
| MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; | MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; | ||||
| return; | return; | ||||
| } | } | ||||
| switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id); | switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id); | ||||
| condition_output_[cond_graph_id] = output; | |||||
| MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; | MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; | ||||
| // set the type of condition graph | // set the type of condition graph | ||||
| auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); | auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); | ||||
| @@ -682,12 +686,14 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An | |||||
| auto from_graph_id = GetGraphIdByNode(front_anf); | auto from_graph_id = GetGraphIdByNode(front_anf); | ||||
| auto from_graph = GetGraph(from_graph_id); | auto from_graph = GetGraph(from_graph_id); | ||||
| MS_EXCEPTION_IF_NULL(from_graph); | MS_EXCEPTION_IF_NULL(from_graph); | ||||
| auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); | |||||
| auto to_graph = GetGraph(to_graph_id); | |||||
| auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); | |||||
| MS_EXCEPTION_IF_NULL(to_graph); | |||||
| MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" | MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" | ||||
| << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) | << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) | ||||
| << "]"; | << "]"; | ||||
| // a node should not assign to itself | // a node should not assign to itself | ||||
| auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); | |||||
| if (backend_arg.get() == backend_parameter.get()) { | if (backend_arg.get() == backend_parameter.get()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -703,15 +709,16 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An | |||||
| return; | return; | ||||
| } | } | ||||
| } | } | ||||
| InsertMultipleAssignToGraph(from_graph_id, backend_arg, backend_parameter); | |||||
| // if front anf is a parameter, we can assign the value back, because backend_parameter | |||||
| // won't be changed in it's graph unless it's a weight. If backend_parameter is a weight, | |||||
| // we do should assign the value back. | |||||
| auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); | |||||
| auto to_graph = GetGraph(to_graph_id); | |||||
| MS_EXCEPTION_IF_NULL(to_graph); | |||||
| // if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device | |||||
| // type same to arg | |||||
| if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) { | |||||
| AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get()); | |||||
| } | |||||
| InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); | |||||
| // if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph | |||||
| // unless it's a weigth.If backend_parameter is a weight,we do should assign the value back | |||||
| if (backend_arg->isa<Parameter>() && !to_graph->execution_order().empty()) { | if (backend_arg->isa<Parameter>() && !to_graph->execution_order().empty()) { | ||||
| InsertMultipleAssignToGraph(to_graph_id, backend_parameter, backend_arg); | |||||
| InsertAssignToGraph(to_graph_id, backend_parameter, backend_arg); | |||||
| } | } | ||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| @@ -755,7 +762,25 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { | |||||
| DumpGraphInputArgs(args); | DumpGraphInputArgs(args); | ||||
| UpdateGraphOrder(g); | UpdateGraphOrder(g); | ||||
| std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); | std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); | ||||
| auto valid_inputs = to_graph->ValidInputs(); | |||||
| size_t real_args_size = 0; | |||||
| for (size_t i = 0; i < args.size(); i++) { | |||||
| real_args_size += AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}).size(); | |||||
| } | |||||
| if (real_args_size != graph_inputs.size()) { | |||||
| for (size_t j = 0; j < valid_inputs.size(); j++) { | |||||
| if (valid_inputs[j]) { | |||||
| MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); | |||||
| } | |||||
| } | |||||
| MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() | |||||
| << " not equal"; | |||||
| } | |||||
| size_t input_index = 0; | size_t input_index = 0; | ||||
| if (graph_inputs.size() != valid_inputs.size()) { | |||||
| MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size() | |||||
| << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; | |||||
| } | |||||
| for (size_t i = 0; i < args.size(); i++) { | for (size_t i = 0; i < args.size(); i++) { | ||||
| if (input_index >= graph_inputs.size()) { | if (input_index >= graph_inputs.size()) { | ||||
| MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); | MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); | ||||
| @@ -763,6 +788,10 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { | |||||
| if (utils::isa<AnfNodePtr>(args[i])) { | if (utils::isa<AnfNodePtr>(args[i])) { | ||||
| // arg is a anf node | // arg is a anf node | ||||
| for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) { | for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) { | ||||
| if (!valid_inputs[input_index]) { | |||||
| MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); | |||||
| continue; | |||||
| } | |||||
| SetChildGraphParameter(real_arg, graph_inputs[input_index]); | SetChildGraphParameter(real_arg, graph_inputs[input_index]); | ||||
| input_index++; | input_index++; | ||||
| } | } | ||||
| @@ -49,9 +49,8 @@ class AscendSession : public SessionBasic { | |||||
| // set output of final graph | // set output of final graph | ||||
| void SetFinalGraphOutput(const BaseRef &output) override; | void SetFinalGraphOutput(const BaseRef &output) override; | ||||
| // insert switch and set the relative active ops | // insert switch and set the relative active ops | ||||
| void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g) override; | |||||
| // set args of child graph. the arg maybe come from a output of other child graphs, | |||||
| // or from final graph's parameter | |||||
| void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; | |||||
| // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter | |||||
| void SetChildGraphInput(GraphId g, const VectorRef &args) override; | void SetChildGraphInput(GraphId g, const VectorRef &args) override; | ||||
| // get graph id in child graphs by ME front anf node pointer | // get graph id in child graphs by ME front anf node pointer | ||||
| GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; | GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; | ||||
| @@ -116,6 +115,7 @@ class AscendSession : public SessionBasic { | |||||
| std::unordered_map<GraphId, GraphId> while_condition_graphs_; | std::unordered_map<GraphId, GraphId> while_condition_graphs_; | ||||
| // record all conditions | // record all conditions | ||||
| std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_; | std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_; | ||||
| std::unordered_map<GraphId, AnfNodePtr> condition_output_; | |||||
| // final_graph_id is used in every root graph has it's own session situation | // final_graph_id is used in every root graph has it's own session situation | ||||
| GraphId final_graph_id_; | GraphId final_graph_id_; | ||||
| }; | }; | ||||
| @@ -372,8 +372,7 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de | |||||
| MS_EXCEPTION_IF_NULL(depend_node); | MS_EXCEPTION_IF_NULL(depend_node); | ||||
| std::vector<AnfNodePtr> prior_nodes = {prior_node}; | std::vector<AnfNodePtr> prior_nodes = {prior_node}; | ||||
| std::vector<AnfNodePtr> depend_nodes = {depend_node}; | std::vector<AnfNodePtr> depend_nodes = {depend_node}; | ||||
| MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "],depend node[" << depend_node->DebugString() | |||||
| << "],depend_mode=[" << AnfAlgo::GetNodeAttr<int>(cnode, "depend_mode") << "]"; | |||||
| MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString(); | |||||
| if (prior_node->isa<Parameter>()) { | if (prior_node->isa<Parameter>()) { | ||||
| prior_nodes = GetOutputNodes(prior_node); | prior_nodes = GetOutputNodes(prior_node); | ||||
| } | } | ||||
| @@ -86,6 +86,9 @@ class KernelGraph : public FuncGraph { | |||||
| bool executable() const { return executable_; } | bool executable() const { return executable_; } | ||||
| // set executable of graph | // set executable of graph | ||||
| void set_executable(bool executable) { executable_ = executable; } | void set_executable(bool executable) { executable_ = executable; } | ||||
| // set invalid inputs for control sink | |||||
| std::vector<bool> *MutableValidInputs() { return &valid_inputs_; } | |||||
| std::vector<bool> ValidInputs() { return valid_inputs_; } | |||||
| private: | private: | ||||
| // remove value node form graph | // remove value node form graph | ||||
| @@ -118,6 +121,8 @@ class KernelGraph : public FuncGraph { | |||||
| std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; | std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; | ||||
| // graph needn't execute | // graph needn't execute | ||||
| bool executable_; | bool executable_; | ||||
| // valid inputs | |||||
| std::vector<bool> valid_inputs_; | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | ||||
| @@ -243,29 +243,38 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { | |||||
| return new_value_node; | return new_value_node; | ||||
| } | } | ||||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) { | |||||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(anf); | MS_EXCEPTION_IF_NULL(anf); | ||||
| if (!anf->isa<Parameter>()) { | if (!anf->isa<Parameter>()) { | ||||
| MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; | MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; | ||||
| } | } | ||||
| auto graph_inputs = graph->MutableInputs(); | auto graph_inputs = graph->MutableInputs(); | ||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | MS_EXCEPTION_IF_NULL(graph_inputs); | ||||
| auto valid_inputs = graph->MutableValidInputs(); | |||||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||||
| ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); | ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); | ||||
| graph->FrontBackendlMapAdd(anf, new_parameter); | |||||
| graph_inputs->push_back(new_parameter); | graph_inputs->push_back(new_parameter); | ||||
| valid_inputs->push_back(valid_input); | |||||
| return new_parameter; | return new_parameter; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { | |||||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| std::vector<AnfNodePtr> parameters; | std::vector<AnfNodePtr> parameters; | ||||
| std::vector<AnfNodePtr> pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); | std::vector<AnfNodePtr> pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); | ||||
| auto valid_inputs = graph->MutableValidInputs(); | |||||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||||
| auto graph_inputs = graph->MutableInputs(); | |||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||||
| auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | ||||
| auto parameter = graph->NewParameter(); | auto parameter = graph->NewParameter(); | ||||
| MS_EXCEPTION_IF_NULL(parameter); | MS_EXCEPTION_IF_NULL(parameter); | ||||
| parameter->set_abstract(abstract); | parameter->set_abstract(abstract); | ||||
| parameters.push_back(graph->NewParameter(parameter)); | |||||
| auto new_parameter = graph->NewParameter(parameter); | |||||
| parameters.push_back(new_parameter); | |||||
| valid_inputs->push_back(valid_input); | |||||
| graph_inputs->push_back(new_parameter); | |||||
| }; | }; | ||||
| for (const auto &out_node : pre_graph_out) { | for (const auto &out_node : pre_graph_out) { | ||||
| MS_EXCEPTION_IF_NULL(out_node); | MS_EXCEPTION_IF_NULL(out_node); | ||||
| @@ -287,18 +296,15 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelG | |||||
| return parameters; | return parameters; | ||||
| } | } | ||||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) { | |||||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(anf); | MS_EXCEPTION_IF_NULL(anf); | ||||
| if (!anf->isa<CNode>()) { | if (!anf->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a cnode"; | |||||
| MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode"; | |||||
| } | } | ||||
| MS_LOG(INFO) << "create a new parameter from cnode[" << anf->DebugString() << "]"; | |||||
| auto parameters = CreateParameterFromTuple(anf, graph); | |||||
| auto graph_inputs = graph->MutableInputs(); | |||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||||
| (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(*graph_inputs)); | |||||
| MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; | |||||
| auto parameters = CreateParameterFromTuple(anf, valid_input, graph); | |||||
| if (parameters.empty()) { | if (parameters.empty()) { | ||||
| MS_LOG(EXCEPTION) << "no parameter exist!!"; | |||||
| MS_LOG(EXCEPTION) << "No parameter exist!!"; | |||||
| } | } | ||||
| if (parameters.size() == 1) { | if (parameters.size() == 1) { | ||||
| return parameters[0]; | return parameters[0]; | ||||
| @@ -307,7 +313,7 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph | |||||
| (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); | (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); | ||||
| auto make_tuple = graph->NewCNode(make_tuple_input); | auto make_tuple = graph->NewCNode(make_tuple_input); | ||||
| MS_EXCEPTION_IF_NULL(make_tuple); | MS_EXCEPTION_IF_NULL(make_tuple); | ||||
| MS_LOG(INFO) << "new make tuple [" << make_tuple->DebugString() << "] of parameters"; | |||||
| MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; | |||||
| return make_tuple; | return make_tuple; | ||||
| } | } | ||||
| @@ -397,14 +403,20 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { | |||||
| GraphId SessionBasic::graph_sum_ = 0; | GraphId SessionBasic::graph_sum_ = 0; | ||||
| CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { | |||||
| CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, | |||||
| bool *from_other_graph, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(from_other_graph); | |||||
| MS_EXCEPTION_IF_NULL(other_graph_cnode); | |||||
| *from_other_graph = false; | |||||
| // get primitive of old node | // get primitive of old node | ||||
| auto prim = AnfAlgo::GetCNodePrimitive(cnode); | auto prim = AnfAlgo::GetCNodePrimitive(cnode); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| // push attr to inputs[0] of new cnode | // push attr to inputs[0] of new cnode | ||||
| std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))}; | std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))}; | ||||
| // if has multiple depends,only select first depend as parameter | |||||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | ||||
| auto anf = cnode->inputs()[input_idx]; | auto anf = cnode->inputs()[input_idx]; | ||||
| MS_EXCEPTION_IF_NULL(anf); | MS_EXCEPTION_IF_NULL(anf); | ||||
| @@ -412,6 +424,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | ||||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | ||||
| continue; | continue; | ||||
| } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { | |||||
| cnode_inputs.push_back((*other_graph_cnode)[anf]); | |||||
| continue; | |||||
| } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) { | } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) { | ||||
| // if input is a value node, | // if input is a value node, | ||||
| auto new_value_node = CreateNewValueNode(anf, graph); | auto new_value_node = CreateNewValueNode(anf, graph); | ||||
| @@ -421,38 +436,60 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||||
| continue; | continue; | ||||
| } else if (anf->isa<Parameter>()) { | } else if (anf->isa<Parameter>()) { | ||||
| // if anf is a parameter | // if anf is a parameter | ||||
| cnode_inputs.emplace_back(CreateNewParameterFromParameter(anf, graph)); | |||||
| auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); | |||||
| cnode_inputs.push_back(new_parameter); | |||||
| if (GetGraphIdByNode(anf) == kInvalidGraphId) { | |||||
| graph->FrontBackendlMapAdd(anf, new_parameter); | |||||
| } else { | |||||
| (*other_graph_cnode)[anf] = new_parameter; | |||||
| } | |||||
| continue; | continue; | ||||
| } else if (anf->isa<CNode>()) { | } else if (anf->isa<CNode>()) { | ||||
| *from_other_graph = true; | |||||
| // the input node is a cnode from other graph | // the input node is a cnode from other graph | ||||
| cnode_inputs.emplace_back(CreateNewParameterFromCNode(anf, graph)); | |||||
| auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); | |||||
| cnode_inputs.push_back(parameter_from_cnode); | |||||
| (*other_graph_cnode)[anf] = parameter_from_cnode; | |||||
| continue; | continue; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "unexpected input[" << anf->DebugString() << "]"; | |||||
| MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; | |||||
| } | } | ||||
| return graph->NewCNode(cnode_inputs); | |||||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); | |||||
| auto new_cnode = graph->NewCNode(cnode_inputs); | |||||
| TraceManager::EndTrace(); | |||||
| return new_cnode; | |||||
| } | } | ||||
| KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode; | |||||
| auto graph = std::make_shared<KernelGraph>(); | auto graph = std::make_shared<KernelGraph>(); | ||||
| graph->set_graph_id(graph_sum_); | graph->set_graph_id(graph_sum_); | ||||
| MS_LOG(INFO) << "Create graph: " << graph_sum_; | |||||
| size_t from_other_graph_depend_num = 0; | |||||
| for (const auto &node : lst) { | for (const auto &node : lst) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_LOG(DEBUG) << "start create new cnode,node = " << node->DebugString(); | |||||
| MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); | |||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "Inst node " << node->DebugString() << " is not CNode"; | |||||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode"; | |||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); | |||||
| // create a new cnode object | // create a new cnode object | ||||
| auto new_cnode = CreateNewCNode(cnode, graph.get()); | |||||
| bool from_other_graph = false; | |||||
| // only first depend from other graph can create | |||||
| bool valid_input = true; | |||||
| if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { | |||||
| valid_input = false; | |||||
| } | |||||
| auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { | |||||
| from_other_graph_depend_num++; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(new_cnode); | MS_EXCEPTION_IF_NULL(new_cnode); | ||||
| new_cnode->set_abstract(cnode->abstract()); | new_cnode->set_abstract(cnode->abstract()); | ||||
| new_cnode->set_scope(cnode->scope()); | new_cnode->set_scope(cnode->scope()); | ||||
| // record map relations between anf from ME and new anf node used in backend | // record map relations between anf from ME and new anf node used in backend | ||||
| graph->FrontBackendlMapAdd(node, new_cnode); | graph->FrontBackendlMapAdd(node, new_cnode); | ||||
| TraceManager::EndTrace(); | |||||
| } | } | ||||
| // add a make_tuple at the end of graph as output | // add a make_tuple at the end of graph as output | ||||
| graph->set_output(ConstructOutput(outputs, graph)); | graph->set_output(ConstructOutput(outputs, graph)); | ||||
| @@ -631,12 +668,15 @@ void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor: | |||||
| CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { | CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| std::vector<AnfNodePtr> output_args; | std::vector<AnfNodePtr> output_args; | ||||
| auto FindEqu = [graph](const AnfNodePtr &out) -> AnfNodePtr { | |||||
| auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { | |||||
| auto backend_anf = graph->GetBackendAnfByFrontAnf(out); | auto backend_anf = graph->GetBackendAnfByFrontAnf(out); | ||||
| if (backend_anf != nullptr) { | if (backend_anf != nullptr) { | ||||
| return backend_anf; | return backend_anf; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Can not find the node in the equiv map!"; | |||||
| for (const auto &output : outputs) { | |||||
| MS_LOG(INFO) << "output:" << output->DebugString(); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; | |||||
| }; | }; | ||||
| output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); | output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); | ||||
| (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), | (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), | ||||
| @@ -69,14 +69,15 @@ class SessionBasic { | |||||
| std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | ||||
| CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); | |||||
| CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); | |||||
| // set parameters of final graph | // set parameters of final graph | ||||
| virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; } | virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; } | ||||
| // set output of final graph | // set output of final graph | ||||
| virtual void SetFinalGraphOutput(const BaseRef &) {} | virtual void SetFinalGraphOutput(const BaseRef &) {} | ||||
| // insert switch and set the relative active ops | // insert switch and set the relative active ops | ||||
| virtual void SwitchCompile(GraphId, GraphId, GraphId) {} | |||||
| virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} | |||||
| // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter | // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter | ||||
| virtual void SetChildGraphInput(GraphId, const VectorRef &) {} | virtual void SetChildGraphInput(GraphId, const VectorRef &) {} | ||||
| // get graph id in child graphs by ME front anf node pointer | // get graph id in child graphs by ME front anf node pointer | ||||
| @@ -136,7 +136,7 @@ void MsBackend::SetSwitchGraph() { | |||||
| MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); | MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; | MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; | ||||
| sess_->SwitchCompile(cond_g, true_g, false_g); | |||||
| sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_)); | |||||
| } | } | ||||
| is_switch_call_ = false; | is_switch_call_ = false; | ||||
| MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; | MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; | ||||