Merge pull request !206 from chenfei/mastertags/v0.2.0-alpha
| @@ -154,6 +154,9 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||
| const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | |||
| const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | |||
| const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | |||
| const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||
| const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||
| const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||
| // NN | |||
| const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| @@ -160,6 +160,9 @@ extern const PrimitivePtr kPrimMul; | |||
| extern const PrimitivePtr kPrimMinimum; | |||
| extern const PrimitivePtr kPrimMaximum; | |||
| extern const PrimitivePtr kPrimSquare; | |||
| extern const PrimitivePtr kPrimEqual; | |||
| extern const PrimitivePtr kPrimLess; | |||
| extern const PrimitivePtr kPrimLessEqual; | |||
| // NN | |||
| extern const PrimitivePtr kPrimFlatten; | |||
| @@ -506,11 +506,13 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true | |||
| kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); | |||
| kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); | |||
| kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); | |||
| // condition graph's output must be single output | |||
| if (condition_graph->outputs().size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Condition_graph output num " << condition_graph_id << " should be 1"; | |||
| auto cond_output_it = condition_output_.find(condition_graph_id); | |||
| if (cond_output_it == condition_output_.end()) { | |||
| MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; | |||
| } | |||
| AnfNodePtr cond_output_kernel = condition_graph->outputs()[0]; | |||
| auto cond_output_kernel = | |||
| AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first; | |||
| MS_EXCEPTION_IF_NULL(cond_output_kernel); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; | |||
| CNodePtr switch_node = condition_graph->NewCNode(inputs); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get()); | |||
| @@ -569,12 +571,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { | |||
| } | |||
| } | |||
| void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id) { | |||
| void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id, | |||
| const AnfNodePtr &output) { | |||
| if (switches_.find(cond_graph_id) != switches_.end()) { | |||
| MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; | |||
| return; | |||
| } | |||
| switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id); | |||
| condition_output_[cond_graph_id] = output; | |||
| MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; | |||
| // set the type of condition graph | |||
| auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); | |||
| @@ -682,12 +686,14 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An | |||
| auto from_graph_id = GetGraphIdByNode(front_anf); | |||
| auto from_graph = GetGraph(from_graph_id); | |||
| MS_EXCEPTION_IF_NULL(from_graph); | |||
| auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); | |||
| auto to_graph = GetGraph(to_graph_id); | |||
| auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); | |||
| MS_EXCEPTION_IF_NULL(to_graph); | |||
| MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" | |||
| << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) | |||
| << "]"; | |||
| // a node should not assign to itself | |||
| auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); | |||
| if (backend_arg.get() == backend_parameter.get()) { | |||
| return; | |||
| } | |||
| @@ -703,15 +709,16 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An | |||
| return; | |||
| } | |||
| } | |||
| InsertMultipleAssignToGraph(from_graph_id, backend_arg, backend_parameter); | |||
| // if front anf is a parameter, we can assign the value back, because backend_parameter | |||
| // won't be changed in it's graph unless it's a weight. If backend_parameter is a weight, | |||
| // we do should assign the value back. | |||
| auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); | |||
| auto to_graph = GetGraph(to_graph_id); | |||
| MS_EXCEPTION_IF_NULL(to_graph); | |||
| // if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device | |||
| // type same to arg | |||
| if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) { | |||
| AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get()); | |||
| } | |||
| InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); | |||
| // if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph | |||
| // unless it's a weigth.If backend_parameter is a weight,we do should assign the value back | |||
| if (backend_arg->isa<Parameter>() && !to_graph->execution_order().empty()) { | |||
| InsertMultipleAssignToGraph(to_graph_id, backend_parameter, backend_arg); | |||
| InsertAssignToGraph(to_graph_id, backend_parameter, backend_arg); | |||
| } | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| @@ -755,7 +762,25 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { | |||
| DumpGraphInputArgs(args); | |||
| UpdateGraphOrder(g); | |||
| std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); | |||
| auto valid_inputs = to_graph->ValidInputs(); | |||
| size_t real_args_size = 0; | |||
| for (size_t i = 0; i < args.size(); i++) { | |||
| real_args_size += AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}).size(); | |||
| } | |||
| if (real_args_size != graph_inputs.size()) { | |||
| for (size_t j = 0; j < valid_inputs.size(); j++) { | |||
| if (valid_inputs[j]) { | |||
| MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); | |||
| } | |||
| } | |||
| MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() | |||
| << " not equal"; | |||
| } | |||
| size_t input_index = 0; | |||
| if (graph_inputs.size() != valid_inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size() | |||
| << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; | |||
| } | |||
| for (size_t i = 0; i < args.size(); i++) { | |||
| if (input_index >= graph_inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); | |||
| @@ -763,6 +788,10 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { | |||
| if (utils::isa<AnfNodePtr>(args[i])) { | |||
| // arg is a anf node | |||
| for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) { | |||
| if (!valid_inputs[input_index]) { | |||
| MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); | |||
| continue; | |||
| } | |||
| SetChildGraphParameter(real_arg, graph_inputs[input_index]); | |||
| input_index++; | |||
| } | |||
| @@ -49,9 +49,8 @@ class AscendSession : public SessionBasic { | |||
| // set output of final graph | |||
| void SetFinalGraphOutput(const BaseRef &output) override; | |||
| // insert switch and set the relative active ops | |||
| void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g) override; | |||
| // set args of child graph. the arg maybe come from a output of other child graphs, | |||
| // or from final graph's parameter | |||
| void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; | |||
| // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter | |||
| void SetChildGraphInput(GraphId g, const VectorRef &args) override; | |||
| // get graph id in child graphs by ME front anf node pointer | |||
| GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; | |||
| @@ -116,6 +115,7 @@ class AscendSession : public SessionBasic { | |||
| std::unordered_map<GraphId, GraphId> while_condition_graphs_; | |||
| // record all conditions | |||
| std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_; | |||
| std::unordered_map<GraphId, AnfNodePtr> condition_output_; | |||
| // final_graph_id is used in every root graph has it's own session situation | |||
| GraphId final_graph_id_; | |||
| }; | |||
| @@ -372,8 +372,7 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| std::vector<AnfNodePtr> prior_nodes = {prior_node}; | |||
| std::vector<AnfNodePtr> depend_nodes = {depend_node}; | |||
| MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "],depend node[" << depend_node->DebugString() | |||
| << "],depend_mode=[" << AnfAlgo::GetNodeAttr<int>(cnode, "depend_mode") << "]"; | |||
| MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString(); | |||
| if (prior_node->isa<Parameter>()) { | |||
| prior_nodes = GetOutputNodes(prior_node); | |||
| } | |||
| @@ -86,6 +86,9 @@ class KernelGraph : public FuncGraph { | |||
| bool executable() const { return executable_; } | |||
| // set executable of graph | |||
| void set_executable(bool executable) { executable_ = executable; } | |||
| // set invalid inputs for control sink | |||
| std::vector<bool> *MutableValidInputs() { return &valid_inputs_; } | |||
| std::vector<bool> ValidInputs() { return valid_inputs_; } | |||
| private: | |||
| // remove value node form graph | |||
| @@ -118,6 +121,8 @@ class KernelGraph : public FuncGraph { | |||
| std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; | |||
| // graph needn't execute | |||
| bool executable_; | |||
| // valid inputs | |||
| std::vector<bool> valid_inputs_; | |||
| }; | |||
| } // namespace session | |||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | |||
| @@ -243,29 +243,38 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { | |||
| return new_value_node; | |||
| } | |||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) { | |||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| if (!anf->isa<Parameter>()) { | |||
| MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; | |||
| } | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| auto valid_inputs = graph->MutableValidInputs(); | |||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||
| ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); | |||
| graph->FrontBackendlMapAdd(anf, new_parameter); | |||
| graph_inputs->push_back(new_parameter); | |||
| valid_inputs->push_back(valid_input); | |||
| return new_parameter; | |||
| } | |||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { | |||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> parameters; | |||
| std::vector<AnfNodePtr> pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); | |||
| auto valid_inputs = graph->MutableValidInputs(); | |||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | |||
| auto parameter = graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| parameter->set_abstract(abstract); | |||
| parameters.push_back(graph->NewParameter(parameter)); | |||
| auto new_parameter = graph->NewParameter(parameter); | |||
| parameters.push_back(new_parameter); | |||
| valid_inputs->push_back(valid_input); | |||
| graph_inputs->push_back(new_parameter); | |||
| }; | |||
| for (const auto &out_node : pre_graph_out) { | |||
| MS_EXCEPTION_IF_NULL(out_node); | |||
| @@ -287,18 +296,15 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelG | |||
| return parameters; | |||
| } | |||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) { | |||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| if (!anf->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a cnode"; | |||
| MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode"; | |||
| } | |||
| MS_LOG(INFO) << "create a new parameter from cnode[" << anf->DebugString() << "]"; | |||
| auto parameters = CreateParameterFromTuple(anf, graph); | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(*graph_inputs)); | |||
| MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; | |||
| auto parameters = CreateParameterFromTuple(anf, valid_input, graph); | |||
| if (parameters.empty()) { | |||
| MS_LOG(EXCEPTION) << "no parameter exist!!"; | |||
| MS_LOG(EXCEPTION) << "No parameter exist!!"; | |||
| } | |||
| if (parameters.size() == 1) { | |||
| return parameters[0]; | |||
| @@ -307,7 +313,7 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph | |||
| (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); | |||
| auto make_tuple = graph->NewCNode(make_tuple_input); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| MS_LOG(INFO) << "new make tuple [" << make_tuple->DebugString() << "] of parameters"; | |||
| MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; | |||
| return make_tuple; | |||
| } | |||
| @@ -397,14 +403,20 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { | |||
| GraphId SessionBasic::graph_sum_ = 0; | |||
| CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { | |||
| CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, | |||
| bool *from_other_graph, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(from_other_graph); | |||
| MS_EXCEPTION_IF_NULL(other_graph_cnode); | |||
| *from_other_graph = false; | |||
| // get primitive of old node | |||
| auto prim = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| // push attr to inputs[0] of new cnode | |||
| std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))}; | |||
| // if has multiple depends,only select first depend as parameter | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto anf = cnode->inputs()[input_idx]; | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| @@ -412,6 +424,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | |||
| continue; | |||
| } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { | |||
| cnode_inputs.push_back((*other_graph_cnode)[anf]); | |||
| continue; | |||
| } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) { | |||
| // if input is a value node, | |||
| auto new_value_node = CreateNewValueNode(anf, graph); | |||
| @@ -421,38 +436,60 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| continue; | |||
| } else if (anf->isa<Parameter>()) { | |||
| // if anf is a parameter | |||
| cnode_inputs.emplace_back(CreateNewParameterFromParameter(anf, graph)); | |||
| auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); | |||
| cnode_inputs.push_back(new_parameter); | |||
| if (GetGraphIdByNode(anf) == kInvalidGraphId) { | |||
| graph->FrontBackendlMapAdd(anf, new_parameter); | |||
| } else { | |||
| (*other_graph_cnode)[anf] = new_parameter; | |||
| } | |||
| continue; | |||
| } else if (anf->isa<CNode>()) { | |||
| *from_other_graph = true; | |||
| // the input node is a cnode from other graph | |||
| cnode_inputs.emplace_back(CreateNewParameterFromCNode(anf, graph)); | |||
| auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); | |||
| cnode_inputs.push_back(parameter_from_cnode); | |||
| (*other_graph_cnode)[anf] = parameter_from_cnode; | |||
| continue; | |||
| } | |||
| MS_LOG(EXCEPTION) << "unexpected input[" << anf->DebugString() << "]"; | |||
| MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; | |||
| } | |||
| return graph->NewCNode(cnode_inputs); | |||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); | |||
| auto new_cnode = graph->NewCNode(cnode_inputs); | |||
| TraceManager::EndTrace(); | |||
| return new_cnode; | |||
| } | |||
| KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode; | |||
| auto graph = std::make_shared<KernelGraph>(); | |||
| graph->set_graph_id(graph_sum_); | |||
| MS_LOG(INFO) << "Create graph: " << graph_sum_; | |||
| size_t from_other_graph_depend_num = 0; | |||
| for (const auto &node : lst) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_LOG(DEBUG) << "start create new cnode,node = " << node->DebugString(); | |||
| MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); | |||
| if (!node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Inst node " << node->DebugString() << " is not CNode"; | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode"; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); | |||
| // create a new cnode object | |||
| auto new_cnode = CreateNewCNode(cnode, graph.get()); | |||
| bool from_other_graph = false; | |||
| // only first depend from other graph can create | |||
| bool valid_input = true; | |||
| if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { | |||
| valid_input = false; | |||
| } | |||
| auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { | |||
| from_other_graph_depend_num++; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_scope(cnode->scope()); | |||
| // record map relations between anf from ME and new anf node used in backend | |||
| graph->FrontBackendlMapAdd(node, new_cnode); | |||
| TraceManager::EndTrace(); | |||
| } | |||
| // add a make_tuple at the end of graph as output | |||
| graph->set_output(ConstructOutput(outputs, graph)); | |||
| @@ -631,12 +668,15 @@ void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor: | |||
| CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> output_args; | |||
| auto FindEqu = [graph](const AnfNodePtr &out) -> AnfNodePtr { | |||
| auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { | |||
| auto backend_anf = graph->GetBackendAnfByFrontAnf(out); | |||
| if (backend_anf != nullptr) { | |||
| return backend_anf; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Can not find the node in the equiv map!"; | |||
| for (const auto &output : outputs) { | |||
| MS_LOG(INFO) << "output:" << output->DebugString(); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; | |||
| }; | |||
| output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), | |||
| @@ -69,14 +69,15 @@ class SessionBasic { | |||
| std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||
| CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); | |||
| CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); | |||
| // set parameters of final graph | |||
| virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; } | |||
| // set output of final graph | |||
| virtual void SetFinalGraphOutput(const BaseRef &) {} | |||
| // insert switch and set the relative active ops | |||
| virtual void SwitchCompile(GraphId, GraphId, GraphId) {} | |||
| virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} | |||
| // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter | |||
| virtual void SetChildGraphInput(GraphId, const VectorRef &) {} | |||
| // get graph id in child graphs by ME front anf node pointer | |||
| @@ -136,7 +136,7 @@ void MsBackend::SetSwitchGraph() { | |||
| MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); | |||
| } | |||
| MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; | |||
| sess_->SwitchCompile(cond_g, true_g, false_g); | |||
| sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_)); | |||
| } | |||
| is_switch_call_ = false; | |||
| MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; | |||