Browse Source

Critical path performance optimization

tags/v1.1.0
wilfChen 5 years ago
parent
commit
b8e1c03cdc
3 changed files with 50 additions and 35 deletions
  1. +33
    -35
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +15
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
  3. +2
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h

+ 33
- 35
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -375,7 +375,7 @@ std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodeP
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]";
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -389,7 +389,7 @@ std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePt
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]";
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -403,7 +403,7 @@ std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]";
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -421,7 +421,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
if (!AnfAlgo::IsRealKernel(node)) {
return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -443,7 +443,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
if (!IsRealKernel(node)) {
return GetPrevNodeOutputFormat(node, input_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -549,7 +549,7 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &nod
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, input_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -568,7 +568,7 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &no
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, output_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -624,7 +624,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, output_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -645,7 +645,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, 0);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -675,7 +675,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
}
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetOutputAddr(output_idx);
if (addr == nullptr) {
@@ -697,7 +697,8 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
}
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetMutableOutputAddr(output_idx);
if (addr == nullptr) {
@@ -710,11 +711,8 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
// get output device addr of anf_node
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->OutputAddrExist(output_idx);
}
@@ -734,7 +732,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode
// set output device addr of anf_node
void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
if (!kernel_info->SetOutputAddr(addr, output_idx)) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
@@ -744,7 +742,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out
// set workspace device addr of anf_node
void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
@@ -754,7 +752,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t
// get workspace device addr of anf_node
DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetWorkspaceAddr(output_idx);
if (addr == nullptr) {
@@ -767,7 +765,7 @@ DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, siz
// get workspace device mutable addr of anf_node
DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetMutableWorkspaceAddr(index);
if (addr == nullptr) {
@@ -810,7 +808,7 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_

kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
// select_kernel_build_info() has checked whether return pointer is null
auto build_info = kernel_info->select_kernel_build_info();
@@ -821,7 +819,7 @@ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
// get KernelBuildType of node, such as ATT,RT,FWK and so on
KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
// select_kernel_build_info() has checked whether return pointer is null
auto build_info = kernel_info->select_kernel_build_info();
@@ -831,7 +829,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {

kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -840,7 +838,7 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {

kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
@@ -850,7 +848,7 @@ kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
// set select kernel_build_info
void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
}
@@ -858,7 +856,7 @@ void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &sel
// get select kernel_build_info
KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->GetMutableSelectKernelBuildInfo();
}
@@ -866,7 +864,7 @@ KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePt
// get kernelMode
KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->MutableKernelMod();
}
@@ -874,7 +872,7 @@ KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
// set kernel mod
void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_kernel_mod(kernel_mod);
}
@@ -940,42 +938,42 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {

void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_stream_id(stream_id);
}

uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->stream_id();
}

void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_stream_distinction_label(stream_label);
}

uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->stream_distinction_label();
}

void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_graph_id(graph_id);
}

uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->graph_id();
}
@@ -1003,7 +1001,7 @@ bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
if (node->isa<ValueNode>()) {
return false;
}
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->is_feature_map();
}


+ 15
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc View File

@@ -281,6 +281,11 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v
}

void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) {
if (is_alloc_inplace_res_[graph->graph_id()]) {
return;
}
is_alloc_inplace_res_[graph->graph_id()] = true;

std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
auto kernel_cnodes = graph->execution_order();
for (auto &kernel : kernel_cnodes) {
@@ -921,6 +926,11 @@ bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::K
}

void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) {
if (is_alloc_communication_res_[graph->graph_id()]) {
return;
}
is_alloc_communication_res_[graph->graph_id()] = true;

MS_EXCEPTION_IF_NULL(graph);
auto &kernels = graph->execution_order();
for (auto &kernel : kernels) {
@@ -1031,6 +1041,11 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
}
}

auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i);
if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {
continue;
}

auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i);
if (kernel_ref_count_ptr == nullptr) {
continue;


+ 2
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h View File

@@ -101,6 +101,8 @@ class GPUKernelRuntime : public KernelRuntime {
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
std::unordered_map<uint32_t, bool> is_first_step_map_;
std::unordered_map<uint32_t, std::set<AnfNodePtr>> graph_output_map_;
std::unordered_map<uint32_t, bool> is_alloc_communication_res_;
std::unordered_map<uint32_t, bool> is_alloc_inplace_res_;

MemReuseUtilPtr mem_reuse_util_{nullptr};
MemSwapManagerPtr mem_swap_manager_{nullptr};


Loading…
Cancel
Save