Merge pull request !5481 from zyli2020/mastertags/v1.0.0
| @@ -49,9 +49,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; | |||||
| void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| bool in_black_list = CheckInModeBlackList(kernel_graph); | |||||
| for (const auto &kernel_node : kernel_graph->execution_order()) { | for (const auto &kernel_node : kernel_graph->execution_order()) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| device::gpu::SetKernelInfo(kernel_node); | |||||
| device::gpu::SetKernelInfo(kernel_node, in_black_list); | |||||
| } | } | ||||
| } | } | ||||
| @@ -75,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); | pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | ||||
| if (context_ptr->execution_mode() != kPynativeMode) { | |||||
| if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) { | |||||
| pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); | pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); | pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); | pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); | ||||
| @@ -192,6 +193,28 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const | |||||
| } | } | ||||
| } | } | ||||
| bool GPUSession::CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| auto kernels = kernel_graph->execution_order(); | |||||
| size_t conv_cnt = 0; | |||||
| size_t bn_cnt = 0; | |||||
| for (const auto &kernel : kernels) { | |||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel); | |||||
| if (kernel_name == prim::kPrimLayerNorm->name()) { | |||||
| return true; | |||||
| } | |||||
| if (kernel_name == prim::kPrimConv2D->name()) { | |||||
| conv_cnt++; | |||||
| } | |||||
| if (kernel_name == prim::kPrimFusedBatchNormEx->name()) { | |||||
| bn_cnt++; | |||||
| } | |||||
| } | |||||
| if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | ||||
| // Construct graph, if successfully, graph_sum_ + 1 | // Construct graph, if successfully, graph_sum_ + 1 | ||||
| auto graph_id = graph_sum_; | auto graph_id = graph_sum_; | ||||
| @@ -67,6 +67,8 @@ class GPUSession : public SessionBasic { | |||||
| void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| bool CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| #ifdef ENABLE_DEBUGGER | #ifdef ENABLE_DEBUGGER | ||||
| void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| @@ -80,6 +82,9 @@ class GPUSession : public SessionBasic { | |||||
| void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| #endif | #endif | ||||
| static constexpr size_t kConv2dCount = 96; | |||||
| static constexpr size_t kFusedBatchNormCount = 94; | |||||
| }; | }; | ||||
| using GPUSessionPtr = std::shared_ptr<GPUSession>; | using GPUSessionPtr = std::shared_ptr<GPUSession>; | ||||
| MS_REG_SESSION(kGPUDevice, GPUSession); | MS_REG_SESSION(kGPUDevice, GPUSession); | ||||
| @@ -223,7 +223,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void SetKernelInfo(const CNodePtr &kernel_node) { | |||||
| void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) { | |||||
| std::vector<std::string> inputs_format; | std::vector<std::string> inputs_format; | ||||
| std::vector<TypeId> inputs_type; | std::vector<TypeId> inputs_type; | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | ||||
| @@ -237,7 +237,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) { | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | ||||
| } | } | ||||
| std::string origin_data_format = kOpFormat_DEFAULT; | std::string origin_data_format = kOpFormat_DEFAULT; | ||||
| if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) { | |||||
| if (!in_black_list && IsNeedProcessFormatInfo(kernel_node, inputs_type)) { | |||||
| UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); | UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); | ||||
| } | } | ||||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | ||||
| @@ -53,7 +53,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>> | |||||
| {prim::kPrimAddN->name(), {{}, {0}}}, | {prim::kPrimAddN->name(), {{}, {0}}}, | ||||
| }; | }; | ||||
| void SetKernelInfo(const CNodePtr &apply_kernel_ptr); | |||||
| void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list = false); | |||||
| class KernelAttr { | class KernelAttr { | ||||
| public: | public: | ||||