Browse Source

!8414 gpu support dynamic shape

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
772fa6838d
5 changed files with 43 additions and 50 deletions
  1. +11
    -17
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h
  2. +1
    -1
      mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc
  3. +21
    -16
      mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc
  4. +5
    -1
      mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h
  5. +5
    -15
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc

+ 11
- 17
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h View File

@@ -78,44 +78,38 @@ class GpuKernelRegister {
// variable has been created. // variable has been created.
#define uchar unsigned char #define uchar unsigned char


#define UNIQUE_KERNEL_NAME(kernel) KERNEL_NAME(kernel, __COUNTER__)
#define UNIQUE_KERNEL_NAME(kernel) KERNEL_NAME(g_##kernel##_gpu_kernel_reg, __COUNTER__)
#define KERNEL_NAME(kernel, cnt) MERGE(kernel, cnt) #define KERNEL_NAME(kernel, cnt) MERGE(kernel, cnt)
#define MERGE(kernel, cnt) kernel##cnt #define MERGE(kernel, cnt) kernel##cnt


#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_gpu_kernel_reg)(#OPNAME, KernelAttr(), \
[]() { return new OPCLASS(); });
#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, KernelAttr(), []() { return new OPCLASS(); });


// regular register of fixed accuracy kernels // regular register of fixed accuracy kernels
#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_gpu_kernel_reg)(#OPNAME, ATTR, \
[]() { return new OPCLASS(); });
#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS(); });


// register of mixed accuracy kernels which use template and maintain one typename, ignore input num // register of mixed accuracy kernels which use template and maintain one typename, ignore input num
#define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \ #define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T>>::value, " must be base of GpuKernel"); \ static_assert(std::is_base_of<GpuKernel, OPCLASS<T>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_gpu_kernel_reg)( \
#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T>(); });


// register of mixed accuracy kernels which use template and maintain one typename // register of mixed accuracy kernels which use template and maintain one typename
#define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \ #define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T>>::value, " must be base of GpuKernel"); \ static_assert(std::is_base_of<GpuKernel, OPCLASS<T>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_gpu_kernel_reg)( \
#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T>(); });


// register of mixed accuracy kernels which use template and maintain two typename // register of mixed accuracy kernels which use template and maintain two typename
#define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \ #define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \ static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_##S##_gpu_kernel_reg)( \
#OPNAME, ATTR, []() { return new OPCLASS<T, S>(); });
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T, S>(); });


// register of mixed accuracy kernels which use template and maintain three typename // register of mixed accuracy kernels which use template and maintain three typename
#define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \ #define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \ static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg)( \
#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); });
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); });
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_

+ 1
- 1
mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc View File

@@ -62,7 +62,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
continue; continue;
} }
} }
if (AnfAlgo::IsDynamicShape(cnode) &&
if (AnfAlgo::IsNodeDynamicShape(cnode) &&
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) { DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
continue; continue;


+ 21
- 16
mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc View File

@@ -42,24 +42,10 @@ void DynamicKernel::Initialize() {
return; return;
} }
MS_LOG(INFO) << "Have depends"; MS_LOG(INFO) << "Have depends";
std::vector<int> depends_list;
std::vector<int64_t> depends_list_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode_ptr_, kDynamicShapeDepends); std::vector<int64_t> depends_list_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode_ptr_, kDynamicShapeDepends);
(void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depends_list),
(void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depend_list_),
[](const int64_t &value) { return static_cast<int>(value); }); [](const int64_t &value) { return static_cast<int>(value); });
// Save depend input tensor. Sync data in InferShape.
for (auto depend : depends_list) {
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, depend);
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode_ptr_, depend);
std::vector<int64_t> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
out_tensor->set_device_address(output_addr);


auto ret = depend_tensor_map_.try_emplace(depend, out_tensor);
if (!ret.second) {
MS_LOG(EXCEPTION) << "Insert map failed";
}
}
MS_LOG(INFO) << "Init End"; MS_LOG(INFO) << "Init End";
} }


@@ -74,6 +60,22 @@ bool IsTupleGetItem(const AnfNodePtr &anf_node) {
return IsPrimitive(input0, prim::kPrimTupleGetItem); return IsPrimitive(input0, prim::kPrimTupleGetItem);
} }


void DynamicKernel::RebuildDependTensor() {
depend_tensor_map_.clear();
for (auto depend : depend_list_) {
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, depend);
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode_ptr_, depend);
std::vector<int64_t> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
out_tensor->set_device_address(output_addr);
auto ret = depend_tensor_map_.try_emplace(depend, out_tensor);
if (!ret.second) {
MS_LOG(EXCEPTION) << "Insert map failed";
}
}
}

void DynamicKernel::InferShape() { void DynamicKernel::InferShape() {
if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) { if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) {
return; return;
@@ -88,12 +90,15 @@ void DynamicKernel::InferShape() {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]); auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);


// rebuild depend tensor map for gpu dynamic memory allocation.
RebuildDependTensor();

auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_); auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_);
for (size_t i = 0; i < input_size; ++i) { for (size_t i = 0; i < input_size; ++i) {
auto input_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i); auto input_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i);
auto real_input = input_with_index.first; auto real_input = input_with_index.first;

