| @@ -339,7 +339,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| } | |||
| } | |||
| void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) { | |||
| void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); | |||
| std::vector<session::KernelWithIndex> non_communication_op; | |||
| @@ -350,6 +350,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) | |||
| if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) { | |||
| continue; | |||
| } | |||
| graph->AddFinalOutputKernel(item_with_index.first); | |||
| if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { | |||
| AssignCommunicationNodeMem(kStaticMem, item_with_index.first); | |||
| } else { | |||
| @@ -95,7 +95,7 @@ class KernelRuntime { | |||
| #endif | |||
| private: | |||
| void AssignStaticMemoryOutput(const session::KernelGraph *graph); | |||
| void AssignStaticMemoryOutput(session::KernelGraph *graph); | |||
| void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, | |||
| AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); | |||
| bool LaunchKernelMod(const session::KernelGraph &graph); | |||
| @@ -25,7 +25,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive_base.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| @@ -179,4 +179,43 @@ std::string get_id(const AnfNodePtr &node) { | |||
| void reset_id() { node_ids.clear(); } | |||
| } // namespace id_generator | |||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| if (!node->isa<CNode>()) { | |||
| return default_target; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto attr_input = cnode->input(0); | |||
| if (attr_input == nullptr) { | |||
| return default_target; | |||
| } | |||
| auto value_node = attr_input->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return default_target; | |||
| } | |||
| auto value = value_node->value(); | |||
| if (value == nullptr) { | |||
| return default_target; | |||
| } | |||
| if (!value->isa<Primitive>()) { | |||
| return default_target; | |||
| } | |||
| auto primitive = value->cast<PrimitivePtr>(); | |||
| auto att_target = primitive->GetAttr("primitive_target"); | |||
| if (att_target != nullptr) { | |||
| if (!att_target->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||
| } | |||
| auto target = GetValue<std::string>(att_target); | |||
| if (kTargetSet.find(target) == kTargetSet.end()) { | |||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||
| } | |||
| return target; | |||
| } | |||
| return default_target; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -448,7 +448,7 @@ void reset_id(); | |||
| } // namespace id_generator | |||
| using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>; | |||
| using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>; | |||
| std::string GetCNodeTarget(const AnfNodePtr &node); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_ANF_H_ | |||
| @@ -46,6 +46,11 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An | |||
| if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr front_node; | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { | |||
| front_node = kernel_graph->GetFrontNodeByInternalOutput(node); | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| MS_LOG(DEBUG) << "====process op: " << node->DebugString(); | |||
| AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); | |||
| @@ -56,7 +61,12 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An | |||
| return new_node; | |||
| } | |||
| } | |||
| return InsertTransOpForOutput(func_graph, new_node, kernel_select_); | |||
| auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); | |||
| if (kernel_graph != nullptr && front_node != nullptr) { | |||
| auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); | |||
| kernel_graph->ReplaceInternalOutput(old_node, final_node); | |||
| } | |||
| return final_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -976,15 +976,6 @@ void AscendSession::SetFinalGraphOutput(const BaseRef &output) { | |||
| } | |||
| } | |||
| KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) { | |||
| auto it = graphs_.find(graph_id); | |||
| if (it == graphs_.end()) { | |||
| MS_LOG(WARNING) << "Can't find graph " << graph_id; | |||
| return nullptr; | |||
| } | |||
| return it->second; | |||
| } | |||
| void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { | |||
| MS_LOG(INFO) << "Start!"; | |||
| MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; | |||
| @@ -128,8 +128,6 @@ class AscendSession : public SessionBasic { | |||
| void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); | |||
| // insert depend to graph, used to attch control nodes to graph | |||
| void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); | |||
| // Get graph by graph id ,if not exist return null ptr | |||
| KernelGraphPtr GetGraph(GraphId graph_id); | |||
| // set child graph parameter if front arg is a anf | |||
| void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); | |||
| // set child graph parameter if front arg is a tensor | |||
| @@ -329,6 +329,9 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { | |||
| FrontBackendlMapUpdate(cnode, new_cnode); | |||
| } | |||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | |||
| if (IsInternalOutput(cnode)) { | |||
| ReplaceInternalOutput(cnode, new_cnode); | |||
| } | |||
| return new_cnode; | |||
| } | |||
| @@ -872,6 +875,76 @@ void KernelGraph::PrintGraphExecuteOrder() const { | |||
| } | |||
| } | |||
| void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { | |||
| if (front_node == nullptr || node == nullptr) { | |||
| MS_LOG(INFO) << "Front node or node is nullptr"; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); | |||
| front_to_internal_outputs_map_[front_node] = node; | |||
| internal_outputs_to_front_map_[node] = front_node; | |||
| } | |||
| void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { | |||
| if (new_node == nullptr || node == nullptr) { | |||
| MS_LOG(INFO) << "New node or node is nullptr"; | |||
| return; | |||
| } | |||
| if (node == new_node) { | |||
| MS_LOG(INFO) << "New node and node is the same"; | |||
| return; | |||
| } | |||
| auto iter = internal_outputs_to_front_map_.find(node); | |||
| if (iter == internal_outputs_to_front_map_.end()) { | |||
| MS_LOG(INFO) << "Node is not internal output"; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); | |||
| internal_outputs_to_front_map_[new_node] = iter->second; | |||
| front_to_internal_outputs_map_[iter->second] = new_node; | |||
| internal_outputs_to_front_map_.erase(iter); | |||
| } | |||
| AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { | |||
| auto iter = front_to_internal_outputs_map_.find(front_node); | |||
| if (iter != front_to_internal_outputs_map_.end()) { | |||
| return iter->second; | |||
| } | |||
| return nullptr; | |||
| } | |||
| bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { | |||
| if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { | |||
| auto iter = internal_outputs_to_front_map_.find(node); | |||
| if (iter != internal_outputs_to_front_map_.end()) { | |||
| return iter->second; | |||
| } | |||
| return nullptr; | |||
| } | |||
| void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| (void)final_output_kernels_.insert(node); | |||
| } | |||
| bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { | |||
| if (node == nullptr) { | |||
| return false; | |||
| } | |||
| if (final_output_kernels_.find(node) != final_output_kernels_.end()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | |||
| KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } | |||
| @@ -144,6 +144,13 @@ class KernelGraph : public FuncGraph { | |||
| void PrintGraphExecuteOrder() const; | |||
| const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; } | |||
| void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; } | |||
| void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); | |||
| void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); | |||
| AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; | |||
| bool IsInternalOutput(const AnfNodePtr &node) const; | |||
| AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; | |||
| void AddFinalOutputKernel(const AnfNodePtr &node); | |||
| bool IsFinalOutputKernel(const AnfNodePtr &node) const; | |||
| private: | |||
| // remove value node form graph | |||
| @@ -202,6 +209,9 @@ class KernelGraph : public FuncGraph { | |||
| CNodePtr start_label_; | |||
| CNodePtr end_goto_; | |||
| bool null_output_; | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_; | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> internal_outputs_to_front_map_; | |||
| std::set<AnfNodePtr> final_output_kernels_; | |||
| }; | |||
| } // namespace session | |||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | |||
| @@ -95,6 +95,13 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| TypeId type_id = kNumberTypeFloat32; | |||
| type_id = AnfAlgo::GetOutputInferDataType(node, output_index); | |||
| std::vector<int> temp_shape; | |||
| if (graph.IsInternalOutput(node)) { | |||
| temp_shape.emplace_back(1); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | |||
| tensor->set_device_address(address); | |||
| tensor->set_dirty(false); | |||
| return tensor; | |||
| } | |||
| (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | |||
| // if in paynative mode,data only copyed to host when user want to print data | |||
| @@ -172,48 +179,6 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { | |||
| return new_value_node; | |||
| } | |||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> parameters; | |||
| std::vector<AnfNodePtr> pre_graph_out = {node}; | |||
| // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); | |||
| } | |||
| auto valid_inputs = graph->MutableValidInputs(); | |||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | |||
| auto parameter = graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| parameter->set_abstract(abstract); | |||
| auto new_parameter = graph->NewParameter(parameter); | |||
| parameters.push_back(new_parameter); | |||
| valid_inputs->push_back(valid_input); | |||
| graph_inputs->push_back(new_parameter); | |||
| }; | |||
| for (const auto &out_node : pre_graph_out) { | |||
| MS_EXCEPTION_IF_NULL(out_node); | |||
| auto abstract = out_node->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| // create multiple parameters if is a tuple output real kernel | |||
| if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { | |||
| auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | |||
| MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; | |||
| for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { | |||
| create_parameter((*tuple_abstract)[output_idx]); | |||
| } | |||
| continue; | |||
| } | |||
| // create single parameter if is a abstract real kernel | |||
| create_parameter(out_node->abstract()); | |||
| } | |||
| return parameters; | |||
| } | |||
| size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Load kInputCtrlTensors"; | |||
| @@ -323,6 +288,103 @@ bool ExistSummaryNode(const KernelGraph *graph) { | |||
| } // namespace | |||
| GraphId SessionBasic::graph_sum_ = 0; | |||
| KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) { | |||
| auto it = graphs_.find(graph_id); | |||
| if (it == graphs_.end()) { | |||
| MS_LOG(WARNING) << "Can't find graph " << graph_id; | |||
| return nullptr; | |||
| } | |||
| return it->second; | |||
| } | |||
| void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) { | |||
| auto graph_id = GetGraphIdByNode(out_node); | |||
| if (graph_id == kInvalidGraphId) { | |||
| return; | |||
| } | |||
| auto node_graph = GetGraph(graph_id); | |||
| if (node_graph == nullptr) { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString(); | |||
| auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node); | |||
| if (ref_node == nullptr) { | |||
| MS_LOG(INFO) << "No corresponding internal output for output node"; | |||
| return; | |||
| } | |||
| auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); | |||
| auto ref_real_node = real_kernel.first; | |||
| auto ref_real_node_index = real_kernel.second; | |||
| if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) && | |||
| node_graph->IsFinalOutputKernel(ref_real_node)) { | |||
| auto kernel_info = ref_real_node->kernel_info(); | |||
| if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { | |||
| MS_LOG(INFO) << "No kernel info"; | |||
| return; | |||
| } | |||
| auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); | |||
| if (address == nullptr) { | |||
| MS_LOG(INFO) << "No kernel address"; | |||
| return; | |||
| } | |||
| auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); | |||
| auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); | |||
| parameter->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| auto d_kernel_info = parameter->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsDeviceType({type}); | |||
| builder.SetOutputsFormat({format}); | |||
| d_kernel_info->set_select_kernel_build_info(builder.Build()); | |||
| AnfAlgo::SetOutputAddr(address, 0, parameter.get()); | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, | |||
| KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> parameters; | |||
| std::vector<AnfNodePtr> pre_graph_out = {node}; | |||
| // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); | |||
| } | |||
| auto valid_inputs = graph->MutableValidInputs(); | |||
| MS_EXCEPTION_IF_NULL(valid_inputs); | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | |||
| auto parameter = graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| parameter->set_abstract(abstract); | |||
| auto new_parameter = graph->NewParameter(parameter); | |||
| parameters.push_back(new_parameter); | |||
| valid_inputs->push_back(valid_input); | |||
| graph_inputs->push_back(new_parameter); | |||
| }; | |||
| for (const auto &out_node : pre_graph_out) { | |||
| MS_EXCEPTION_IF_NULL(out_node); | |||
| auto abstract = out_node->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| // create multiple parameters if is a tuple output real kernel | |||
| if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { | |||
| auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | |||
| MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; | |||
| for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { | |||
| create_parameter((*tuple_abstract)[output_idx]); | |||
| } | |||
| continue; | |||
| } | |||
| // create single parameter if is a abstract real kernel | |||
| create_parameter(out_node->abstract()); | |||
| InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]); | |||
| } | |||
| return parameters; | |||
| } | |||
| ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, | |||
| KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| @@ -857,6 +919,29 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: | |||
| auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { | |||
| auto backend_anf = graph->GetBackendAnfByFrontAnf(out); | |||
| if (backend_anf != nullptr) { | |||
| auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); | |||
| auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); | |||
| MS_EXCEPTION_IF_NULL(out); | |||
| auto out_func_graph = out->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(out_func_graph); | |||
| auto out_func_graph_manager = out_func_graph->manager(); | |||
| if (out_func_graph_manager == nullptr) { | |||
| return backend_anf; | |||
| } | |||
| auto node_users = out_func_graph_manager->node_users(); | |||
| auto users = node_users[out]; | |||
| bool internal_output = true; | |||
| std::string kernel_target = GetCNodeTarget(front_real_kernel.first); | |||
| for (auto user : users) { | |||
| if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { | |||
| internal_output = false; | |||
| break; | |||
| } | |||
| } | |||
| if (internal_output) { | |||
| MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); | |||
| graph->AddInternalOutput(out, backend_real_kernel.first); | |||
| } | |||
| return backend_anf; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; | |||
| @@ -110,6 +110,8 @@ class SessionBasic { | |||
| #endif | |||
| protected: | |||
| // Get graph by graph id ,if not exist return null ptr | |||
| KernelGraphPtr GetGraph(GraphId graph_id); | |||
| virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | |||
| void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, | |||
| @@ -127,11 +129,13 @@ class SessionBasic { | |||
| BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); | |||
| // create a new kernel graph and update the graph sum | |||
| KernelGraphPtr NewKernelGraph(); | |||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); | |||
| virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||
| ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); | |||
| ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); | |||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||
| void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph); | |||
| void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); | |||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | |||
| @@ -52,45 +52,6 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | |||
| } | |||
| namespace { | |||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| if (!node->isa<CNode>()) { | |||
| return default_target; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto attr_input = cnode->input(kAnfPrimitiveIndex); | |||
| if (attr_input == nullptr) { | |||
| return default_target; | |||
| } | |||
| auto value_node = attr_input->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return default_target; | |||
| } | |||
| auto value = value_node->value(); | |||
| if (value == nullptr) { | |||
| return default_target; | |||
| } | |||
| if (!value->isa<Primitive>()) { | |||
| return default_target; | |||
| } | |||
| auto primitive = value->cast<PrimitivePtr>(); | |||
| auto att_target = primitive->GetAttr("primitive_target"); | |||
| if (att_target != nullptr) { | |||
| if (!att_target->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||
| } | |||
| auto target = GetValue<std::string>(att_target); | |||
| if (kTargetSet.find(target) == kTargetSet.end()) { | |||
| MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; | |||
| } | |||
| return target; | |||
| } | |||
| return default_target; | |||
| } | |||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||