Merge pull request !5517 from limingqi107/mastertags/v1.0.0
| @@ -49,10 +49,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; | |||||
| void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| bool in_black_list = CheckInModeBlackList(kernel_graph); | |||||
| bool graph_format_transform = IsSupportFormatTransform(kernel_graph); | |||||
| for (const auto &kernel_node : kernel_graph->execution_order()) { | for (const auto &kernel_node : kernel_graph->execution_order()) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| device::gpu::SetKernelInfo(kernel_node, in_black_list); | |||||
| device::gpu::SetKernelInfo(kernel_node, graph_format_transform); | |||||
| } | } | ||||
| } | } | ||||
| @@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); | pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | ||||
| if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||||
| if (IsSupportFormatTransform(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||||
| pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); | pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); | pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); | pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); | ||||
| @@ -193,14 +193,14 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const | |||||
| } | } | ||||
| } | } | ||||
| bool GPUSession::CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| bool GPUSession::IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| auto kernels = kernel_graph->execution_order(); | auto kernels = kernel_graph->execution_order(); | ||||
| size_t conv_cnt = 0; | size_t conv_cnt = 0; | ||||
| size_t bn_cnt = 0; | size_t bn_cnt = 0; | ||||
| for (const auto &kernel : kernels) { | for (const auto &kernel : kernels) { | ||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel); | auto kernel_name = AnfAlgo::GetCNodeName(kernel); | ||||
| if (kernel_name == prim::kPrimLayerNorm->name()) { | if (kernel_name == prim::kPrimLayerNorm->name()) { | ||||
| return true; | |||||
| return false; | |||||
| } | } | ||||
| if (kernel_name == prim::kPrimConv2D->name()) { | if (kernel_name == prim::kPrimConv2D->name()) { | ||||
| conv_cnt++; | conv_cnt++; | ||||
| @@ -210,9 +210,9 @@ bool GPUSession::CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel | |||||
| } | } | ||||
| } | } | ||||
| if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) { | if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) { | ||||
| return true; | |||||
| return false; | |||||
| } | } | ||||
| return false; | |||||
| return true; | |||||
| } | } | ||||
| GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | ||||
| @@ -67,7 +67,7 @@ class GPUSession : public SessionBasic { | |||||
| void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| bool CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| bool IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| #ifdef ENABLE_DEBUGGER | #ifdef ENABLE_DEBUGGER | ||||
| void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| @@ -404,7 +404,9 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v | |||||
| // Release the kernel resource. | // Release the kernel resource. | ||||
| for (const auto &kernel : execution_order) { | for (const auto &kernel : execution_order) { | ||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| if (kernel_mod == nullptr) { | |||||
| continue; | |||||
| } | |||||
| kernel_mod->ReleaseResource(); | kernel_mod->ReleaseResource(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -176,9 +176,18 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type | |||||
| if (inputs_type.size() == 0) { | if (inputs_type.size() == 0) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| if (input_shape.size() != 4) { | |||||
| return false; | |||||
| auto inputs_format_position = iter->second.first; | |||||
| // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. | |||||
| if (inputs_format_position.size() == 0) { | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); input_index++) { | |||||
| inputs_format_position.push_back(input_index); | |||||
| } | |||||
| } | |||||
| for (const auto &input_format_position : inputs_format_position) { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position); | |||||
| if (input_shape.size() != 4) { | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -223,7 +232,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) { | |||||
| void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) { | |||||
| std::vector<std::string> inputs_format; | std::vector<std::string> inputs_format; | ||||
| std::vector<TypeId> inputs_type; | std::vector<TypeId> inputs_type; | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | ||||
| @@ -237,7 +246,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) { | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | ||||
| } | } | ||||
| std::string origin_data_format = kOpFormat_DEFAULT; | std::string origin_data_format = kOpFormat_DEFAULT; | ||||
| if (!in_black_list && IsNeedProcessFormatInfo(kernel_node, inputs_type)) { | |||||
| if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) { | |||||
| UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); | UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); | ||||
| } | } | ||||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | ||||
| @@ -53,7 +53,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>> | |||||
| {prim::kPrimAddN->name(), {{}, {0}}}, | {prim::kPrimAddN->name(), {{}, {0}}}, | ||||
| }; | }; | ||||
| void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list = false); | |||||
| void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform = false); | |||||
| class KernelAttr { | class KernelAttr { | ||||
| public: | public: | ||||