diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h index 7269998b30..63af8fc951 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h @@ -105,6 +105,7 @@ class OpInfo { dynamic_shape_ = opinfo.dynamic_shape_; op_pattern_ = opinfo.op_pattern(); processor_ = opinfo.processor_; + need_check_supported_ = opinfo.need_check_supported(); for (const auto &attr : opinfo.attrs_ptr()) { attrs_ptr_.push_back(std::make_shared(*attr)); } @@ -125,6 +126,7 @@ class OpInfo { OpPattern op_pattern() const { return op_pattern_; } bool dynamic_shape() const { return dynamic_shape_; } std::string processor() const { return processor_; } + bool need_check_supported() const { return need_check_supported_; } std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } std::vector> outputs_ptr() const { return outputs_ptr_; } @@ -142,6 +144,7 @@ class OpInfo { void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } void set_processor(const std::string &processor) { processor_ = processor; } + void set_need_check_supported(const bool need_check_supported) { need_check_supported_ = need_check_supported; } void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } @@ -168,6 +171,7 @@ class OpInfo { bool partial_flag_ = false; bool dynamic_format_ = false; bool dynamic_shape_ = false; + bool need_check_supported_ = false; OpPattern op_pattern_ = kCommonPattern; std::string processor_; std::vector> attrs_ptr_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc index d5cecf79ed..cada10cb34 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc @@ -36,6 +36,7 @@ constexpr auto kReshapeType = "reshape_type"; constexpr auto kOpPattern = "op_pattern"; constexpr auto kDynamicFormat = "dynamicFormat"; constexpr auto kFormatAgnostic = "formatAgnostic"; +constexpr auto kNeedCheckSupported = "need_check_supported"; constexpr auto kBroadcast = "broadcast"; constexpr auto kReduce = "reduce"; constexpr auto kDynamicShape = "dynamic_shape"; @@ -111,6 +112,7 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p op_info->set_compute_cost(obj.at(kComputeCost)); op_info->set_kernel_name(obj.at(kKernelName)); op_info->set_partial_flag(obj.at(kPartialFlag)); + op_info->set_need_check_supported(obj.at(kNeedCheckSupported)); if (obj.find(kDynamicShape) != obj.end()) { op_info->set_dynamic_shape(obj.at(kDynamicShape)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index 52085be9d0..bceff6b478 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -31,11 +31,9 @@ #include "backend/optimizer/common/helper.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/kernel_build_client.h" -#include "frontend/parallel/ops_info/ops_utils.h" #include "nlohmann/json.hpp" -namespace mindspore { -namespace kernel { +namespace mindspore::kernel { constexpr auto kName = "name"; constexpr auto kDtype = "dtype"; constexpr auto kFormat = "format"; @@ -83,7 +81,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() { MS_LOG(INFO) << "Warning: op pattern is invailed."; } // check support - FilterInVaildKernelInfo(); + FilterInVaildKernelInfo(*op_info_ptr); MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; } @@ -221,7 +219,7 @@ void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { MS_LOG(INFO) << "end."; } -void TbeKernelSelect::FilterInVaildKernelInfo() { +void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { if (kernel_info_list_->empty()) { MS_LOG(INFO) << "Warning: get kernel build info failed."; return; @@ -232,9 +230,11 @@ void TbeKernelSelect::FilterInVaildKernelInfo() { MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*iter)->ToString(); continue; } - if (!TbeCheckSupported(iter)) { - MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString(); - continue; + if (op_info.need_check_supported()) { + if (!TbeCheckSupported(iter)) { + MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString(); + continue; + } } new_kernel_info_list.emplace_back(*iter); } @@ -292,22 +292,9 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const return true; } -bool TbeKernelSelect::TbeCheckSupported( - const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { - MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); - static const std::set kCheckSupportedOpType = {parallel::MATMUL, - parallel::BATCHMATMUL, - parallel::TOPK, - parallel::IN_TOPK, - parallel::PACK, - parallel::UNSORTEF_SEGMENT_MIND, - parallel::UNSORTEF_SEGMENT_PRODD, - parallel::CAST}; - auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); - if (iter == kCheckSupportedOpType.end()) { - return true; - } +bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter) { MS_LOG(INFO) << "Check support start."; + MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); // replace kernel_info with current kernel info auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); @@ -560,7 +547,7 @@ void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, std::string input_format_item = item.value().at(kFormat); select_input.formats = SplitStrToVec(input_format_item); inputs.emplace_back(select_input); - } else if (is_output) { + } else { SelectOpIOInfo select_output; select_output.name = item.value().at(kName); std::string input_dtype_item = item.value().at(kDtype); @@ -628,5 +615,4 @@ void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) MS_LOG(INFO) << "Support format: " << print_str; } } -} // namespace kernel -} // namespace mindspore +} // namespace mindspore::kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h index 95597afd23..60601a658e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h @@ -43,7 +43,7 @@ class TbeKernelSelect { void GetAgnosticPatternKernelInfo(const OpInfo &op_info); void GetBroadcastPatternKernelInfo(const OpInfo &op_info); void GetReducePatternKernelInfo(const OpInfo &op_info); - void FilterInVaildKernelInfo(); + void FilterInVaildKernelInfo(const OpInfo &op_info); bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); static bool IsShapeMatchFormat(const std::vector &shape, const std::string &format); bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); diff --git a/mindspore/ops/_op_impl/tbe/assign_add.py b/mindspore/ops/_op_impl/tbe/assign_add.py index 5003157b2f..23c3e5f682 100644 --- a/mindspore/ops/_op_impl/tbe/assign_add.py +++ b/mindspore/ops/_op_impl/tbe/assign_add.py @@ -23,6 +23,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \ .compute_cost(10) \ .kernel_name("assign_add") \ .partial_flag(True) \ + .need_check_supported(True) \ .input(0, "ref", False, "required", "all") \ .input(1, "value", False, "required", "all") \ .output(0, "ref", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/batch_matmul.py b/mindspore/ops/_op_impl/tbe/batch_matmul.py index fd1935d369..cd35e97588 100644 --- a/mindspore/ops/_op_impl/tbe/batch_matmul.py +++ b/mindspore/ops/_op_impl/tbe/batch_matmul.py @@ -25,6 +25,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \ .attr("transpose_x1", "required", "bool", "all") \ .attr("transpose_x2", "required", "bool", "all") \ .partial_flag(True) \ + .need_check_supported(True) \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ .input(2, "bias", False, "optional", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/cast.py b/mindspore/ops/_op_impl/tbe/cast.py index 0a809e28a7..fc43d3a8ee 100644 --- a/mindspore/ops/_op_impl/tbe/cast.py +++ b/mindspore/ops/_op_impl/tbe/cast.py @@ -24,6 +24,7 @@ cast_op_info = TBERegOp("Cast") \ .kernel_name("cast") \ .partial_flag(True) \ .attr("dst_type", "required", "int", "all") \ + .need_check_supported(True) \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .op_pattern("formatAgnostic") \ diff --git a/mindspore/ops/_op_impl/tbe/cast_ds.py b/mindspore/ops/_op_impl/tbe/cast_ds.py index bb9c472a07..05d99afd94 100644 --- a/mindspore/ops/_op_impl/tbe/cast_ds.py +++ b/mindspore/ops/_op_impl/tbe/cast_ds.py @@ -23,6 +23,7 @@ cast_ds_op_info = TBERegOp("Cast") \ .compute_cost(10) \ .kernel_name("cast") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("dst_type", "required", "int", "all") \ .dynamic_shape(True)\ .input(0, "x", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/fill.py b/mindspore/ops/_op_impl/tbe/fill.py index 90301f123b..37e330a210 100644 --- a/mindspore/ops/_op_impl/tbe/fill.py +++ b/mindspore/ops/_op_impl/tbe/fill.py @@ -23,6 +23,7 @@ fill_d_op_info = TBERegOp("Fill") \ .compute_cost(10) \ .kernel_name("fill_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("dims", "required", "listInt", "all") \ .input(0, "value", False, "required", "all") \ .output(0, "y", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/in_top_k.py b/mindspore/ops/_op_impl/tbe/in_top_k.py index 46d7258e2a..dc87df21cd 100644 --- a/mindspore/ops/_op_impl/tbe/in_top_k.py +++ b/mindspore/ops/_op_impl/tbe/in_top_k.py @@ -23,6 +23,7 @@ in_top_k_op_info = TBERegOp("InTopK") \ .compute_cost(10) \ .kernel_name("in_top_k") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("k", "required", "int", "all") \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/inplace_update.py b/mindspore/ops/_op_impl/tbe/inplace_update.py index b8c7454d77..8eaa27d7d6 100644 --- a/mindspore/ops/_op_impl/tbe/inplace_update.py +++ b/mindspore/ops/_op_impl/tbe/inplace_update.py @@ -23,6 +23,7 @@ inplace_update_op_info = TBERegOp("InplaceUpdate") \ .compute_cost(10) \ .kernel_name("inplace_update_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("indices", "required", "listInt", "all") \ .input(0, "x", False, "required", "all") \ .input(1, "v", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/matmul.py b/mindspore/ops/_op_impl/tbe/matmul.py index be7f7303e4..dde83e1099 100644 --- a/mindspore/ops/_op_impl/tbe/matmul.py +++ b/mindspore/ops/_op_impl/tbe/matmul.py @@ -23,6 +23,7 @@ matmul_op_info = TBERegOp("MatMul") \ .compute_cost(10) \ .kernel_name("mat_mul") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("transpose_x1", "required", "bool", "all") \ .attr("transpose_x2", "required", "bool", "all") \ .attr("offset_x", "optional", "int", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/pack.py b/mindspore/ops/_op_impl/tbe/pack.py index fa1b1a2644..bd63121c74 100644 --- a/mindspore/ops/_op_impl/tbe/pack.py +++ b/mindspore/ops/_op_impl/tbe/pack.py @@ -23,6 +23,7 @@ pack_op_info = TBERegOp("Pack") \ .compute_cost(10) \ .kernel_name("pack") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("axis", "optional", "int", "all") \ .input(0, "x", False, "dynamic", "all") \ .output(0, "y", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/reduce_prod.py b/mindspore/ops/_op_impl/tbe/reduce_prod.py index b6f1386ce2..459f13bfcc 100644 --- a/mindspore/ops/_op_impl/tbe/reduce_prod.py +++ b/mindspore/ops/_op_impl/tbe/reduce_prod.py @@ -23,6 +23,7 @@ reduce_prod_op_info = TBERegOp("ReduceProd") \ .compute_cost(10) \ .kernel_name("reduce_prod_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("axis", "required", "listInt", "all") \ .attr("keep_dims", "optional", "bool", "all") \ .input(0, "x", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/resize_bilinear.py b/mindspore/ops/_op_impl/tbe/resize_bilinear.py index 2d8091488c..84db8831e3 100644 --- a/mindspore/ops/_op_impl/tbe/resize_bilinear.py +++ b/mindspore/ops/_op_impl/tbe/resize_bilinear.py @@ -23,6 +23,7 @@ resize_bilinear_op_info = TBERegOp("ResizeBilinear") \ .compute_cost(10) \ .kernel_name("resize_bilinear_v2_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("size", "required", "listInt", "all") \ .attr("align_corners", "optional", "bool", "all") \ .attr("half_pixel_centers", "optional", "bool", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py b/mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py index bbc4419458..5ab36f30a3 100644 --- a/mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +++ b/mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py @@ -23,6 +23,7 @@ resize_bilinear_grad_op_info = TBERegOp("ResizeBilinearGrad") \ .compute_cost(10) \ .kernel_name("resize_bilinear_v2_grad") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("align_corners", "optional", "bool", "all") \ .attr("half_pixel_centers", "optional", "bool", "all")\ .input(0, "grads", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py b/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py index 38da10c9f2..20413cc209 100644 --- a/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +++ b/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py @@ -23,6 +23,7 @@ resize_nearest_neighbor_op_info = TBERegOp("ResizeNearestNeighbor") \ .compute_cost(10) \ .kernel_name("resize_nearest_neighbor_v2_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("size", "required", "listInt", "all") \ .attr("align_corners", "optional", "bool", "all") \ .input(0, "images", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py b/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py index bf046de668..abbafc6c6b 100644 --- a/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +++ b/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py @@ -23,6 +23,7 @@ resize_nearest_neighbor_grad_op_info = TBERegOp("ResizeNearestNeighborGrad") \ .compute_cost(10) \ .kernel_name("resize_nearest_neighbor_v2_grad_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("size", "required", "listInt", "all") \ .attr("align_corners", "optional", "bool", "all") \ .input(0, "grads", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/top_k.py b/mindspore/ops/_op_impl/tbe/top_k.py index a97ecadae0..ecc37716ed 100644 --- a/mindspore/ops/_op_impl/tbe/top_k.py +++ b/mindspore/ops/_op_impl/tbe/top_k.py @@ -23,6 +23,7 @@ top_k_op_info = TBERegOp("TopK") \ .compute_cost(10) \ .kernel_name("top_k_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("dim", "optional", "int", "all") \ .attr("k", "required", "int", "all") \ .attr("largest", "optional", "bool", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/unsorted_segment_max.py b/mindspore/ops/_op_impl/tbe/unsorted_segment_max.py index 63596fdb70..e5b1baf2e3 100644 --- a/mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +++ b/mindspore/ops/_op_impl/tbe/unsorted_segment_max.py @@ -23,6 +23,7 @@ unsorted_segment_max_op_info = TBERegOp("UnsortedSegmentMax") \ .compute_cost(10) \ .kernel_name("unsorted_segment_max_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("num_segments", "required", "int", "all") \ .input(0, "data", False, "required", "all") \ .input(1, "segment_ids", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/unsorted_segment_min.py b/mindspore/ops/_op_impl/tbe/unsorted_segment_min.py index a26f14048a..1e8f24da9d 100644 --- a/mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +++ b/mindspore/ops/_op_impl/tbe/unsorted_segment_min.py @@ -23,6 +23,7 @@ unsorted_segment_min_op_info = TBERegOp("UnsortedSegmentMin") \ .compute_cost(10) \ .kernel_name("unsorted_segment_min_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("num_segments", "required", "int", "all") \ .input(0, "data", False, "required", "all") \ .input(1, "segment_ids", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py b/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py index 40b04d17c3..22ef624a84 100644 --- a/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +++ b/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py @@ -23,6 +23,7 @@ unsorted_segment_prod_d_op_info = TBERegOp("UnsortedSegmentProd") \ .compute_cost(10) \ .kernel_name("unsorted_segment_prod_d") \ .partial_flag(True) \ + .need_check_supported(True) \ .attr("num_segments", "required", "int", "all") \ .input(0, "data", False, "required", "all") \ .input(1, "segment_ids", False, "required", "all") \ diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 88b9410503..c4df8809db 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -355,6 +355,7 @@ class TBERegOp(RegOp): self.reshape_type_ = '' self.dynamic_format_ = False self.dynamic_shape_ = False + self.need_check_supported_ = False self.op_pattern_ = "" def async_flag(self, async_flag): @@ -445,6 +446,17 @@ class TBERegOp(RegOp): self.dynamic_shape_ = dynamic_shape return self + def need_check_supported(self, need_check_supported): + """ + Whether the operator need check supports. + + Args: + :param need_check_supported: (bool): Value of need_check_supported. Default: false. + """ + self._is_bool(need_check_supported) + self.need_check_supported_ = need_check_supported + return self + def op_pattern(self, pattern=None): """ The behavior type of opeator, such as broadcast, reduce and so on.