| @@ -82,13 +82,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() { | |||||
| } | } | ||||
| void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | ||||
| // get dynamic inputs | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| std::vector<int64_t> dyn_input_sizes; | |||||
| if (primitive->HasAttr(kAttrDynInputSizes)) { | |||||
| dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes)); | |||||
| } | |||||
| auto dyn_input_sizes = GetNodeDynamicInputs(); | |||||
| // get real input/output num | // get real input/output num | ||||
| size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | ||||
| const auto inputs_info = op_info.inputs_ptr(); | const auto inputs_info = op_info.inputs_ptr(); | ||||
| @@ -189,8 +183,9 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { | |||||
| return; | return; | ||||
| } | } | ||||
| std::vector<std::shared_ptr<KernelBuildInfo>> new_kernel_info_list; | std::vector<std::shared_ptr<KernelBuildInfo>> new_kernel_info_list; | ||||
| auto dynamic_inputs = GetNodeDynamicInputs(); | |||||
| for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) { | for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) { | ||||
| if (!FilterInVaildShape(iter)) { | |||||
| if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (op_info.need_check_supported()) { | if (op_info.need_check_supported()) { | ||||
| @@ -203,13 +198,15 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { | |||||
| (*kernel_info_list_) = new_kernel_info_list; | (*kernel_info_list_) = new_kernel_info_list; | ||||
| } | } | ||||
| bool TbeKernelSelect::FilterInVaildShape( | |||||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||||
| bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input) { | |||||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | ||||
| const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); | const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); | ||||
| for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { | |||||
| // dynamic input just need to check first input, because other inputs copy from 1th input; | |||||
| auto iter_num = | |||||
| is_dynamic_input && !kernel_build_info_inputs_format.empty() ? 1 : kernel_build_info_inputs_format.size(); | |||||
| for (size_t i = 0; i < iter_num; ++i) { | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | ||||
| const auto &format = kernel_build_info_inputs_format[i]; | |||||
| const auto &format = kernel_build_info_inputs_format.at(i); | |||||
| if (!IsShapeMatchFormat(shape, format)) { | if (!IsShapeMatchFormat(shape, format)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -279,6 +276,17 @@ void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_ | |||||
| builder->SetKernelType(TBE_KERNEL); | builder->SetKernelType(TBE_KERNEL); | ||||
| } | } | ||||
| std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() { | |||||
| // get dynamic inputs | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| std::vector<int64_t> dyn_input_sizes; | |||||
| if (primitive->HasAttr(kAttrDynInputSizes)) { | |||||
| dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes)); | |||||
| } | |||||
| return dyn_input_sizes; | |||||
| } | |||||
| bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | ||||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | ||||
| const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats, | const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats, | ||||
| @@ -44,10 +44,11 @@ class TbeKernelSelect { | |||||
| void GetBroadcastPatternKernelInfo(const OpInfo &op_info); | void GetBroadcastPatternKernelInfo(const OpInfo &op_info); | ||||
| void GetReducePatternKernelInfo(const OpInfo &op_info); | void GetReducePatternKernelInfo(const OpInfo &op_info); | ||||
| void FilterInVaildKernelInfo(const OpInfo &op_info); | void FilterInVaildKernelInfo(const OpInfo &op_info); | ||||
| bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); | |||||
| bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input); | |||||
| static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format); | static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format); | ||||
| bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); | bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); | ||||
| static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); | static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); | ||||
| std::vector<int64_t> GetNodeDynamicInputs(); | |||||
| bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | ||||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | ||||
| const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats, | const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats, | ||||
| @@ -18,11 +18,47 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace mindspore::opt { | |||||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat) { | |||||
| MS_EXCEPTION_IF_NULL(concat); | |||||
| std::vector<std::string> inputs_device_format; | |||||
| std::vector<std::string> outputs_device_format; | |||||
| std::vector<TypeId> inputs_device_type; | |||||
| std::vector<TypeId> outputs_device_type; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(concat); ++input_index) { | |||||
| inputs_device_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(concat, input_index)); | |||||
| inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(concat, input_index)); | |||||
| } | |||||
| // Current only support default format & float16 | |||||
| auto cmp_format = inputs_device_format.begin(); | |||||
| auto format_iter = std::find_if(inputs_device_format.begin(), inputs_device_format.end(), | |||||
| [&](const auto &format) { return format != (*cmp_format); }); | |||||
| if (format_iter != inputs_device_format.end()) { | |||||
| MS_LOG(EXCEPTION) << "Input format is not same, value: " << *format_iter; | |||||
| } | |||||
| auto cmp_dtype = inputs_device_type.begin(); | |||||
| auto dtype_iter = std::find_if(inputs_device_type.begin(), inputs_device_type.end(), | |||||
| [&](const auto &dtype) { return dtype != (*cmp_dtype); }); | |||||
| if (dtype_iter != inputs_device_type.end()) { | |||||
| MS_LOG(EXCEPTION) << "Input dtype is not same, value: " << *dtype_iter; | |||||
| } | |||||
| outputs_device_format.emplace_back(*cmp_format); | |||||
| outputs_device_type.emplace_back(*cmp_dtype); | |||||
| builder.SetFusionType(kernel::FusionType::OPAQUE); | |||||
| builder.SetProcessor(kernel::Processor::AICORE); | |||||
| builder.SetKernelType(TBE_KERNEL); | |||||
| builder.SetInputsFormat(inputs_device_format); | |||||
| builder.SetOutputsFormat(outputs_device_format); | |||||
| builder.SetInputsDeviceType(inputs_device_type); | |||||
| builder.SetOutputsDeviceType(outputs_device_type); | |||||
| return builder.Build(); | |||||
| } | |||||
| AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const std::vector<AnfNodePtr> &new_tuple_getitems, | const std::vector<AnfNodePtr> &new_tuple_getitems, | ||||
| int64_t rank_size) const { | |||||
| int64_t rank_size) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | ||||
| size_t inputs_size = AnfAlgo::GetInputTensorNum(node); | size_t inputs_size = AnfAlgo::GetInputTensorNum(node); | ||||
| @@ -43,7 +79,8 @@ AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr & | |||||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); | AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); | ||||
| std::vector<int64_t> dyn_input_size{rank_size}; | std::vector<int64_t> dyn_input_size{rank_size}; | ||||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); | AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); | ||||
| kernel_select_->SelectKernel(concat); | |||||
| auto kernel_build_info = GenerateKernelBuildInfo(concat); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, concat.get()); | |||||
| make_tuple_inputs.push_back(concat); | make_tuple_inputs.push_back(concat); | ||||
| } | } | ||||
| @@ -78,5 +115,4 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra | |||||
| CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); | CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); | ||||
| return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); | return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); | ||||
| } | } | ||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::opt | |||||
| @@ -33,8 +33,8 @@ class ConcatOutputsForAllGather : public PatternProcessPass { | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | private: | ||||
| AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) const; | |||||
| static AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size); | |||||
| KernelSelectPtr kernel_select_; | KernelSelectPtr kernel_select_; | ||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||