| @@ -82,13 +82,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() { | |||
| } | |||
| void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | |||
| // get dynamic inputs | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| std::vector<int64_t> dyn_input_sizes; | |||
| if (primitive->HasAttr(kAttrDynInputSizes)) { | |||
| dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes)); | |||
| } | |||
| auto dyn_input_sizes = GetNodeDynamicInputs(); | |||
| // get real input/output num | |||
| size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||
| const auto inputs_info = op_info.inputs_ptr(); | |||
| @@ -189,8 +183,9 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { | |||
| return; | |||
| } | |||
| std::vector<std::shared_ptr<KernelBuildInfo>> new_kernel_info_list; | |||
| auto dynamic_inputs = GetNodeDynamicInputs(); | |||
| for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) { | |||
| if (!FilterInVaildShape(iter)) { | |||
| if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) { | |||
| continue; | |||
| } | |||
| if (op_info.need_check_supported()) { | |||
| @@ -203,13 +198,15 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { | |||
| (*kernel_info_list_) = new_kernel_info_list; | |||
| } | |||
| bool TbeKernelSelect::FilterInVaildShape( | |||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||
| bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input) { | |||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||
| const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); | |||
| for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { | |||
| // dynamic input just need to check first input, because other inputs copy from 1th input; | |||
| auto iter_num = | |||
| is_dynamic_input && !kernel_build_info_inputs_format.empty() ? 1 : kernel_build_info_inputs_format.size(); | |||
| for (size_t i = 0; i < iter_num; ++i) { | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | |||
| const auto &format = kernel_build_info_inputs_format[i]; | |||
| const auto &format = kernel_build_info_inputs_format.at(i); | |||
| if (!IsShapeMatchFormat(shape, format)) { | |||
| return false; | |||
| } | |||
| @@ -279,6 +276,17 @@ void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_ | |||
| builder->SetKernelType(TBE_KERNEL); | |||
| } | |||
| std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() { | |||
| // get dynamic inputs | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| std::vector<int64_t> dyn_input_sizes; | |||
| if (primitive->HasAttr(kAttrDynInputSizes)) { | |||
| dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes)); | |||
| } | |||
| return dyn_input_sizes; | |||
| } | |||
| bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | |||
| const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats, | |||
| @@ -44,10 +44,11 @@ class TbeKernelSelect { | |||
| void GetBroadcastPatternKernelInfo(const OpInfo &op_info); | |||
| void GetReducePatternKernelInfo(const OpInfo &op_info); | |||
| void FilterInVaildKernelInfo(const OpInfo &op_info); | |||
| bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); | |||
| bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input); | |||
| static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format); | |||
| bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); | |||
| static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); | |||
| std::vector<int64_t> GetNodeDynamicInputs(); | |||
| bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | |||
| const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats, | |||
| @@ -18,11 +18,47 @@ | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace mindspore::opt { | |||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat) { | |||
| MS_EXCEPTION_IF_NULL(concat); | |||
| std::vector<std::string> inputs_device_format; | |||
| std::vector<std::string> outputs_device_format; | |||
| std::vector<TypeId> inputs_device_type; | |||
| std::vector<TypeId> outputs_device_type; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(concat); ++input_index) { | |||
| inputs_device_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(concat, input_index)); | |||
| inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(concat, input_index)); | |||
| } | |||
| // Current only support default format & float16 | |||
| auto cmp_format = inputs_device_format.begin(); | |||
| auto format_iter = std::find_if(inputs_device_format.begin(), inputs_device_format.end(), | |||
| [&](const auto &format) { return format != (*cmp_format); }); | |||
| if (format_iter != inputs_device_format.end()) { | |||
| MS_LOG(EXCEPTION) << "Input format is not same, value: " << *format_iter; | |||
| } | |||
| auto cmp_dtype = inputs_device_type.begin(); | |||
| auto dtype_iter = std::find_if(inputs_device_type.begin(), inputs_device_type.end(), | |||
| [&](const auto &dtype) { return dtype != (*cmp_dtype); }); | |||
| if (dtype_iter != inputs_device_type.end()) { | |||
| MS_LOG(EXCEPTION) << "Input dtype is not same, value: " << *dtype_iter; | |||
| } | |||
| outputs_device_format.emplace_back(*cmp_format); | |||
| outputs_device_type.emplace_back(*cmp_dtype); | |||
| builder.SetFusionType(kernel::FusionType::OPAQUE); | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| builder.SetKernelType(TBE_KERNEL); | |||
| builder.SetInputsFormat(inputs_device_format); | |||
| builder.SetOutputsFormat(outputs_device_format); | |||
| builder.SetInputsDeviceType(inputs_device_type); | |||
| builder.SetOutputsDeviceType(outputs_device_type); | |||
| return builder.Build(); | |||
| } | |||
| AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::vector<AnfNodePtr> &new_tuple_getitems, | |||
| int64_t rank_size) const { | |||
| int64_t rank_size) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | |||
| size_t inputs_size = AnfAlgo::GetInputTensorNum(node); | |||
| @@ -43,7 +79,8 @@ AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr & | |||
| AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); | |||
| std::vector<int64_t> dyn_input_size{rank_size}; | |||
| AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); | |||
| kernel_select_->SelectKernel(concat); | |||
| auto kernel_build_info = GenerateKernelBuildInfo(concat); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, concat.get()); | |||
| make_tuple_inputs.push_back(concat); | |||
| } | |||
| @@ -78,5 +115,4 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra | |||
| CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); | |||
| return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| } // namespace mindspore::opt | |||
| @@ -33,8 +33,8 @@ class ConcatOutputsForAllGather : public PatternProcessPass { | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) const; | |||
| static AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size); | |||
| KernelSelectPtr kernel_select_; | |||
| }; | |||
| } // namespace opt | |||