Browse Source

add checksupport func

tags/v1.2.0-rc1
jjfeing 4 years ago
parent
commit
4f0ecf9857
23 changed files with 49 additions and 27 deletions
  1. +4
    -0
      mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h
  2. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc
  3. +12
    -26
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc
  4. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h
  5. +1
    -0
      mindspore/ops/_op_impl/tbe/assign_add.py
  6. +1
    -0
      mindspore/ops/_op_impl/tbe/batch_matmul.py
  7. +1
    -0
      mindspore/ops/_op_impl/tbe/cast.py
  8. +1
    -0
      mindspore/ops/_op_impl/tbe/cast_ds.py
  9. +1
    -0
      mindspore/ops/_op_impl/tbe/fill.py
  10. +1
    -0
      mindspore/ops/_op_impl/tbe/in_top_k.py
  11. +1
    -0
      mindspore/ops/_op_impl/tbe/inplace_update.py
  12. +1
    -0
      mindspore/ops/_op_impl/tbe/matmul.py
  13. +1
    -0
      mindspore/ops/_op_impl/tbe/pack.py
  14. +1
    -0
      mindspore/ops/_op_impl/tbe/reduce_prod.py
  15. +1
    -0
      mindspore/ops/_op_impl/tbe/resize_bilinear.py
  16. +1
    -0
      mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py
  17. +1
    -0
      mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py
  18. +1
    -0
      mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py
  19. +1
    -0
      mindspore/ops/_op_impl/tbe/top_k.py
  20. +1
    -0
      mindspore/ops/_op_impl/tbe/unsorted_segment_max.py
  21. +1
    -0
      mindspore/ops/_op_impl/tbe/unsorted_segment_min.py
  22. +1
    -0
      mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py
  23. +12
    -0
      mindspore/ops/op_info_register.py

+ 4
- 0
mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h View File

@@ -105,6 +105,7 @@ class OpInfo {
dynamic_shape_ = opinfo.dynamic_shape_;
op_pattern_ = opinfo.op_pattern();
processor_ = opinfo.processor_;
need_check_supported_ = opinfo.need_check_supported();
for (const auto &attr : opinfo.attrs_ptr()) {
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
}
@@ -125,6 +126,7 @@ class OpInfo {
OpPattern op_pattern() const { return op_pattern_; }
bool dynamic_shape() const { return dynamic_shape_; }
std::string processor() const { return processor_; }
bool need_check_supported() const { return need_check_supported_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
@@ -142,6 +144,7 @@ class OpInfo {
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
void set_processor(const std::string &processor) { processor_ = processor; }
void set_need_check_supported(const bool need_check_supported) { need_check_supported_ = need_check_supported; }
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
@@ -168,6 +171,7 @@ class OpInfo {
bool partial_flag_ = false;
bool dynamic_format_ = false;
bool dynamic_shape_ = false;
bool need_check_supported_ = false;
OpPattern op_pattern_ = kCommonPattern;
std::string processor_;
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;


+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc View File

@@ -36,6 +36,7 @@ constexpr auto kReshapeType = "reshape_type";
constexpr auto kOpPattern = "op_pattern";
constexpr auto kDynamicFormat = "dynamicFormat";
constexpr auto kFormatAgnostic = "formatAgnostic";
constexpr auto kNeedCheckSupported = "need_check_supported";
constexpr auto kBroadcast = "broadcast";
constexpr auto kReduce = "reduce";
constexpr auto kDynamicShape = "dynamic_shape";
@@ -111,6 +112,7 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
op_info->set_need_check_supported(obj.at(kNeedCheckSupported));

if (obj.find(kDynamicShape) != obj.end()) {
op_info->set_dynamic_shape(obj.at(kDynamicShape));


+ 12
- 26
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc View File

@@ -31,11 +31,9 @@
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "nlohmann/json.hpp"

namespace mindspore {
namespace kernel {
namespace mindspore::kernel {
constexpr auto kName = "name";
constexpr auto kDtype = "dtype";
constexpr auto kFormat = "format";
@@ -83,7 +81,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
MS_LOG(INFO) << "Warning: op pattern is invailed.";
}
// check support
FilterInVaildKernelInfo();
FilterInVaildKernelInfo(*op_info_ptr);
MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select.";
}

@@ -221,7 +219,7 @@ void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
MS_LOG(INFO) << "end.";
}

void TbeKernelSelect::FilterInVaildKernelInfo() {
void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
if (kernel_info_list_->empty()) {
MS_LOG(INFO) << "Warning: get kernel build info failed.";
return;
@@ -232,9 +230,11 @@ void TbeKernelSelect::FilterInVaildKernelInfo() {
MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*iter)->ToString();
continue;
}
if (!TbeCheckSupported(iter)) {
MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString();
continue;
if (op_info.need_check_supported()) {
if (!TbeCheckSupported(iter)) {
MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString();
continue;
}
}
new_kernel_info_list.emplace_back(*iter);
}
@@ -292,22 +292,9 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const
return true;
}

bool TbeKernelSelect::TbeCheckSupported(
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) {
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL,
parallel::BATCHMATMUL,
parallel::TOPK,
parallel::IN_TOPK,
parallel::PACK,
parallel::UNSORTEF_SEGMENT_MIND,
parallel::UNSORTEF_SEGMENT_PRODD,
parallel::CAST};
auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_);
if (iter == kCheckSupportedOpType.end()) {
return true;
}
bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter) {
MS_LOG(INFO) << "Check support start.";
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
// replace kernel_info with current kernel info
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
@@ -560,7 +547,7 @@ void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
std::string input_format_item = item.value().at(kFormat);
select_input.formats = SplitStrToVec(input_format_item);
inputs.emplace_back(select_input);
} else if (is_output) {
} else {
SelectOpIOInfo select_output;
select_output.name = item.value().at(kName);
std::string input_dtype_item = item.value().at(kDtype);
@@ -628,5 +615,4 @@ void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format)
MS_LOG(INFO) << "Support format: " << print_str;
}
}
} // namespace kernel
} // namespace mindspore
} // namespace mindspore::kernel

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h View File

