| @@ -362,9 +362,12 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || | |||
| AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) { | |||
| AnfAlgo::IsParameterWeight(input_param)) { | |||
| tensor->set_device_address(device_address); | |||
| } | |||
| if (kernel_graph->IsUpdatedParameter(input_param)) { | |||
| tensor->SetIsUpdateByDevice(); | |||
| } | |||
| } | |||
| tensor->set_sync_status(kNoNeedSync); | |||
| } | |||
| @@ -164,20 +164,26 @@ void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; | |||
| } | |||
| for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) { | |||
| auto &item = input_nodes[input_idx]; | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| if (item->isa<Parameter>() && !HasAbstractMonad(item)) { | |||
| auto address = AnfAlgo::GetMutableOutputAddr(item, 0); | |||
| auto tensor = inputs_const[input_idx]; | |||
| auto tensor_address = tensor->device_address(); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (tensor_address != nullptr && tensor_address != address && | |||
| (std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != | |||
| device::DeviceAddressType::kCPU || | |||
| AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) { | |||
| tensor->data_sync(false); | |||
| } | |||
| auto &input_node = input_nodes[input_idx]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (!input_node->isa<Parameter>() || HasAbstractMonad(input_node)) { | |||
| continue; | |||
| } | |||
| auto address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| auto tensor = inputs_const[input_idx]; | |||
| auto tensor_address = tensor->device_address(); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (tensor_address == nullptr || tensor_address == address) { | |||
| continue; | |||
| } | |||
| auto input_param = input_node->cast<ParameterPtr>(); | |||
| if (AnfAlgo::IsParameterWeight(input_param) && !tensor->IsUpdatedByDevice()) { | |||
| continue; | |||
| } | |||
| if (std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != | |||
| device::DeviceAddressType::kCPU) { | |||
| tensor->data_sync(false); | |||
| } | |||
| } | |||
| } | |||
| @@ -1323,15 +1323,19 @@ void KernelGraph::SetOptimizerFlag() { | |||
| auto node_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) { | |||
| has_optimizer_ = true; | |||
| return; | |||
| } else if (node_name.find("Assign") == string::npos) { | |||
| continue; | |||
| } | |||
| if (node_name.find("Assign") != string::npos) { | |||
| for (auto &input : cnode->inputs()) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (input->isa<Parameter>() && AnfAlgo::IsParameterWeight(input->cast<ParameterPtr>())) { | |||
| has_optimizer_ = true; | |||
| return; | |||
| } | |||
| for (auto &input : cnode->inputs()) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| auto real_node = AnfAlgo::VisitKernel(input, 0).first; | |||
| if (!real_node->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| auto param = real_node->cast<ParameterPtr>(); | |||
| if (AnfAlgo::IsParameterWeight(param)) { | |||
| has_optimizer_ = true; | |||
| (void)updated_parameters_.insert(param); | |||
| } | |||
| } | |||
| } | |||
| @@ -63,6 +63,7 @@ class KernelGraph : public FuncGraph { | |||
| ref_out_in_map_ = graph.ref_out_in_map_; | |||
| node_output_edges_ = graph.node_output_edges_; | |||
| summary_nodes_ = graph.summary_nodes_; | |||
| updated_parameters_ = graph.updated_parameters_; | |||
| executable_ = graph.executable_; | |||
| summary_node_exist_ = graph.summary_node_exist_; | |||
| valid_inputs_ = graph.valid_inputs_; | |||
| @@ -259,6 +260,12 @@ class KernelGraph : public FuncGraph { | |||
| void SetInputNodes(); | |||
| const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; } | |||
| bool has_optimizer() const { return has_optimizer_; } | |||
| bool IsUpdatedParameter(const ParameterPtr ¶m) { | |||
| if (updated_parameters_.find(param) != updated_parameters_.end()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| // handle graph dependency | |||
| void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) { | |||
| if (graph != nullptr) { | |||
| @@ -373,6 +380,8 @@ class KernelGraph : public FuncGraph { | |||
| std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_; | |||
| std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; | |||
| std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_; | |||
| // parameters that will be updated when graph is executed | |||
| std::unordered_set<ParameterPtr> updated_parameters_; | |||
| // graph needn't execute | |||
| bool executable_{false}; | |||
| // exist summary node in graph | |||
| @@ -74,7 +74,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodeP | |||
| (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) { | |||
| return false; | |||
| } | |||
| if (AnfAlgo::IsRealKernel(node)) { | |||
| if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { | |||
| return true; | |||
| } | |||
| (*idx) += 1; | |||
| @@ -40,6 +40,28 @@ using mindspore::kernel::AddressPtr; | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace { | |||
| std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto graph_inputs = graph->inputs(); | |||
| std::vector<AnfNodePtr> result(graph_inputs.begin(), graph_inputs.end()); | |||
| std::set<AnfNodePtr> inputs_set(graph_inputs.begin(), graph_inputs.end()); | |||
| auto kernels = graph->execution_order(); | |||
| for (auto &kernel : kernels) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto input_node = kernel->input(i + 1); | |||
| auto input_real_node = AnfAlgo::VisitKernelWithReturnType(input_node, 0).first; | |||
| if (input_real_node->isa<Parameter>() && inputs_set.find(input_real_node) == inputs_set.end()) { | |||
| (void)inputs_set.insert(input_real_node); | |||
| (void)result.emplace_back(input_real_node); | |||
| } | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| } // namespace | |||
| constexpr size_t kMinInputSize = 2; | |||
| KernelRuntime::~KernelRuntime() {} | |||
| @@ -277,17 +299,21 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(mem_manager_); | |||
| MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph->graph_id(); | |||
| auto graph_inputs = graph->inputs(); | |||
| auto graph_inputs = GetGraphInputs(graph); | |||
| auto graph_valid_input = graph->valid_inputs(); | |||
| graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); | |||
| std::vector<AnfNodePtr> need_alloc_nodes; | |||
| auto add_need_alloc_nodes = [&need_alloc_nodes, this](const AnfNodePtr &node) { | |||
| auto add_need_alloc_nodes = [&need_alloc_nodes, graph, this](const AnfNodePtr &node) { | |||
| if (!node->isa<Parameter>()) { | |||
| return; | |||
| } | |||
| if (NodeOutputDeviceAddressExist(node, 0)) { | |||
| return; | |||
| } | |||
| auto input_param = node->cast<ParameterPtr>(); | |||
| if (!input_param->IsUsedByRealKernelInGraph(graph->graph_id())) { | |||
| return; | |||
| } | |||
| need_alloc_nodes.push_back(node); | |||
| }; | |||
| @@ -356,6 +356,8 @@ class Tensor : public MetaTensor { | |||
| bool IsGraphOutput() { return graph_output_; } | |||
| void SetIsGraphOutput() { graph_output_ = true; } | |||
| bool IsUpdatedByDevice() { return updated_by_device_; } | |||
| void SetIsUpdateByDevice() { updated_by_device_ = true; } | |||
| private: | |||
| bool init_flag_{false}; | |||
| @@ -364,6 +366,7 @@ class Tensor : public MetaTensor { | |||
| mutable std::shared_ptr<WaitEvent> event_{nullptr}; | |||
| mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; | |||
| bool graph_output_{false}; | |||
| bool updated_by_device_{false}; | |||
| DeviceSyncPtr device_sync_{nullptr}; | |||
| bool cache_enable_{false}; | |||
| std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr}; | |||