From 28d92b1e894cfa7ea8a6d44c9a13b9cbb29facc0 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Mon, 1 Mar 2021 20:20:23 +0800 Subject: [PATCH] optimizer concat output pass --- .../tbe_kernel_select/tbe_kernel_select.cc | 32 ++++++++----- .../tbe/tbe_kernel_select/tbe_kernel_select.h | 3 +- .../enhancer/concat_outputs_for_all_gather.cc | 48 ++++++++++++++++--- .../enhancer/concat_outputs_for_all_gather.h | 4 +- 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index 15b1b12da7..b957e3974d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -82,13 +82,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() { } void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { - // get dynamic inputs - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); - MS_EXCEPTION_IF_NULL(primitive); - std::vector dyn_input_sizes; - if (primitive->HasAttr(kAttrDynInputSizes)) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } + auto dyn_input_sizes = GetNodeDynamicInputs(); // get real input/output num size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); const auto inputs_info = op_info.inputs_ptr(); @@ -189,8 +183,9 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { return; } std::vector> new_kernel_info_list; + auto dynamic_inputs = GetNodeDynamicInputs(); for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) { - if (!FilterInVaildShape(iter)) { + if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) { continue; } if (op_info.need_check_supported()) { @@ -203,13 +198,15 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { (*kernel_info_list_) = new_kernel_info_list; } -bool TbeKernelSelect::FilterInVaildShape( - const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { +bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input) { MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); - for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { + // dynamic input just need to check first input, because other inputs copy from 1th input; + auto iter_num = + is_dynamic_input && !kernel_build_info_inputs_format.empty() ? 1 : kernel_build_info_inputs_format.size(); + for (size_t i = 0; i < iter_num; ++i) { auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); - const auto &format = kernel_build_info_inputs_format[i]; + const auto &format = kernel_build_info_inputs_format.at(i); if (!IsShapeMatchFormat(shape, format)) { return false; } @@ -279,6 +276,17 @@ void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_ builder->SetKernelType(TBE_KERNEL); } +std::vector TbeKernelSelect::GetNodeDynamicInputs() { + // get dynamic inputs + auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); + MS_EXCEPTION_IF_NULL(primitive); + std::vector dyn_input_sizes; + if (primitive->HasAttr(kAttrDynInputSizes)) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + return dyn_input_sizes; +} + bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, const std::vector> &ios_info, const std::vector &dyn_input_sizes, std::vector *formats, diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h index 60601a658e..2547260ad2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h @@ -44,10 +44,11 @@ class TbeKernelSelect { void GetBroadcastPatternKernelInfo(const OpInfo &op_info); void GetReducePatternKernelInfo(const OpInfo &op_info); void FilterInVaildKernelInfo(const OpInfo &op_info); - bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); + bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input); static bool IsShapeMatchFormat(const std::vector &shape, const std::string &format); bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); + std::vector GetNodeDynamicInputs(); bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, const std::vector> &ios_info, const std::vector &dyn_input_sizes, std::vector *formats, diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc index fbf79d9352..29c941725e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc @@ -18,11 +18,47 @@ #include #include "backend/session/anf_runtime_algorithm.h" -namespace mindspore { -namespace opt { +namespace mindspore::opt { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat) { + MS_EXCEPTION_IF_NULL(concat); + std::vector inputs_device_format; + std::vector outputs_device_format; + std::vector inputs_device_type; + std::vector outputs_device_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(concat); ++input_index) { + inputs_device_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(concat, input_index)); + inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(concat, input_index)); + } + // Current only support default format & float16 + auto cmp_format = inputs_device_format.begin(); + auto format_iter = std::find_if(inputs_device_format.begin(), inputs_device_format.end(), + [&](const auto &format) { return format != (*cmp_format); }); + if (format_iter != inputs_device_format.end()) { + MS_LOG(EXCEPTION) << "Input format is not same, value: " << *format_iter; + } + auto cmp_dtype = inputs_device_type.begin(); + auto dtype_iter = std::find_if(inputs_device_type.begin(), inputs_device_type.end(), + [&](const auto &dtype) { return dtype != (*cmp_dtype); }); + if (dtype_iter != inputs_device_type.end()) { + MS_LOG(EXCEPTION) << "Input dtype is not same, value: " << *dtype_iter; + } + outputs_device_format.emplace_back(*cmp_format); + outputs_device_type.emplace_back(*cmp_dtype); + + builder.SetFusionType(kernel::FusionType::OPAQUE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(TBE_KERNEL); + builder.SetInputsFormat(inputs_device_format); + builder.SetOutputsFormat(outputs_device_format); + builder.SetInputsDeviceType(inputs_device_type); + builder.SetOutputsDeviceType(outputs_device_type); + return builder.Build(); +} + AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::vector &new_tuple_getitems, - int64_t rank_size) const { + int64_t rank_size) { MS_EXCEPTION_IF_NULL(func_graph); std::vector make_tuple_inputs{NewValueNode(std::make_shared(prim::kPrimMakeTuple->name()))}; size_t inputs_size = AnfAlgo::GetInputTensorNum(node); @@ -43,7 +79,8 @@ AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr & AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); std::vector dyn_input_size{rank_size}; AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); - kernel_select_->SelectKernel(concat); + auto kernel_build_info = GenerateKernelBuildInfo(concat); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, concat.get()); make_tuple_inputs.push_back(concat); } @@ -78,5 +115,4 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); } -} // namespace opt -} // namespace mindspore +} // namespace mindspore::opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h index adaeb83f31..df18cb0d4b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h @@ -33,8 +33,8 @@ class ConcatOutputsForAllGather : public PatternProcessPass { const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: - AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const std::vector &new_tuple_getitems, int64_t rank_size) const; + static AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::vector &new_tuple_getitems, int64_t rank_size); KernelSelectPtr kernel_select_; }; } // namespace opt