@@ -43,7 +43,7 @@ class TbeKernelSelect {
void GetAgnosticPatternKernelInfo(const OpInfo &op_info);
void GetBroadcastPatternKernelInfo(const OpInfo &op_info);
void GetReducePatternKernelInfo(const OpInfo &op_info);
void FilterInVaildKernelInfo();
void FilterInVaildKernelInfo(const OpInfo &op_info);
bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter);
static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format);
bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter);


+ 1
- 0
mindspore/ops/_op_impl/tbe/assign_add.py View File

@@ -23,6 +23,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \
.compute_cost(10) \
.kernel_name("assign_add") \
.partial_flag(True) \
.need_check_supported(True) \
.input(0, "ref", False, "required", "all") \
.input(1, "value", False, "required", "all") \
.output(0, "ref", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/batch_matmul.py View File

@@ -25,6 +25,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \
.attr("transpose_x1", "required", "bool", "all") \
.attr("transpose_x2", "required", "bool", "all") \
.partial_flag(True) \
.need_check_supported(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/cast.py View File

@@ -24,6 +24,7 @@ cast_op_info = TBERegOp("Cast") \
.kernel_name("cast") \
.partial_flag(True) \
.attr("dst_type", "required", "int", "all") \
.need_check_supported(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("formatAgnostic") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/cast_ds.py View File

@@ -23,6 +23,7 @@ cast_ds_op_info = TBERegOp("Cast") \
.compute_cost(10) \
.kernel_name("cast") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("dst_type", "required", "int", "all") \
.dynamic_shape(True)\
.input(0, "x", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/fill.py View File

@@ -23,6 +23,7 @@ fill_d_op_info = TBERegOp("Fill") \
.compute_cost(10) \
.kernel_name("fill_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("dims", "required", "listInt", "all") \
.input(0, "value", False, "required", "all") \
.output(0, "y", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/in_top_k.py View File

@@ -23,6 +23,7 @@ in_top_k_op_info = TBERegOp("InTopK") \
.compute_cost(10) \
.kernel_name("in_top_k") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("k", "required", "int", "all") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/inplace_update.py View File

@@ -23,6 +23,7 @@ inplace_update_op_info = TBERegOp("InplaceUpdate") \
.compute_cost(10) \
.kernel_name("inplace_update_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("indices", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "v", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/matmul.py View File

@@ -23,6 +23,7 @@ matmul_op_info = TBERegOp("MatMul") \
.compute_cost(10) \
.kernel_name("mat_mul") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("transpose_x1", "required", "bool", "all") \
.attr("transpose_x2", "required", "bool", "all") \
.attr("offset_x", "optional", "int", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/pack.py View File

@@ -23,6 +23,7 @@ pack_op_info = TBERegOp("Pack") \
.compute_cost(10) \
.kernel_name("pack") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("axis", "optional", "int", "all") \
.input(0, "x", False, "dynamic", "all") \
.output(0, "y", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/reduce_prod.py View File

@@ -23,6 +23,7 @@ reduce_prod_op_info = TBERegOp("ReduceProd") \
.compute_cost(10) \
.kernel_name("reduce_prod_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("axis", "required", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/resize_bilinear.py View File

@@ -23,6 +23,7 @@ resize_bilinear_op_info = TBERegOp("ResizeBilinear") \
.compute_cost(10) \
.kernel_name("resize_bilinear_v2_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("size", "required", "listInt", "all") \
.attr("align_corners", "optional", "bool", "all") \
.attr("half_pixel_centers", "optional", "bool", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py View File

@@ -23,6 +23,7 @@ resize_bilinear_grad_op_info = TBERegOp("ResizeBilinearGrad") \
.compute_cost(10) \
.kernel_name("resize_bilinear_v2_grad") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("align_corners", "optional", "bool", "all") \
.attr("half_pixel_centers", "optional", "bool", "all")\
.input(0, "grads", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py View File

@@ -23,6 +23,7 @@ resize_nearest_neighbor_op_info = TBERegOp("ResizeNearestNeighbor") \
.compute_cost(10) \
.kernel_name("resize_nearest_neighbor_v2_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("size", "required", "listInt", "all") \
.attr("align_corners", "optional", "bool", "all") \
.input(0, "images", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py View File

@@ -23,6 +23,7 @@ resize_nearest_neighbor_grad_op_info = TBERegOp("ResizeNearestNeighborGrad") \
.compute_cost(10) \
.kernel_name("resize_nearest_neighbor_v2_grad_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("size", "required", "listInt", "all") \
.attr("align_corners", "optional", "bool", "all") \
.input(0, "grads", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/top_k.py View File

@@ -23,6 +23,7 @@ top_k_op_info = TBERegOp("TopK") \
.compute_cost(10) \
.kernel_name("top_k_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("dim", "optional", "int", "all") \
.attr("k", "required", "int", "all") \
.attr("largest", "optional", "bool", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/unsorted_segment_max.py View File

@@ -23,6 +23,7 @@ unsorted_segment_max_op_info = TBERegOp("UnsortedSegmentMax") \
.compute_cost(10) \
.kernel_name("unsorted_segment_max_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("num_segments", "required", "int", "all") \
.input(0, "data", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/unsorted_segment_min.py View File

@@ -23,6 +23,7 @@ unsorted_segment_min_op_info = TBERegOp("UnsortedSegmentMin") \
.compute_cost(10) \
.kernel_name("unsorted_segment_min_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("num_segments", "required", "int", "all") \
.input(0, "data", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \


+ 1
- 0
mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py View File

@@ -23,6 +23,7 @@ unsorted_segment_prod_d_op_info = TBERegOp("UnsortedSegmentProd") \
.compute_cost(10) \
.kernel_name("unsorted_segment_prod_d") \
.partial_flag(True) \
.need_check_supported(True) \
.attr("num_segments", "required", "int", "all") \
.input(0, "data", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \


+ 12
- 0
mindspore/ops/op_info_register.py View File

@@ -355,6 +355,7 @@ class TBERegOp(RegOp):
self.reshape_type_ = ''
self.dynamic_format_ = False
self.dynamic_shape_ = False
self.need_check_supported_ = False
self.op_pattern_ = ""

def async_flag(self, async_flag):
@@ -445,6 +446,17 @@ class TBERegOp(RegOp):
self.dynamic_shape_ = dynamic_shape
return self

def need_check_supported(self, need_check_supported):
"""
Whether the operator need check supports.

Args:
:param need_check_supported: (bool): Value of need_check_supported. Default: false.
"""
self._is_bool(need_check_supported)
self.need_check_supported_ = need_check_supported
return self

def op_pattern(self, pattern=None):
"""
The behavior type of opeator, such as broadcast, reduce and so on.


Loading…
Cancel
Save