diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 51a2d3b64f..d901e82178 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -160,7 +160,7 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const std::vector &input_tensors, std::map *tensor_to_node) { auto &node = node_output_pair.first; - auto &output_index = node_output_pair.second; + int output_index = SizeToInt(node_output_pair.second); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; @@ -172,25 +172,24 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, if (type_id == kTypeUnknown) { type_id = AnfAlgo::GetOutputInferDataType(node, output_index); } - tensor::TensorPtr tensor = nullptr; std::vector temp_shape; - if (graph->IsUniqueTargetInternalOutput(node, output_index)) { - temp_shape.emplace_back(1); - tensor = std::make_shared(type_id, temp_shape); - tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); - tensor->set_sync_status(kNoNeedSync); - } else { + auto shape = AnfAlgo::GetOutputInferShape(node, output_index); + (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); + tensor::TensorPtr tensor; + bool is_internal_output = graph->IsInternalOutput(node, output_index); + if (is_internal_output) { tensor = graph->GetInternalOutputTensor(node, output_index); if (tensor == nullptr) { - auto shape = AnfAlgo::GetOutputInferShape(node, output_index); - (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); tensor = std::make_shared(type_id, temp_shape); - bool is_internal_output = graph->IsInternalOutput(node, output_index); - if (is_internal_output) { - graph->AddInternalOutputTensor(node, output_index, tensor); - } + graph->AddInternalOutputTensor(node, output_index, tensor); } - tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); + } else { + tensor = std::make_shared(type_id, temp_shape); + } + tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); + if (is_internal_output) { + tensor->set_sync_status(kNoNeedSync); + } else { // if in pynative mode,data only copied to host when user want to print data auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -682,16 +681,20 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter if (param_value != nullptr) { new_parameter = param_value->parameter(); - if (new_parameter == nullptr) { - TraceGuard trace_guard(std::make_shared(anf->debug_info())); - new_parameter = graph->NewParameter(anf->cast()); - param_value->set_parameter(new_parameter); - } - } else { + } + if (new_parameter == nullptr) { TraceGuard trace_guard(std::make_shared(anf->debug_info())); new_parameter = graph->NewParameter(anf->cast()); - } + auto input_node_iter = partial_parameters_map_.find(anf); + if (input_node_iter != partial_parameters_map_.end()) { + InitInternalOutputParameter(input_node_iter->second, new_parameter); + } + + if (param_value != nullptr) { + param_value->set_parameter(new_parameter); + } + } new_parameter->IncreaseUsedGraphCount(); graph_inputs->push_back(new_parameter); valid_inputs->push_back(true); @@ -1772,10 +1775,11 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) { std::vector ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager, const AnfNodePtr &front_node) { + MS_EXCEPTION_IF_NULL(front_func_graph_manager); auto &users = front_func_graph_manager->node_users()[front_node]; std::vector result; for (auto &user : users) { - if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { + if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend)) { auto depend_cnode = user.first->cast(); if (depend_cnode == nullptr) { continue; @@ -1785,9 +1789,12 @@ std::vector ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr } auto res = ExtendNodeUsers(front_func_graph_manager, user.first); result.insert(result.end(), res.begin(), res.end()); - continue; + } else if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) { + auto res = ExtendNodeUsers(front_func_graph_manager, user.first); + (void)result.insert(result.end(), res.begin(), res.end()); + } else { + (void)result.emplace_back(user.first); } - (void)result.emplace_back(user.first); } return result; } @@ -1813,10 +1820,54 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) { } return nullptr; } +} // namespace -void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, - const FuncGraphManagerPtr &front_func_graph_manager, - const std::shared_ptr &backend_graph) { +constexpr auto kMixTarget = "MixTarget"; +constexpr auto kNoTarget = "NoTarget"; +std::string SessionBasic::AddPartialParametersMap(const FuncGraphManagerPtr &front_func_graph_manager, + const AnfNodePtr &partial_node) { + MS_EXCEPTION_IF_NULL(partial_node); + auto iter = partial_target_map_.find(partial_node); + if (iter != partial_target_map_.end()) { + return iter->second; + } + auto partial_cnode = partial_node->cast(); + MS_EXCEPTION_IF_NULL(partial_cnode); + auto partial_graph = GetValueNode(partial_cnode->input(kFirstDataInputIndex)); + MS_EXCEPTION_IF_NULL(partial_graph); + auto parameters = partial_graph->parameters(); + auto partial_inputs = partial_cnode->inputs(); + if (parameters.size() + 2 != partial_inputs.size()) { + return kMixTarget; + } + for (size_t i = 0; i < parameters.size(); ++i) { + partial_parameters_map_[parameters[i]] = partial_inputs[2 + i]; + } + auto graph_nodes = TopoSort(partial_graph->get_return()); + std::string graph_target = kNoTarget; + for (auto &node : graph_nodes) { + if (!node->isa()) { + continue; + } + if (!AnfAlgo::IsRealKernel(node)) { + continue; + } + std::string cur_target = GetCNodeTarget(node); + if (graph_target == kNoTarget) { + graph_target = cur_target; + } + if (graph_target != cur_target) { + graph_target = kMixTarget; + break; + } + } + (void)partial_target_map_.insert({partial_node, graph_target}); + return graph_target; +} + +void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, + const FuncGraphManagerPtr &front_func_graph_manager, + const std::shared_ptr &backend_graph) { auto front_node = GetSupportedInternalNode(input_front_node); if (front_node == nullptr) { return; @@ -1840,7 +1891,14 @@ void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr & } if (internal_output) { auto users = ExtendNodeUsers(front_func_graph_manager, front_node); - for (auto user : users) { + for (auto &user : users) { + if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial)) { + auto partial_target = AddPartialParametersMap(front_func_graph_manager, user); + if (partial_target != kNoTarget && partial_target != kernel_target) { + unique_target = false; + } + continue; + } if (!CNodeFirstInputIsPrimitive(user)) { internal_output = false; break; @@ -1860,7 +1918,6 @@ void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr & backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target); } } -} // namespace CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph) { MS_EXCEPTION_IF_NULL(graph); @@ -1869,7 +1926,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: MS_EXCEPTION_IF_NULL(output); MS_LOG(INFO) << "Output:" << output->DebugString(); } - auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { + auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr { auto backend_anf = graph->GetBackendAnfByFrontAnf(out); if (backend_anf != nullptr) { auto context_ptr = MsContext::GetInstance(); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 4ba99bc1b5..f425de317e 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -156,6 +156,11 @@ class SessionBasic : public std::enable_shared_from_this { std::unordered_map *other_graph_cnode); std::vector CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph); void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector &real_inputs); + void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, + const FuncGraphManagerPtr &front_func_graph_manager, + const std::shared_ptr &backend_graph); + std::string AddPartialParametersMap(const FuncGraphManagerPtr &front_func_graph_manager, + const AnfNodePtr &partial_node); protected: friend class Executor; @@ -255,6 +260,8 @@ class SessionBasic : public std::enable_shared_from_this { std::unordered_map> graphs_; std::unordered_map> run_op_graphs_; std::unordered_map front_backend_graph_map_; + std::unordered_map partial_parameters_map_; + std::unordered_map partial_target_map_; std::shared_ptr context_; CallBackFunc summary_callback_; static GraphId graph_sum_; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index b5d8d470dc..a98bdaef2f 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -189,16 +189,19 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput( MS_EXCEPTION_IF_NULL(address); TypeId infer_type_id = AnfAlgo::GetOutputInferDataType(node, index); TypeId device_type_id = AnfAlgo::GetOutputDeviceDataType(node, index); - tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index); - if (tensor == nullptr) { - auto shape = AnfAlgo::GetOutputInferShape(node, index); - ShapeVector temp_shape; - (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); - tensor = std::make_shared(infer_type_id, temp_shape); - bool is_internal_output = kernel_graph->IsInternalOutput(node, index); - if (is_internal_output) { - kernel_graph->AddInternalOutputTensor(node, index, tensor); + auto shape = AnfAlgo::GetOutputInferShape(node, index); + ShapeVector temp_shape; + tensor::TensorPtr tensor; + bool is_internal_output = kernel_graph->IsInternalOutput(node, index); + (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); + if (is_internal_output) { + tensor = kernel_graph->GetInternalOutputTensor(node, index); + if (tensor == nullptr) { + tensor = std::make_shared(infer_type_id, temp_shape); } + kernel_graph->AddInternalOutputTensor(node, index, tensor); + } else { + tensor = std::make_shared(infer_type_id, temp_shape); } tensor->set_device_address(address); if (bound_addresses_.find(address) == bound_addresses_.end()) {