| @@ -339,7 +339,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| } | } | ||||
| void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) { | |||||
| void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); | auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); | ||||
| std::vector<session::KernelWithIndex> non_communication_op; | std::vector<session::KernelWithIndex> non_communication_op; | ||||
| @@ -350,6 +350,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) | |||||
| if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) { | if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| graph->AddFinalOutputKernel(item_with_index.first); | |||||
| if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { | if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { | ||||
| AssignCommunicationNodeMem(kStaticMem, item_with_index.first); | AssignCommunicationNodeMem(kStaticMem, item_with_index.first); | ||||
| } else { | } else { | ||||
| @@ -95,7 +95,7 @@ class KernelRuntime { | |||||
| #endif | #endif | ||||
| private: | private: | ||||
| void AssignStaticMemoryOutput(const session::KernelGraph *graph); | |||||
| void AssignStaticMemoryOutput(session::KernelGraph *graph); | |||||
| void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, | void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, | ||||
| AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); | AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); | ||||
| bool LaunchKernelMod(const session::KernelGraph &graph); | bool LaunchKernelMod(const session::KernelGraph &graph); | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive_base.h" | #include "ir/primitive_base.h" | ||||
| #include "utils/context/ms_context.h" | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -179,4 +179,43 @@ std::string get_id(const AnfNodePtr &node) { | |||||
| void reset_id() { node_ids.clear(); } | void reset_id() { node_ids.clear(); } | ||||
| } // namespace id_generator | } // namespace id_generator | ||||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string default_target = context_ptr->device_target(); | |||||
| if (!node->isa<CNode>()) { | |||||
| return default_target; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto attr_input = cnode->input(0); | |||||
| if (attr_input == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| auto value_node = attr_input->cast<ValueNodePtr>(); | |||||
| if (value_node == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| auto value = value_node->value(); | |||||
| if (value == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| if (!value->isa<Primitive>()) { | |||||
| return default_target; | |||||
| } | |||||
| auto primitive = value->cast<PrimitivePtr>(); | |||||
| auto att_target = primitive->GetAttr("primitive_target"); | |||||
| if (att_target != nullptr) { | |||||
| if (!att_target->isa<StringImm>()) { | |||||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||||
| } | |||||
| auto target = GetValue<std::string>(att_target); | |||||
| if (kTargetSet.find(target) == kTargetSet.end()) { | |||||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||||
| } | |||||
| return target; | |||||
| } | |||||
| return default_target; | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -448,7 +448,7 @@ void reset_id(); | |||||
| } // namespace id_generator | } // namespace id_generator | ||||
| using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>; | using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>; | ||||
| using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>; | using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>; | ||||
| std::string GetCNodeTarget(const AnfNodePtr &node); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_IR_ANF_H_ | #endif // MINDSPORE_CCSRC_IR_ANF_H_ | ||||
| @@ -46,6 +46,11 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An | |||||
| if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { | if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AnfNodePtr front_node; | |||||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||||
| if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { | |||||
| front_node = kernel_graph->GetFrontNodeByInternalOutput(node); | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | ||||
| MS_LOG(DEBUG) << "====process op: " << node->DebugString(); | MS_LOG(DEBUG) << "====process op: " << node->DebugString(); | ||||
| AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); | AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); | ||||
| @@ -56,7 +61,12 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| } | } | ||||
| return InsertTransOpForOutput(func_graph, new_node, kernel_select_); | |||||
| auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); | |||||
| if (kernel_graph != nullptr && front_node != nullptr) { | |||||
| auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); | |||||
| kernel_graph->ReplaceInternalOutput(old_node, final_node); | |||||
| } | |||||
| return final_node; | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -976,15 +976,6 @@ void AscendSession::SetFinalGraphOutput(const BaseRef &output) { | |||||
| } | } | ||||
| } | } | ||||
| KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) { | |||||
| auto it = graphs_.find(graph_id); | |||||
| if (it == graphs_.end()) { | |||||
| MS_LOG(WARNING) << "Can't find graph " << graph_id; | |||||
| return nullptr; | |||||
| } | |||||
| return it->second; | |||||
| } | |||||
| void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { | void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; | MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; | ||||
| @@ -128,8 +128,6 @@ class AscendSession : public SessionBasic { | |||||
| void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); | void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); | ||||
| // insert depend to graph, used to attch control nodes to graph | // insert depend to graph, used to attch control nodes to graph | ||||
| void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); | void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); | ||||
| // Get graph by graph id ,if not exist return null ptr | |||||
| KernelGraphPtr GetGraph(GraphId graph_id); | |||||
| // set child graph parameter if front arg is a anf | // set child graph parameter if front arg is a anf | ||||
| void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); | void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); | ||||
| // set child graph parameter if front arg is a tensor | // set child graph parameter if front arg is a tensor | ||||
| @@ -329,6 +329,9 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { | |||||
| FrontBackendlMapUpdate(cnode, new_cnode); | FrontBackendlMapUpdate(cnode, new_cnode); | ||||
| } | } | ||||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | AnfAlgo::SetGraphId(graph_id_, cnode.get()); | ||||
| if (IsInternalOutput(cnode)) { | |||||
| ReplaceInternalOutput(cnode, new_cnode); | |||||
| } | |||||
| return new_cnode; | return new_cnode; | ||||
| } | } | ||||
| @@ -872,6 +875,76 @@ void KernelGraph::PrintGraphExecuteOrder() const { | |||||
| } | } | ||||
| } | } | ||||
| void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { | |||||
| if (front_node == nullptr || node == nullptr) { | |||||
| MS_LOG(INFO) << "Front node or node is nullptr"; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); | |||||
| front_to_internal_outputs_map_[front_node] = node; | |||||
| internal_outputs_to_front_map_[node] = front_node; | |||||
| } | |||||
| void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { | |||||
| if (new_node == nullptr || node == nullptr) { | |||||
| MS_LOG(INFO) << "New node or node is nullptr"; | |||||
| return; | |||||
| } | |||||
| if (node == new_node) { | |||||
| MS_LOG(INFO) << "New node and node is the same"; | |||||
| return; | |||||
| } | |||||
| auto iter = internal_outputs_to_front_map_.find(node); | |||||
| if (iter == internal_outputs_to_front_map_.end()) { | |||||
| MS_LOG(INFO) << "Node is not internal output"; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); | |||||
| internal_outputs_to_front_map_[new_node] = iter->second; | |||||
| front_to_internal_outputs_map_[iter->second] = new_node; | |||||
| internal_outputs_to_front_map_.erase(iter); | |||||
| } | |||||
| AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { | |||||
| auto iter = front_to_internal_outputs_map_.find(front_node); | |||||
| if (iter != front_to_internal_outputs_map_.end()) { | |||||
| return iter->second; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { | |||||
| if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { | |||||
| auto iter = internal_outputs_to_front_map_.find(node); | |||||
| if (iter != internal_outputs_to_front_map_.end()) { | |||||
| return iter->second; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| return; | |||||
| } | |||||
| (void)final_output_kernels_.insert(node); | |||||
| } | |||||
| bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (final_output_kernels_.find(node) != final_output_kernels_.end()) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | ||||
| KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } | KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } | ||||
| @@ -144,6 +144,13 @@ class KernelGraph : public FuncGraph { | |||||
| void PrintGraphExecuteOrder() const; | void PrintGraphExecuteOrder() const; | ||||
| const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; } | const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; } | ||||
| void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; } | void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; } | ||||
| void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); | |||||
| void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); | |||||
| AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; | |||||
| bool IsInternalOutput(const AnfNodePtr &node) const; | |||||
| AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; | |||||
| void AddFinalOutputKernel(const AnfNodePtr &node); | |||||
| bool IsFinalOutputKernel(const AnfNodePtr &node) const; | |||||
| private: | private: | ||||
| // remove value node form graph | // remove value node form graph | ||||
| @@ -202,6 +209,9 @@ class KernelGraph : public FuncGraph { | |||||
| CNodePtr start_label_; | CNodePtr start_label_; | ||||
| CNodePtr end_goto_; | CNodePtr end_goto_; | ||||
| bool null_output_; | bool null_output_; | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_; | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> internal_outputs_to_front_map_; | |||||
| std::set<AnfNodePtr> final_output_kernels_; | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | ||||
| @@ -95,6 +95,13 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||||
| TypeId type_id = kNumberTypeFloat32; | TypeId type_id = kNumberTypeFloat32; | ||||
| type_id = AnfAlgo::GetOutputInferDataType(node, output_index); | type_id = AnfAlgo::GetOutputInferDataType(node, output_index); | ||||
| std::vector<int> temp_shape; | std::vector<int> temp_shape; | ||||
| if (graph.IsInternalOutput(node)) { | |||||
| temp_shape.emplace_back(1); | |||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | |||||
| tensor->set_device_address(address); | |||||
| tensor->set_dirty(false); | |||||
| return tensor; | |||||
| } | |||||
| (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); | (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | ||||
| // if in paynative mode,data only copyed to host when user want to print data | // if in paynative mode,data only copyed to host when user want to print data | ||||
| @@ -172,48 +179,6 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { | |||||
| return new_value_node; | return new_value_node; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> parameters; | |||||
| std::vector<AnfNodePtr> pre_graph_out = {node}; | |||||
| // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive | |||||
| if (!AnfAlgo::IsRealKernel(node)) { | |||||
| pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); | |||||
| } | |||||
| auto valid_inputs = graph->MutableValidInputs(); | |||||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||||
| auto graph_inputs = graph->MutableInputs(); | |||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||||
| auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | |||||
| auto parameter = graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| parameter->set_abstract(abstract); | |||||
| auto new_parameter = graph->NewParameter(parameter); | |||||
| parameters.push_back(new_parameter); | |||||
| valid_inputs->push_back(valid_input); | |||||
| graph_inputs->push_back(new_parameter); | |||||
| }; | |||||
| for (const auto &out_node : pre_graph_out) { | |||||
| MS_EXCEPTION_IF_NULL(out_node); | |||||
| auto abstract = out_node->abstract(); | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| // create multiple parameters if is a tuple output real kernel | |||||
| if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { | |||||
| auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | |||||
| MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; | |||||
| for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { | |||||
| create_parameter((*tuple_abstract)[output_idx]); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| // create single parameter if is a abstract real kernel | |||||
| create_parameter(out_node->abstract()); | |||||
| } | |||||
| return parameters; | |||||
| } | |||||
| size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) { | size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_LOG(INFO) << "Load kInputCtrlTensors"; | MS_LOG(INFO) << "Load kInputCtrlTensors"; | ||||
| @@ -323,6 +288,103 @@ bool ExistSummaryNode(const KernelGraph *graph) { | |||||
| } // namespace | } // namespace | ||||
| GraphId SessionBasic::graph_sum_ = 0; | GraphId SessionBasic::graph_sum_ = 0; | ||||
| KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) { | |||||
| auto it = graphs_.find(graph_id); | |||||
| if (it == graphs_.end()) { | |||||
| MS_LOG(WARNING) << "Can't find graph " << graph_id; | |||||
| return nullptr; | |||||
| } | |||||
| return it->second; | |||||
| } | |||||
| void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) { | |||||
| auto graph_id = GetGraphIdByNode(out_node); | |||||
| if (graph_id == kInvalidGraphId) { | |||||
| return; | |||||
| } | |||||
| auto node_graph = GetGraph(graph_id); | |||||
| if (node_graph == nullptr) { | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString(); | |||||
| auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node); | |||||
| if (ref_node == nullptr) { | |||||
| MS_LOG(INFO) << "No corresponding internal output for output node"; | |||||
| return; | |||||
| } | |||||
| auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); | |||||
| auto ref_real_node = real_kernel.first; | |||||
| auto ref_real_node_index = real_kernel.second; | |||||
| if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) && | |||||
| node_graph->IsFinalOutputKernel(ref_real_node)) { | |||||
| auto kernel_info = ref_real_node->kernel_info(); | |||||
| if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { | |||||
| MS_LOG(INFO) << "No kernel info"; | |||||
| return; | |||||
| } | |||||
| auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); | |||||
| if (address == nullptr) { | |||||
| MS_LOG(INFO) << "No kernel address"; | |||||
| return; | |||||
| } | |||||
| auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); | |||||
| auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); | |||||
| parameter->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| auto d_kernel_info = parameter->kernel_info(); | |||||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| builder.SetOutputsDeviceType({type}); | |||||
| builder.SetOutputsFormat({format}); | |||||
| d_kernel_info->set_select_kernel_build_info(builder.Build()); | |||||
| AnfAlgo::SetOutputAddr(address, 0, parameter.get()); | |||||
| } | |||||
| } | |||||
| std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, | |||||
| KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> parameters; | |||||
| std::vector<AnfNodePtr> pre_graph_out = {node}; | |||||
| // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive | |||||
| if (!AnfAlgo::IsRealKernel(node)) { | |||||
| pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); | |||||
| } | |||||
| auto valid_inputs = graph->MutableValidInputs(); | |||||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||||
| auto graph_inputs = graph->MutableInputs(); | |||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||||
| auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | |||||
| auto parameter = graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| parameter->set_abstract(abstract); | |||||
| auto new_parameter = graph->NewParameter(parameter); | |||||
| parameters.push_back(new_parameter); | |||||
| valid_inputs->push_back(valid_input); | |||||
| graph_inputs->push_back(new_parameter); | |||||
| }; | |||||
| for (const auto &out_node : pre_graph_out) { | |||||
| MS_EXCEPTION_IF_NULL(out_node); | |||||
| auto abstract = out_node->abstract(); | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| // create multiple parameters if is a tuple output real kernel | |||||
| if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { | |||||
| auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | |||||
| MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; | |||||
| for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { | |||||
| create_parameter((*tuple_abstract)[output_idx]); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| // create single parameter if is a abstract real kernel | |||||
| create_parameter(out_node->abstract()); | |||||
| InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]); | |||||
| } | |||||
| return parameters; | |||||
| } | |||||
| ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, | ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, | ||||
| KernelGraph *graph) { | KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(anf); | MS_EXCEPTION_IF_NULL(anf); | ||||
| @@ -857,6 +919,29 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: | |||||
| auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { | auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { | ||||
| auto backend_anf = graph->GetBackendAnfByFrontAnf(out); | auto backend_anf = graph->GetBackendAnfByFrontAnf(out); | ||||
| if (backend_anf != nullptr) { | if (backend_anf != nullptr) { | ||||
| auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); | |||||
| auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); | |||||
| MS_EXCEPTION_IF_NULL(out); | |||||
| auto out_func_graph = out->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(out_func_graph); | |||||
| auto out_func_graph_manager = out_func_graph->manager(); | |||||
| if (out_func_graph_manager == nullptr) { | |||||
| return backend_anf; | |||||
| } | |||||
| auto node_users = out_func_graph_manager->node_users(); | |||||
| auto users = node_users[out]; | |||||
| bool internal_output = true; | |||||
| std::string kernel_target = GetCNodeTarget(front_real_kernel.first); | |||||
| for (auto user : users) { | |||||
| if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { | |||||
| internal_output = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (internal_output) { | |||||
| MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); | |||||
| graph->AddInternalOutput(out, backend_real_kernel.first); | |||||
| } | |||||
| return backend_anf; | return backend_anf; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; | MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; | ||||
| @@ -110,6 +110,8 @@ class SessionBasic { | |||||
| #endif | #endif | ||||
| protected: | protected: | ||||
| // Get graph by graph id ,if not exist return null ptr | |||||
| KernelGraphPtr GetGraph(GraphId graph_id); | |||||
| virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | ||||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | const std::vector<tensor::TensorPtr> &inputs_const) const; | ||||
| void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, | void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, | ||||
| @@ -127,11 +129,13 @@ class SessionBasic { | |||||
| BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | ||||
| // create a new kernel graph and update the graph sum | // create a new kernel graph and update the graph sum | ||||
| KernelGraphPtr NewKernelGraph(); | KernelGraphPtr NewKernelGraph(); | ||||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); | |||||
| virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | ||||
| ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); | ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); | ||||
| ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); | ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); | ||||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | ||||
| void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph); | void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph); | ||||
| void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); | |||||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | ||||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | ||||
| @@ -52,45 +52,6 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | |||||
| } | } | ||||
| namespace { | namespace { | ||||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string default_target = context_ptr->device_target(); | |||||
| if (!node->isa<CNode>()) { | |||||
| return default_target; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto attr_input = cnode->input(kAnfPrimitiveIndex); | |||||
| if (attr_input == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| auto value_node = attr_input->cast<ValueNodePtr>(); | |||||
| if (value_node == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| auto value = value_node->value(); | |||||
| if (value == nullptr) { | |||||
| return default_target; | |||||
| } | |||||
| if (!value->isa<Primitive>()) { | |||||
| return default_target; | |||||
| } | |||||
| auto primitive = value->cast<PrimitivePtr>(); | |||||
| auto att_target = primitive->GetAttr("primitive_target"); | |||||
| if (att_target != nullptr) { | |||||
| if (!att_target->isa<StringImm>()) { | |||||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||||
| } | |||||
| auto target = GetValue<std::string>(att_target); | |||||
| if (kTargetSet.find(target) == kTargetSet.end()) { | |||||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||||
| } | |||||
| return target; | |||||
| } | |||||
| return default_target; | |||||
| } | |||||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||