Browse Source

!5667 add kernel select after optimize pass

Merge pull request !5667 from zyli2020/code_refactor
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
bc4c5afc1a
12 changed files with 103 additions and 58 deletions
  1. +9
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h
  2. +5
    -1
      mindspore/ccsrc/backend/optimizer/CMakeLists.txt
  3. +3
    -1
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc
  4. +4
    -1
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc
  5. +3
    -1
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc
  6. +3
    -1
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc
  7. +18
    -41
      mindspore/ccsrc/backend/session/gpu_session.cc
  8. +0
    -5
      mindspore/ccsrc/backend/session/gpu_session.h
  9. +1
    -1
      mindspore/ccsrc/runtime/device/gpu/cuda_env_checker.cc
  10. +29
    -2
      mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc
  11. +24
    -1
      mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h
  12. +4
    -0
      tests/ut/cpp/CMakeLists.txt

+ 9
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h View File

@@ -96,9 +96,15 @@ class ActivationGradGpuKernel : public GpuKernel {
const int split_dim = 4; const int split_dim = 4;
if (input_shape.size() <= split_dim) { if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape); ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"SetTensor4dDescriptor failed");
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
shape[0], shape[3], shape[1], shape[2]),
"cudnnSetTensor4dDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"cudnnSetTensor4dDescriptor failed");
}
} else { } else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_); CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
} }


+ 5
- 1
mindspore/ccsrc/backend/optimizer/CMakeLists.txt View File

@@ -2,7 +2,6 @@ file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"common/*.cc" "common/*.cc"
"mem_reuse/*.cc" "mem_reuse/*.cc"
"pass/*.cc" "pass/*.cc"
"gpu/*.cc"
) )


if (ENABLE_D) if (ENABLE_D)
@@ -10,5 +9,10 @@ if (ENABLE_D)
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST}) list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
endif () endif ()


if (ENABLE_GPU)
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc")
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST})
endif ()

set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST}) add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST})

+ 3
- 1
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc View File

@@ -23,6 +23,7 @@
#include "ir/primitive.h" #include "ir/primitive.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@@ -46,7 +47,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex); MS_EXCEPTION_IF_NULL(batch_norm_ex);
if (AnfAlgo::GetOutputInferDataType(batch_norm_ex, 0) != kNumberTypeFloat16) {
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
return nullptr; return nullptr;
} }
@@ -83,6 +84,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu); manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
return tuple_get_item; return tuple_get_item;
} }
} // namespace opt } // namespace opt


+ 4
- 1
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc View File

@@ -24,6 +24,7 @@
#include "ir/primitive.h" #include "ir/primitive.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@@ -123,7 +124,8 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
const EquivPtr &) const { const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
return nullptr; return nullptr;
} }
@@ -169,6 +171,7 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad); AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad); SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad); ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
device::gpu::SetKernelInfo(fused_batch_norm_add_relu_grad);
return nullptr; return nullptr;
} }
} // namespace opt } // namespace opt


+ 3
- 1
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc View File

@@ -23,6 +23,7 @@
#include "ir/primitive.h" #include "ir/primitive.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@@ -43,7 +44,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex); MS_EXCEPTION_IF_NULL(batch_norm_ex);
if (AnfAlgo::GetOutputInferDataType(batch_norm_ex, 0) != kNumberTypeFloat16) {
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
return nullptr; return nullptr;
} }
@@ -78,6 +79,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm_ex, fused_batch_norm_with_relu); manager->Replace(batch_norm_ex, fused_batch_norm_with_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_relu);
return tuple_get_item; return tuple_get_item;
} }
} // namespace opt } // namespace opt


+ 3
- 1
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc View File

@@ -23,6 +23,7 @@
#include "ir/primitive.h" #include "ir/primitive.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@@ -38,7 +39,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
return nullptr; return nullptr;
} }
@@ -84,6 +85,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
} }
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get()); AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get());
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu); AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu);
device::gpu::SetKernelInfo(fused_batch_norm_grad_with_relu);
return fused_batch_norm_grad_with_relu; return fused_batch_norm_grad_with_relu;
} }
} // namespace opt } // namespace opt


+ 18
- 41
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -53,10 +53,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;


void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
bool graph_format_transform = IsSupportFormatTransform(kernel_graph);
device::gpu::FormatTransformChecker::GetInstance().CheckSupportFormatTransform(kernel_graph);
for (const auto &kernel_node : kernel_graph->execution_order()) { for (const auto &kernel_node : kernel_graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
device::gpu::SetKernelInfo(kernel_node, graph_format_transform);
device::gpu::SetKernelInfo(kernel_node);
} }
} }


@@ -82,12 +82,6 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
if (IsSupportFormatTransform(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
}
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();
@@ -96,6 +90,10 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) { void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>(); auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>()); pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>()); pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
@@ -201,28 +199,6 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
} }
} }


bool GPUSession::IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const {
auto kernels = kernel_graph->execution_order();
size_t conv_cnt = 0;
size_t bn_cnt = 0;
for (const auto &kernel : kernels) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
if (kernel_name == prim::kPrimLayerNorm->name()) {
return false;
}
if (kernel_name == prim::kPrimConv2D->name()) {
conv_cnt++;
}
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
bn_cnt++;
}
}
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
return false;
}
return true;
}

GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
// Construct graph, if successfully, graph_sum_ + 1 // Construct graph, if successfully, graph_sum_ + 1
auto graph_id = graph_sum_; auto graph_id = graph_sum_;
@@ -232,26 +208,27 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
// Optimize
// Dump .pb graph before graph optimization
if (save_graphs) {
DumpIRProto(graph, "before_opt_" + std::to_string(graph_id));
}
// Graph optimization irrelevant to device data format
Optimize(graph); Optimize(graph);
// Select kernel build info // Select kernel build info
SelectKernel(graph); SelectKernel(graph);
// Graph optimization relevant to device data format
HardwareOptimize(graph);
// Dump .pb graph after graph optimization
if (save_graphs) {
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));
}

#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Assign parameter keys. // Assign parameter keys.
AssignParamKey(graph); AssignParamKey(graph);
#endif #endif
// Start gpu kernel runtime // Start gpu kernel runtime
StartKernelRT(); StartKernelRT();
// Dump .pb graph before hardware optimization
if (save_graphs) {
DumpIRProto(graph, "before_hwopt_" + std::to_string(graph_id));
}
// HardwareOptimize
HardwareOptimize(graph);
// Dump .pb graph after hardware optimization
if (save_graphs) {
DumpIRProto(graph, "after_hwopt_" + std::to_string(graph_id));
}
// Assign CUDA streams // Assign CUDA streams
AssignStream(graph); AssignStream(graph);
// Hide NopOp from execution graph // Hide NopOp from execution graph


+ 0
- 5
mindspore/ccsrc/backend/session/gpu_session.h View File

@@ -67,8 +67,6 @@ class GPUSession : public SessionBasic {


void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;


bool IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const;

#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;


@@ -82,9 +80,6 @@ class GPUSession : public SessionBasic {


void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
#endif #endif

static constexpr size_t kConv2dCount = 96;
static constexpr size_t kFusedBatchNormCount = 94;
}; };
using GPUSessionPtr = std::shared_ptr<GPUSession>; using GPUSessionPtr = std::shared_ptr<GPUSession>;
MS_REG_SESSION(kGPUDevice, GPUSession); MS_REG_SESSION(kGPUDevice, GPUSession);


+ 1
- 1
mindspore/ccsrc/runtime/device/gpu/cuda_env_checker.cc View File

@@ -47,7 +47,7 @@ bool CudaEnvChecker::CheckNvccInPath() {
}; };


auto cuda_paths = GetCudaRealPaths(); auto cuda_paths = GetCudaRealPaths();
find_nvcc_ = any_of(cuda_paths.begin(), cuda_paths.end(), checker);
find_nvcc_ = std::any_of(cuda_paths.begin(), cuda_paths.end(), checker);
already_check_nvcc_ = true; already_check_nvcc_ = true;
return find_nvcc_; return find_nvcc_;
} }


+ 29
- 2
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc View File

@@ -165,6 +165,9 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return false; return false;
} }
if (!FormatTransformChecker::GetInstance().format_transform()) {
return false;
}
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) { if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
return false; return false;
} }
@@ -232,7 +235,31 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
} }
} // namespace } // namespace


void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto kernels = kernel_graph->execution_order();
size_t conv_cnt = 0;
size_t bn_cnt = 0;
for (const auto &kernel : kernels) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
if (kernel_name == prim::kPrimLayerNorm->name()) {
format_transform_ = false;
return;
}
if (kernel_name == prim::kPrimConv2D->name()) {
conv_cnt++;
}
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
bn_cnt++;
}
}
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
format_transform_ = false;
return;
}
format_transform_ = true;
}

void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format; std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type; std::vector<TypeId> inputs_type;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
@@ -246,7 +273,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
} }
std::string origin_data_format = kOpFormat_DEFAULT; std::string origin_data_format = kOpFormat_DEFAULT;
if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
} }
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =


+ 24
- 1
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h View File

@@ -20,11 +20,13 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include <map> #include <map>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/dtype.h" #include "ir/dtype.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "backend/session/kernel_graph.h"


namespace mindspore { namespace mindspore {
namespace device { namespace device {
@@ -53,7 +55,28 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
{prim::kPrimAddN->name(), {{}, {0}}}, {prim::kPrimAddN->name(), {{}, {0}}},
}; };


void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform = false);
void SetKernelInfo(const CNodePtr &kernel_node);

class FormatTransformChecker {
public:
void CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph);
bool format_transform() const { return format_transform_; }

static FormatTransformChecker &GetInstance() {
static FormatTransformChecker instance;
return instance;
}

private:
FormatTransformChecker() = default;
~FormatTransformChecker() = default;
FormatTransformChecker(const FormatTransformChecker &);
FormatTransformChecker &operator=(const FormatTransformChecker &);

bool format_transform_{true};
static constexpr size_t kConv2dCount = 96;
static constexpr size_t kFusedBatchNormCount = 94;
};


class KernelAttr { class KernelAttr {
public: public:


+ 4
- 0
tests/ut/cpp/CMakeLists.txt View File

@@ -133,6 +133,10 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/scheduler.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/scheduler.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc")


add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST}) add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST})
add_library(_ut_ut_obj OBJECT ${UT_SRCS}) add_library(_ut_ut_obj OBJECT ${UT_SRCS})


Loading…
Cancel
Save