| @@ -105,6 +105,7 @@ class OpInfo { | |||||
| dynamic_shape_ = opinfo.dynamic_shape_; | dynamic_shape_ = opinfo.dynamic_shape_; | ||||
| op_pattern_ = opinfo.op_pattern(); | op_pattern_ = opinfo.op_pattern(); | ||||
| processor_ = opinfo.processor_; | processor_ = opinfo.processor_; | ||||
| need_check_supported_ = opinfo.need_check_supported(); | |||||
| for (const auto &attr : opinfo.attrs_ptr()) { | for (const auto &attr : opinfo.attrs_ptr()) { | ||||
| attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr)); | attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr)); | ||||
| } | } | ||||
| @@ -125,6 +126,7 @@ class OpInfo { | |||||
| OpPattern op_pattern() const { return op_pattern_; } | OpPattern op_pattern() const { return op_pattern_; } | ||||
| bool dynamic_shape() const { return dynamic_shape_; } | bool dynamic_shape() const { return dynamic_shape_; } | ||||
| std::string processor() const { return processor_; } | std::string processor() const { return processor_; } | ||||
| bool need_check_supported() const { return need_check_supported_; } | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | ||||
| @@ -142,6 +144,7 @@ class OpInfo { | |||||
| void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | ||||
| void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } | void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } | ||||
| void set_processor(const std::string &processor) { processor_ = processor; } | void set_processor(const std::string &processor) { processor_ = processor; } | ||||
| void set_need_check_supported(const bool need_check_supported) { need_check_supported_ = need_check_supported; } | |||||
| void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | ||||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | ||||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | ||||
| @@ -168,6 +171,7 @@ class OpInfo { | |||||
| bool partial_flag_ = false; | bool partial_flag_ = false; | ||||
| bool dynamic_format_ = false; | bool dynamic_format_ = false; | ||||
| bool dynamic_shape_ = false; | bool dynamic_shape_ = false; | ||||
| bool need_check_supported_ = false; | |||||
| OpPattern op_pattern_ = kCommonPattern; | OpPattern op_pattern_ = kCommonPattern; | ||||
| std::string processor_; | std::string processor_; | ||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | ||||
| @@ -36,6 +36,7 @@ constexpr auto kReshapeType = "reshape_type"; | |||||
| constexpr auto kOpPattern = "op_pattern"; | constexpr auto kOpPattern = "op_pattern"; | ||||
| constexpr auto kDynamicFormat = "dynamicFormat"; | constexpr auto kDynamicFormat = "dynamicFormat"; | ||||
| constexpr auto kFormatAgnostic = "formatAgnostic"; | constexpr auto kFormatAgnostic = "formatAgnostic"; | ||||
| constexpr auto kNeedCheckSupported = "need_check_supported"; | |||||
| constexpr auto kBroadcast = "broadcast"; | constexpr auto kBroadcast = "broadcast"; | ||||
| constexpr auto kReduce = "reduce"; | constexpr auto kReduce = "reduce"; | ||||
| constexpr auto kDynamicShape = "dynamic_shape"; | constexpr auto kDynamicShape = "dynamic_shape"; | ||||
| @@ -111,6 +112,7 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p | |||||
| op_info->set_compute_cost(obj.at(kComputeCost)); | op_info->set_compute_cost(obj.at(kComputeCost)); | ||||
| op_info->set_kernel_name(obj.at(kKernelName)); | op_info->set_kernel_name(obj.at(kKernelName)); | ||||
| op_info->set_partial_flag(obj.at(kPartialFlag)); | op_info->set_partial_flag(obj.at(kPartialFlag)); | ||||
| op_info->set_need_check_supported(obj.at(kNeedCheckSupported)); | |||||
| if (obj.find(kDynamicShape) != obj.end()) { | if (obj.find(kDynamicShape) != obj.end()) { | ||||
| op_info->set_dynamic_shape(obj.at(kDynamicShape)); | op_info->set_dynamic_shape(obj.at(kDynamicShape)); | ||||
| @@ -31,11 +31,9 @@ | |||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "backend/session/kernel_build_client.h" | #include "backend/session/kernel_build_client.h" | ||||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||||
| #include "nlohmann/json.hpp" | #include "nlohmann/json.hpp" | ||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| namespace mindspore::kernel { | |||||
| constexpr auto kName = "name"; | constexpr auto kName = "name"; | ||||
| constexpr auto kDtype = "dtype"; | constexpr auto kDtype = "dtype"; | ||||
| constexpr auto kFormat = "format"; | constexpr auto kFormat = "format"; | ||||
| @@ -83,7 +81,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() { | |||||
| MS_LOG(INFO) << "Warning: op pattern is invailed."; | MS_LOG(INFO) << "Warning: op pattern is invailed."; | ||||
| } | } | ||||
| // check support | // check support | ||||
| FilterInVaildKernelInfo(); | |||||
| FilterInVaildKernelInfo(*op_info_ptr); | |||||
| MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; | MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; | ||||
| } | } | ||||
| @@ -221,7 +219,7 @@ void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "end."; | MS_LOG(INFO) << "end."; | ||||
| } | } | ||||
| void TbeKernelSelect::FilterInVaildKernelInfo() { | |||||
| void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { | |||||
| if (kernel_info_list_->empty()) { | if (kernel_info_list_->empty()) { | ||||
| MS_LOG(INFO) << "Warning: get kernel build info failed."; | MS_LOG(INFO) << "Warning: get kernel build info failed."; | ||||
| return; | return; | ||||
| @@ -232,9 +230,11 @@ void TbeKernelSelect::FilterInVaildKernelInfo() { | |||||
| MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*iter)->ToString(); | MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*iter)->ToString(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (!TbeCheckSupported(iter)) { | |||||
| MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString(); | |||||
| continue; | |||||
| if (op_info.need_check_supported()) { | |||||
| if (!TbeCheckSupported(iter)) { | |||||
| MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString(); | |||||
| continue; | |||||
| } | |||||
| } | } | ||||
| new_kernel_info_list.emplace_back(*iter); | new_kernel_info_list.emplace_back(*iter); | ||||
| } | } | ||||
| @@ -292,22 +292,9 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool TbeKernelSelect::TbeCheckSupported( | |||||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||||
| static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL, | |||||
| parallel::BATCHMATMUL, | |||||
| parallel::TOPK, | |||||
| parallel::IN_TOPK, | |||||
| parallel::PACK, | |||||
| parallel::UNSORTEF_SEGMENT_MIND, | |||||
| parallel::UNSORTEF_SEGMENT_PRODD, | |||||
| parallel::CAST}; | |||||
| auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); | |||||
| if (iter == kCheckSupportedOpType.end()) { | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter) { | |||||
| MS_LOG(INFO) << "Check support start."; | MS_LOG(INFO) << "Check support start."; | ||||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||||
| // replace kernel_info with current kernel info | // replace kernel_info with current kernel info | ||||
| auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); | auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); | AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); | ||||
| @@ -560,7 +547,7 @@ void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, | |||||
| std::string input_format_item = item.value().at(kFormat); | std::string input_format_item = item.value().at(kFormat); | ||||
| select_input.formats = SplitStrToVec(input_format_item); | select_input.formats = SplitStrToVec(input_format_item); | ||||
| inputs.emplace_back(select_input); | inputs.emplace_back(select_input); | ||||
| } else if (is_output) { | |||||
| } else { | |||||
| SelectOpIOInfo select_output; | SelectOpIOInfo select_output; | ||||
| select_output.name = item.value().at(kName); | select_output.name = item.value().at(kName); | ||||
| std::string input_dtype_item = item.value().at(kDtype); | std::string input_dtype_item = item.value().at(kDtype); | ||||
| @@ -628,5 +615,4 @@ void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) | |||||
| MS_LOG(INFO) << "Support format: " << print_str; | MS_LOG(INFO) << "Support format: " << print_str; | ||||
| } | } | ||||
| } | } | ||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::kernel | |||||
| @@ -43,7 +43,7 @@ class TbeKernelSelect { | |||||
| void GetAgnosticPatternKernelInfo(const OpInfo &op_info); | void GetAgnosticPatternKernelInfo(const OpInfo &op_info); | ||||
| void GetBroadcastPatternKernelInfo(const OpInfo &op_info); | void GetBroadcastPatternKernelInfo(const OpInfo &op_info); | ||||
| void GetReducePatternKernelInfo(const OpInfo &op_info); | void GetReducePatternKernelInfo(const OpInfo &op_info); | ||||
| void FilterInVaildKernelInfo(); | |||||
| void FilterInVaildKernelInfo(const OpInfo &op_info); | |||||
| bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); | bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); | ||||
| static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format); | static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format); | ||||
| bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); | bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); | ||||
| @@ -23,6 +23,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("assign_add") \ | .kernel_name("assign_add") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .input(0, "ref", False, "required", "all") \ | .input(0, "ref", False, "required", "all") \ | ||||
| .input(1, "value", False, "required", "all") \ | .input(1, "value", False, "required", "all") \ | ||||
| .output(0, "ref", False, "required", "all") \ | .output(0, "ref", False, "required", "all") \ | ||||
| @@ -25,6 +25,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \ | |||||
| .attr("transpose_x1", "required", "bool", "all") \ | .attr("transpose_x1", "required", "bool", "all") \ | ||||
| .attr("transpose_x2", "required", "bool", "all") \ | .attr("transpose_x2", "required", "bool", "all") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .input(0, "x1", False, "required", "all") \ | .input(0, "x1", False, "required", "all") \ | ||||
| .input(1, "x2", False, "required", "all") \ | .input(1, "x2", False, "required", "all") \ | ||||
| .input(2, "bias", False, "optional", "all") \ | .input(2, "bias", False, "optional", "all") \ | ||||
| @@ -24,6 +24,7 @@ cast_op_info = TBERegOp("Cast") \ | |||||
| .kernel_name("cast") \ | .kernel_name("cast") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("dst_type", "required", "int", "all") \ | .attr("dst_type", "required", "int", "all") \ | ||||
| .need_check_supported(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("formatAgnostic") \ | .op_pattern("formatAgnostic") \ | ||||
| @@ -23,6 +23,7 @@ cast_ds_op_info = TBERegOp("Cast") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("cast") \ | .kernel_name("cast") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("dst_type", "required", "int", "all") \ | .attr("dst_type", "required", "int", "all") \ | ||||
| .dynamic_shape(True)\ | .dynamic_shape(True)\ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ fill_d_op_info = TBERegOp("Fill") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("fill_d") \ | .kernel_name("fill_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("dims", "required", "listInt", "all") \ | .attr("dims", "required", "listInt", "all") \ | ||||
| .input(0, "value", False, "required", "all") \ | .input(0, "value", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ in_top_k_op_info = TBERegOp("InTopK") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("in_top_k") \ | .kernel_name("in_top_k") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("k", "required", "int", "all") \ | .attr("k", "required", "int", "all") \ | ||||
| .input(0, "x1", False, "required", "all") \ | .input(0, "x1", False, "required", "all") \ | ||||
| .input(1, "x2", False, "required", "all") \ | .input(1, "x2", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ inplace_update_op_info = TBERegOp("InplaceUpdate") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("inplace_update_d") \ | .kernel_name("inplace_update_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("indices", "required", "listInt", "all") \ | .attr("indices", "required", "listInt", "all") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .input(1, "v", False, "required", "all") \ | .input(1, "v", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ matmul_op_info = TBERegOp("MatMul") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("mat_mul") \ | .kernel_name("mat_mul") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("transpose_x1", "required", "bool", "all") \ | .attr("transpose_x1", "required", "bool", "all") \ | ||||
| .attr("transpose_x2", "required", "bool", "all") \ | .attr("transpose_x2", "required", "bool", "all") \ | ||||
| .attr("offset_x", "optional", "int", "all") \ | .attr("offset_x", "optional", "int", "all") \ | ||||
| @@ -23,6 +23,7 @@ pack_op_info = TBERegOp("Pack") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("pack") \ | .kernel_name("pack") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("axis", "optional", "int", "all") \ | .attr("axis", "optional", "int", "all") \ | ||||
| .input(0, "x", False, "dynamic", "all") \ | .input(0, "x", False, "dynamic", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ reduce_prod_op_info = TBERegOp("ReduceProd") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("reduce_prod_d") \ | .kernel_name("reduce_prod_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("axis", "required", "listInt", "all") \ | .attr("axis", "required", "listInt", "all") \ | ||||
| .attr("keep_dims", "optional", "bool", "all") \ | .attr("keep_dims", "optional", "bool", "all") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ resize_bilinear_op_info = TBERegOp("ResizeBilinear") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("resize_bilinear_v2_d") \ | .kernel_name("resize_bilinear_v2_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("size", "required", "listInt", "all") \ | .attr("size", "required", "listInt", "all") \ | ||||
| .attr("align_corners", "optional", "bool", "all") \ | .attr("align_corners", "optional", "bool", "all") \ | ||||
| .attr("half_pixel_centers", "optional", "bool", "all") \ | .attr("half_pixel_centers", "optional", "bool", "all") \ | ||||
| @@ -23,6 +23,7 @@ resize_bilinear_grad_op_info = TBERegOp("ResizeBilinearGrad") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("resize_bilinear_v2_grad") \ | .kernel_name("resize_bilinear_v2_grad") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("align_corners", "optional", "bool", "all") \ | .attr("align_corners", "optional", "bool", "all") \ | ||||
| .attr("half_pixel_centers", "optional", "bool", "all")\ | .attr("half_pixel_centers", "optional", "bool", "all")\ | ||||
| .input(0, "grads", False, "required", "all") \ | .input(0, "grads", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ resize_nearest_neighbor_op_info = TBERegOp("ResizeNearestNeighbor") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("resize_nearest_neighbor_v2_d") \ | .kernel_name("resize_nearest_neighbor_v2_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("size", "required", "listInt", "all") \ | .attr("size", "required", "listInt", "all") \ | ||||
| .attr("align_corners", "optional", "bool", "all") \ | .attr("align_corners", "optional", "bool", "all") \ | ||||
| .input(0, "images", False, "required", "all") \ | .input(0, "images", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ resize_nearest_neighbor_grad_op_info = TBERegOp("ResizeNearestNeighborGrad") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("resize_nearest_neighbor_v2_grad_d") \ | .kernel_name("resize_nearest_neighbor_v2_grad_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("size", "required", "listInt", "all") \ | .attr("size", "required", "listInt", "all") \ | ||||
| .attr("align_corners", "optional", "bool", "all") \ | .attr("align_corners", "optional", "bool", "all") \ | ||||
| .input(0, "grads", False, "required", "all") \ | .input(0, "grads", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ top_k_op_info = TBERegOp("TopK") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("top_k_d") \ | .kernel_name("top_k_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("dim", "optional", "int", "all") \ | .attr("dim", "optional", "int", "all") \ | ||||
| .attr("k", "required", "int", "all") \ | .attr("k", "required", "int", "all") \ | ||||
| .attr("largest", "optional", "bool", "all") \ | .attr("largest", "optional", "bool", "all") \ | ||||
| @@ -23,6 +23,7 @@ unsorted_segment_max_op_info = TBERegOp("UnsortedSegmentMax") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("unsorted_segment_max_d") \ | .kernel_name("unsorted_segment_max_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("num_segments", "required", "int", "all") \ | .attr("num_segments", "required", "int", "all") \ | ||||
| .input(0, "data", False, "required", "all") \ | .input(0, "data", False, "required", "all") \ | ||||
| .input(1, "segment_ids", False, "required", "all") \ | .input(1, "segment_ids", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ unsorted_segment_min_op_info = TBERegOp("UnsortedSegmentMin") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("unsorted_segment_min_d") \ | .kernel_name("unsorted_segment_min_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("num_segments", "required", "int", "all") \ | .attr("num_segments", "required", "int", "all") \ | ||||
| .input(0, "data", False, "required", "all") \ | .input(0, "data", False, "required", "all") \ | ||||
| .input(1, "segment_ids", False, "required", "all") \ | .input(1, "segment_ids", False, "required", "all") \ | ||||
| @@ -23,6 +23,7 @@ unsorted_segment_prod_d_op_info = TBERegOp("UnsortedSegmentProd") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("unsorted_segment_prod_d") \ | .kernel_name("unsorted_segment_prod_d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .need_check_supported(True) \ | |||||
| .attr("num_segments", "required", "int", "all") \ | .attr("num_segments", "required", "int", "all") \ | ||||
| .input(0, "data", False, "required", "all") \ | .input(0, "data", False, "required", "all") \ | ||||
| .input(1, "segment_ids", False, "required", "all") \ | .input(1, "segment_ids", False, "required", "all") \ | ||||
| @@ -355,6 +355,7 @@ class TBERegOp(RegOp): | |||||
| self.reshape_type_ = '' | self.reshape_type_ = '' | ||||
| self.dynamic_format_ = False | self.dynamic_format_ = False | ||||
| self.dynamic_shape_ = False | self.dynamic_shape_ = False | ||||
| self.need_check_supported_ = False | |||||
| self.op_pattern_ = "" | self.op_pattern_ = "" | ||||
| def async_flag(self, async_flag): | def async_flag(self, async_flag): | ||||
| @@ -445,6 +446,17 @@ class TBERegOp(RegOp): | |||||
| self.dynamic_shape_ = dynamic_shape | self.dynamic_shape_ = dynamic_shape | ||||
| return self | return self | ||||
| def need_check_supported(self, need_check_supported): | |||||
| """ | |||||
| Whether the operator need check supports. | |||||
| Args: | |||||
| :param need_check_supported: (bool): Value of need_check_supported. Default: false. | |||||
| """ | |||||
| self._is_bool(need_check_supported) | |||||
| self.need_check_supported_ = need_check_supported | |||||
| return self | |||||
| def op_pattern(self, pattern=None): | def op_pattern(self, pattern=None): | ||||
| """ | """ | ||||
| The behavior type of opeator, such as broadcast, reduce and so on. | The behavior type of opeator, such as broadcast, reduce and so on. | ||||