Browse Source

extract load input

pull/14484/head
kswang 4 years ago
parent
commit
97a97e02db
10 changed files with 48 additions and 24 deletions
  1. +0
    -2
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +26
    -0
      mindspore/ccsrc/backend/session/cpu_session.cc
  3. +2
    -0
      mindspore/ccsrc/backend/session/cpu_session.h
  4. +1
    -0
      mindspore/ccsrc/backend/session/executor.cc
  5. +0
    -2
      mindspore/ccsrc/backend/session/gpu_session.cc
  6. +2
    -3
      mindspore/ccsrc/backend/session/gpu_session.h
  7. +6
    -6
      mindspore/ccsrc/backend/session/kernel_graph.cc
  8. +2
    -2
      mindspore/ccsrc/backend/session/kernel_graph.h
  9. +7
    -0
      mindspore/ccsrc/backend/session/session_basic.h
  10. +2
    -9
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc

+ 0
- 2
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -571,8 +571,6 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens
std::set<KernelGraphPtr> memo;
SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo));
memo.clear();
// load input data from user input
LoadInputData(kernel_graph, inputs);
if (debugger_) {
debugger_->PreExecute(kernel_graph, graph_sum_);
}


+ 26
- 0
mindspore/ccsrc/backend/session/cpu_session.cc View File

@@ -130,6 +130,32 @@ void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &ker
runtime_.SyncValueNodeDeviceAddr(kernel_graph.get());
}

void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto &input_nodes = kernel_graph->inputs();
if (input_nodes.size() != inputs_const.size()) {
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
}
for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
auto &item = input_nodes[input_idx];
MS_EXCEPTION_IF_NULL(item);
if (item->isa<Parameter>() && !HasAbstractMonad(item)) {
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
auto tensor = inputs_const[input_idx];
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address &&
(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() !=
device::DeviceAddressType::kCPU ||
AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) {
tensor->data_sync(false);
}
}
}
}

void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
auto kernel_graph = GetGraph(graph_id);


+ 2
- 0
mindspore/ccsrc/backend/session/cpu_session.h View File

@@ -44,6 +44,8 @@ class CPUSession : public SessionBasic {
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const override;

private:
void Reorder(std::vector<CNodePtr> *node_list);


+ 1
- 0
mindspore/ccsrc/backend/session/executor.cc View File

@@ -161,6 +161,7 @@ void RunGraphTask::Run() {
}
graph->ResetGraphRunningStatus();
try {
session_->LoadInputs(graph_id_, input_tensors_);
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
UpdateOutputTensors(&outputs_, tensor_to_node_);
} catch (const std::exception &e) {


+ 0
- 2
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -425,8 +425,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
MS_LOG(INFO) << "RunGraph graph_id: " << graph_id;
// In pynative mode, device addresses of tensors in value nodes change.
SyncValueNodeDeviceAddr(kernel_graph);
// Load input data from user input
LoadInputData(kernel_graph, inputs);
if (debugger_) {
debugger_->PreExecute(kernel_graph, graph_sum_);
}


+ 2
- 3
mindspore/ccsrc/backend/session/gpu_session.h View File

@@ -47,6 +47,8 @@ class GPUSession : public SessionBasic {
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
std::string GetCommWorldGroup() override { return kNcclWorldGroup; }
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const override;

private:
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
@@ -71,9 +73,6 @@ class GPUSession : public SessionBasic {

void RunOpClearMemory(KernelGraph *kernel_graph) const;

void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const override;

void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;

void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;


+ 6
- 6
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -180,8 +180,8 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
return std::vector<AnfNodePtr>(1, graph_output);
}

void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
MS_EXCEPTION_IF_NULL(visit_queue);
MS_EXCEPTION_IF_NULL(visited_nodes);
auto it = node_output_edges_.find(node);
@@ -241,7 +241,7 @@ void KernelGraph::SetExecOrderByDefault() {
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
// seed nodes first, then delay comm nodes
if (seed_nodes.empty()) {
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
} else {
zero_input_nodes.push(seed_nodes.front());
@@ -272,16 +272,16 @@ void KernelGraph::SetExecOrderByDefault() {
}
if (optimize_comm) {
while (!delay_comm_stack.empty()) {
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
delay_comm_stack.pop();
}
delay_comm_stack.push(node);
} else if (is_fused_comm) {
delay_comm_stack.push(node);
} else if (is_communication_descendant) {
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
EnqueueActiveNodes(node, &communication_descendants, &visited_nodes);
} else {
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes);
}
}
}


+ 2
- 2
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -283,8 +283,8 @@ class KernelGraph : public FuncGraph {
void SetKernelInfoForNode(const AnfNodePtr &node) const;
void ResetInFormat(const AnfNodePtr &node, const std::string &format) const;
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
void EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
// update node edge list
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
// add node depend edge by data edge or control depend


+ 7
- 0
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -181,6 +181,13 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
virtual void SetSummaryNodes(KernelGraph *graph);

void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) {
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Load inputs";
LoadInputData(kernel_graph, inputs_const);
}

virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const;
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);


+ 2
- 9
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc View File

@@ -283,20 +283,14 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
if (input_nodes.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
}
size_t input_idx = 0;
for (auto &item : input_nodes) {
for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
auto &item = input_nodes[input_idx];
MS_EXCEPTION_IF_NULL(item);
if (item->isa<Parameter>() && !HasAbstractMonad(item)) {
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
auto tensor = inputs[input_idx];
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address &&
(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU ||
AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) {
tensor->data_sync(false);
}
if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) {
address->ptr_ = tensor->data_c();
} else {
@@ -318,7 +312,6 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
address->ref_count_ = INIT_NODE_REF;
tensor->set_device_address(address);
}
input_idx++;
}
}



Loading…
Cancel
Save