| @@ -571,8 +571,6 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens | |||||
| std::set<KernelGraphPtr> memo; | std::set<KernelGraphPtr> memo; | ||||
| SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo)); | SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo)); | ||||
| memo.clear(); | memo.clear(); | ||||
| // load input data from user input | |||||
| LoadInputData(kernel_graph, inputs); | |||||
| if (debugger_) { | if (debugger_) { | ||||
| debugger_->PreExecute(kernel_graph, graph_sum_); | debugger_->PreExecute(kernel_graph, graph_sum_); | ||||
| } | } | ||||
| @@ -130,6 +130,32 @@ void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &ker | |||||
| runtime_.SyncValueNodeDeviceAddr(kernel_graph.get()); | runtime_.SyncValueNodeDeviceAddr(kernel_graph.get()); | ||||
| } | } | ||||
| void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||||
| const std::vector<tensor::TensorPtr> &inputs_const) const { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto &input_nodes = kernel_graph->inputs(); | |||||
| if (input_nodes.size() != inputs_const.size()) { | |||||
| MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; | |||||
| } | |||||
| for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) { | |||||
| auto &item = input_nodes[input_idx]; | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| if (item->isa<Parameter>() && !HasAbstractMonad(item)) { | |||||
| auto address = AnfAlgo::GetMutableOutputAddr(item, 0); | |||||
| auto tensor = inputs_const[input_idx]; | |||||
| auto tensor_address = tensor->device_address(); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| if (tensor_address != nullptr && tensor_address != address && | |||||
| (std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != | |||||
| device::DeviceAddressType::kCPU || | |||||
| AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) { | |||||
| tensor->data_sync(false); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| VectorRef *outputs) { | VectorRef *outputs) { | ||||
| auto kernel_graph = GetGraph(graph_id); | auto kernel_graph = GetGraph(graph_id); | ||||
| @@ -44,6 +44,8 @@ class CPUSession : public SessionBasic { | |||||
| const std::vector<int64_t> &tensors_mask) override; | const std::vector<int64_t> &tensors_mask) override; | ||||
| void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | ||||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | ||||
| void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||||
| const std::vector<tensor::TensorPtr> &inputs_const) const override; | |||||
| private: | private: | ||||
| void Reorder(std::vector<CNodePtr> *node_list); | void Reorder(std::vector<CNodePtr> *node_list); | ||||
| @@ -161,6 +161,7 @@ void RunGraphTask::Run() { | |||||
| } | } | ||||
| graph->ResetGraphRunningStatus(); | graph->ResetGraphRunningStatus(); | ||||
| try { | try { | ||||
| session_->LoadInputs(graph_id_, input_tensors_); | |||||
| session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); | session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); | ||||
| UpdateOutputTensors(&outputs_, tensor_to_node_); | UpdateOutputTensors(&outputs_, tensor_to_node_); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| @@ -425,8 +425,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||||
| MS_LOG(INFO) << "RunGraph graph_id: " << graph_id; | MS_LOG(INFO) << "RunGraph graph_id: " << graph_id; | ||||
| // In pynative mode, device addresses of tensors in value nodes change. | // In pynative mode, device addresses of tensors in value nodes change. | ||||
| SyncValueNodeDeviceAddr(kernel_graph); | SyncValueNodeDeviceAddr(kernel_graph); | ||||
| // Load input data from user input | |||||
| LoadInputData(kernel_graph, inputs); | |||||
| if (debugger_) { | if (debugger_) { | ||||
| debugger_->PreExecute(kernel_graph, graph_sum_); | debugger_->PreExecute(kernel_graph, graph_sum_); | ||||
| } | } | ||||
| @@ -47,6 +47,8 @@ class GPUSession : public SessionBasic { | |||||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | ||||
| std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override; | std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override; | ||||
| std::string GetCommWorldGroup() override { return kNcclWorldGroup; } | std::string GetCommWorldGroup() override { return kNcclWorldGroup; } | ||||
| void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||||
| const std::vector<tensor::TensorPtr> &inputs_const) const override; | |||||
| private: | private: | ||||
| void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| @@ -71,9 +73,6 @@ class GPUSession : public SessionBasic { | |||||
| void RunOpClearMemory(KernelGraph *kernel_graph) const; | void RunOpClearMemory(KernelGraph *kernel_graph) const; | ||||
| void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||||
| const std::vector<tensor::TensorPtr> &inputs_const) const override; | |||||
| void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| @@ -180,8 +180,8 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const { | |||||
| return std::vector<AnfNodePtr>(1, graph_output); | return std::vector<AnfNodePtr>(1, graph_output); | ||||
| } | } | ||||
| void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | |||||
| std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) { | |||||
| void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | |||||
| std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) { | |||||
| MS_EXCEPTION_IF_NULL(visit_queue); | MS_EXCEPTION_IF_NULL(visit_queue); | ||||
| MS_EXCEPTION_IF_NULL(visited_nodes); | MS_EXCEPTION_IF_NULL(visited_nodes); | ||||
| auto it = node_output_edges_.find(node); | auto it = node_output_edges_.find(node); | ||||
| @@ -241,7 +241,7 @@ void KernelGraph::SetExecOrderByDefault() { | |||||
| while (!seed_nodes.empty() || !delay_comm_stack.empty()) { | while (!seed_nodes.empty() || !delay_comm_stack.empty()) { | ||||
| // seed nodes first, then delay comm nodes | // seed nodes first, then delay comm nodes | ||||
| if (seed_nodes.empty()) { | if (seed_nodes.empty()) { | ||||
| VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); | |||||
| EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); | |||||
| delay_comm_stack.pop(); | delay_comm_stack.pop(); | ||||
| } else { | } else { | ||||
| zero_input_nodes.push(seed_nodes.front()); | zero_input_nodes.push(seed_nodes.front()); | ||||
| @@ -272,16 +272,16 @@ void KernelGraph::SetExecOrderByDefault() { | |||||
| } | } | ||||
| if (optimize_comm) { | if (optimize_comm) { | ||||
| while (!delay_comm_stack.empty()) { | while (!delay_comm_stack.empty()) { | ||||
| VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); | |||||
| EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); | |||||
| delay_comm_stack.pop(); | delay_comm_stack.pop(); | ||||
| } | } | ||||
| delay_comm_stack.push(node); | delay_comm_stack.push(node); | ||||
| } else if (is_fused_comm) { | } else if (is_fused_comm) { | ||||
| delay_comm_stack.push(node); | delay_comm_stack.push(node); | ||||
| } else if (is_communication_descendant) { | } else if (is_communication_descendant) { | ||||
| VisitNodeDescendants(node, &communication_descendants, &visited_nodes); | |||||
| EnqueueActiveNodes(node, &communication_descendants, &visited_nodes); | |||||
| } else { | } else { | ||||
| VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); | |||||
| EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -283,8 +283,8 @@ class KernelGraph : public FuncGraph { | |||||
| void SetKernelInfoForNode(const AnfNodePtr &node) const; | void SetKernelInfoForNode(const AnfNodePtr &node) const; | ||||
| void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; | void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; | ||||
| AnfNodePtr MakeValueNode(const AnfNodePtr &node); | AnfNodePtr MakeValueNode(const AnfNodePtr &node); | ||||
| void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | |||||
| std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true); | |||||
| void EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | |||||
| std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true); | |||||
| // update node edge list | // update node edge list | ||||
| void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); | void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); | ||||
| // add node depend edge by data edge or control depend | // add node depend edge by data edge or control depend | ||||
| @@ -181,6 +181,13 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| const std::map<KernelWithIndex, size_t> &cnode_refcount) {} | const std::map<KernelWithIndex, size_t> &cnode_refcount) {} | ||||
| virtual void SetSummaryNodes(KernelGraph *graph); | virtual void SetSummaryNodes(KernelGraph *graph); | ||||
| void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) { | |||||
| auto kernel_graph = GetGraph(graph_id); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| MS_LOG(INFO) << "Load inputs"; | |||||
| LoadInputData(kernel_graph, inputs_const); | |||||
| } | |||||
| virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | ||||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | const std::vector<tensor::TensorPtr> &inputs_const) const; | ||||
| void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors); | void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors); | ||||
| @@ -283,20 +283,14 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker | |||||
| if (input_nodes.size() != inputs.size()) { | if (input_nodes.size() != inputs.size()) { | ||||
| MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; | MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; | ||||
| } | } | ||||
| size_t input_idx = 0; | |||||
| for (auto &item : input_nodes) { | |||||
| for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) { | |||||
| auto &item = input_nodes[input_idx]; | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| if (item->isa<Parameter>() && !HasAbstractMonad(item)) { | if (item->isa<Parameter>() && !HasAbstractMonad(item)) { | ||||
| auto address = AnfAlgo::GetMutableOutputAddr(item, 0); | auto address = AnfAlgo::GetMutableOutputAddr(item, 0); | ||||
| auto tensor = inputs[input_idx]; | auto tensor = inputs[input_idx]; | ||||
| auto tensor_address = tensor->device_address(); | |||||
| MS_EXCEPTION_IF_NULL(address); | MS_EXCEPTION_IF_NULL(address); | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| if (tensor_address != nullptr && tensor_address != address && | |||||
| (std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU || | |||||
| AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) { | |||||
| tensor->data_sync(false); | |||||
| } | |||||
| if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { | if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { | ||||
| address->ptr_ = tensor->data_c(); | address->ptr_ = tensor->data_c(); | ||||
| } else { | } else { | ||||
| @@ -318,7 +312,6 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker | |||||
| address->ref_count_ = INIT_NODE_REF; | address->ref_count_ = INIT_NODE_REF; | ||||
| tensor->set_device_address(address); | tensor->set_device_address(address); | ||||
| } | } | ||||
| input_idx++; | |||||
| } | } | ||||
| } | } | ||||