Browse Source

MemScheduler handle nop node output and input from other graphs

tags/v1.6.0
tanghuikang 4 years ago
parent
commit
87fccab600
3 changed files with 44 additions and 22 deletions
  1. +18
    -0
      mindspore/ccsrc/backend/session/kernel_graph.cc
  2. +6
    -1
      mindspore/ccsrc/backend/session/kernel_graph.h
  3. +20
    -21
      mindspore/ccsrc/runtime/device/kernel_runtime.cc

+ 18
- 0
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -918,6 +918,24 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
return graph_value_nodes_.erase(value_node) != 0;
}

void KernelGraph::SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor) {
output_node_to_tensor_ = node_to_tensor;
for (const auto &item : output_node_to_tensor_) {
auto node = item.first.first;
auto out_index = item.first.second;
if (!opt::IsNopNode(node)) {
continue;
}
while (opt::IsNopNode(node)) {
const auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, 0);
node = kernel_with_index.first;
out_index = kernel_with_index.second;
}
KernelWithIndex real_output{node, out_index};
nop_node_output_map_.emplace(real_output, item.first);
}
}

void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) {
// update graph inputs
MS_EXCEPTION_IF_NULL(old_parameter);


+ 6
- 1
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -295,13 +295,17 @@ class KernelGraph : public FuncGraph {
void SetInputTensors(const std::vector<tensor::TensorPtr> &input_tensors) { input_tensors_ = input_tensors; }
const std::vector<tensor::TensorPtr> &input_tensors() const { return input_tensors_; }

void SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor) { output_node_to_tensor_ = node_to_tensor; }
void SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor);

tensor::TensorPtr GetNodeOutputTensor(const session::KernelWithIndex &output_index) const {
auto iter = output_node_to_tensor_.find(output_index);
if (iter != output_node_to_tensor_.end()) {
return utils::cast<tensor::TensorPtr>(iter->second);
}
auto nop_node_output_iter = nop_node_output_map_.find(output_index);
if (nop_node_output_iter != nop_node_output_map_.end()) {
return GetNodeOutputTensor(nop_node_output_iter->second);
}
return nullptr;
}

@@ -498,6 +502,7 @@ class KernelGraph : public FuncGraph {
std::vector<AnfNodePtr> input_nodes_;
std::vector<tensor::TensorPtr> input_tensors_;
KernelMapTensor output_node_to_tensor_;
std::map<session::KernelWithIndex, session::KernelWithIndex, session::KernelWithIndexCmp> nop_node_output_map_;
mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
// The send/recv pairs inserted for allreduce, the key is allreduce kernel, the first of pair is send node, the second


+ 20
- 21
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -1456,23 +1456,23 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
const auto tensor_size = LongToSize(tensor->data().nbytes());
if (tensor_address == device_address) {
if (tensor->NeedSyncHostToDevice()) {
tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), tensor->data().nbytes(),
tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_);
tensor->set_sync_status(kNoNeedSync);
}
if (mem_scheduler->HasDeviceMem(tensor_address.get())) {
tensor_address->set_ptr(nullptr);
tensor->set_device_address(nullptr);
}
continue;
}
bool need_sync = false;
if (tensor->NeedSyncHostToDevice()) {
mem_scheduler->AddMemNeedInit(device_address.get());
} else if (tensor_address != nullptr) {
need_sync = true;
} else if (tensor_address != device_address) {
tensor->data_sync(false);
mem_scheduler->AddMemNeedInit(device_address.get());
need_sync = true;
}
if (mem_scheduler->HasDeviceMem(device_address.get())) {
device_address->set_ptr(nullptr);
}
if (need_sync) {
if (device_address->GetPtr() != nullptr) {
device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), tensor->data().nbytes(),
tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_);
} else {
mem_scheduler->AddMemNeedInit(device_address.get());
}
}
MemPriority priority = kMemPriorityLow;
const auto &parameter = input_node->cast<ParameterPtr>();
@@ -1642,18 +1642,17 @@ void KernelRuntime::SyncParameter(const session::KernelGraph &graph,
if (!AnfAlgo::IsParameterWeight(parameter) && !graph.IsUpdatedParameter(parameter)) {
continue;
}
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
if (mem_scheduler->HasDeviceMem(device_address.get())) {
auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh);
device_address->set_ptr(device_ptr);
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto origin_tensor_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (origin_tensor_device_address != nullptr) {
origin_tensor_device_address->set_ptr(nullptr);
}
tensor->set_device_address(device_address);
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
if (graph.IsUpdatedParameter(parameter)) {
tensor->SetIsUpdateByDevice();
}
}
}



Loading…
Cancel
Save