MS_EXCEPTION_IF_NULL(real_input); MS_EXCEPTION_IF_NULL(real_input);

auto ret = depend_tensor_map_.find(i); auto ret = depend_tensor_map_.find(i);
if (ret != depend_tensor_map_.end()) { if (ret != depend_tensor_map_.end()) {
auto tensor_ptr = ret->second; auto tensor_ptr = ret->second;


+ 5
- 1
mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h View File

@@ -19,6 +19,7 @@


#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include <map> #include <map>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/tensor.h" #include "ir/tensor.h"
@@ -44,16 +45,19 @@ class DynamicKernel {
bool is_dynamic_shape() const { return is_dynamic_shape_; } bool is_dynamic_shape() const { return is_dynamic_shape_; }
bool is_input_dynamic_shape() const { return is_input_dynamic_shape_; } bool is_input_dynamic_shape() const { return is_input_dynamic_shape_; }
bool is_output_dynamic_shape() const { return is_output_dynamic_shape_; } bool is_output_dynamic_shape() const { return is_output_dynamic_shape_; }
bool have_depends() const { return !depend_tensor_map_.empty(); }
bool have_depends() const { return !depend_list_.empty(); }
virtual void Initialize(); virtual void Initialize();
std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); } std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); }


protected: protected:
void RebuildDependTensor();

void *stream_; void *stream_;
const CNodePtr cnode_ptr_; const CNodePtr cnode_ptr_;
bool is_dynamic_shape_; bool is_dynamic_shape_;
bool is_input_dynamic_shape_; bool is_input_dynamic_shape_;
bool is_output_dynamic_shape_; bool is_output_dynamic_shape_;
std::vector<uint32_t> depend_list_;
std::map<uint32_t, tensor::TensorPtr> depend_tensor_map_; std::map<uint32_t, tensor::TensorPtr> depend_tensor_map_;
}; };
using DynamicKernelPtr = std::shared_ptr<DynamicKernel>; using DynamicKernelPtr = std::shared_ptr<DynamicKernel>;


+ 5
- 15
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc View File

@@ -37,7 +37,6 @@
#include "utils/shape_utils.h" #include "utils/shape_utils.h"
#include "debug/data_dump/dump_json_parser.h" #include "debug/data_dump/dump_json_parser.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "runtime/device/executor/executor_callback.h"
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
#include "debug/debug_services.h" #include "debug/debug_services.h"
#endif #endif
@@ -369,7 +368,7 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) {
bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph) { bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph) {
bool ret = true; bool ret = true;
auto graph_id = graph->graph_id(); auto graph_id = graph->graph_id();
if (!is_first_step_map_[graph_id]) {
if (!is_first_step_map_[graph_id] || graph->is_dynamic_shape()) {
// Normally run graph // Normally run graph
ret = LaunchKernelDynamic(graph); ret = LaunchKernelDynamic(graph);
} else { } else {
@@ -603,16 +602,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo
dynamic_kernel = gpu_kernel->DynamicKernel(); dynamic_kernel = gpu_kernel->DynamicKernel();
} }


if (dynamic_kernel && dynamic_kernel->have_depends()) {
MS_LOG(INFO) << "Match Dynamic Kernel, Start SyncStream";
if (!SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed";
return false;
}
}

if (dynamic_kernel && dynamic_kernel->is_dynamic_shape()) { if (dynamic_kernel && dynamic_kernel->is_dynamic_shape()) {
ExecutorCallback::GetInstance().Consume();
dynamic_kernel->InferShape(); dynamic_kernel->InferShape();
dynamic_kernel->UpdateArgs(); dynamic_kernel->UpdateArgs();
} }
@@ -645,9 +635,10 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo
LaunchKernelWithTimeProfiling(kernel, kernel_inputs, kernel_workspaces, kernel_outputs); LaunchKernelWithTimeProfiling(kernel, kernel_inputs, kernel_workspaces, kernel_outputs);
} }


ExecutorCallback::GetInstance().RegistCallback([&gpu_kernel] {
if (gpu_kernel) gpu_kernel->PostExecute();
});
if (gpu_kernel && dynamic_kernel && dynamic_kernel->is_dynamic_shape()) {
gpu_kernel->PostExecute();
}

// called once per kernel to collect the outputs to the kernel (does a SyncDeviceToHost) // called once per kernel to collect the outputs to the kernel (does a SyncDeviceToHost)
LoadKernelData(debugger_.get(), kernel, kernel_inputs, kernel_workspaces, kernel_outputs, exec_order, stream_, LoadKernelData(debugger_.get(), kernel, kernel_inputs, kernel_workspaces, kernel_outputs, exec_order, stream_,
dump_enabled); dump_enabled);
@@ -666,7 +657,6 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo
// collect weights and bias for dump mode // collect weights and bias for dump mode
debugger_->LoadParametersAndConst(); debugger_->LoadParametersAndConst();
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
ExecutorCallback::GetInstance().Consume();
} }
ClearSwapInfo(mock); ClearSwapInfo(mock);
return true; return true;


Loading…
Cancel
Save