Browse Source

remove transpose check supported

tags/v1.2.0
liubuyu 5 years ago
parent
commit
03fa52d213
2 changed files with 1 additions and 8 deletions
  1. +0
    -7
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc
  2. +1
    -1
      mindspore/ops/_op_impl/tbe/transpose_d.py

+ 0
- 7
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc View File

@@ -42,8 +42,6 @@ constexpr auto kPrefixOutput = "output";
constexpr char kParamTypeDynamic[] = "dynamic";
constexpr char kParamTypeRequre[] = "required";
constexpr char kParamTypeOptional[] = "optional";
const std::set<TypeId> transpose_unsupported = {kNumberTypeInt8, kNumberTypeUInt8, kNumberTypeBool, kNumberTypeInt64,
kNumberTypeUInt64};
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list);
tbe_selecter.TbeMetadataInfoEx();
@@ -189,7 +187,6 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
}
std::vector<std::shared_ptr<KernelBuildInfo>> new_kernel_info_list;
auto dynamic_inputs = GetNodeDynamicInputs();
auto op_name = AnfAlgo::GetCNodeName(cnode_ptr_);
for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) {
if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) {
continue;
@@ -199,10 +196,6 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
continue;
}
}
if (op_name == kTransposeOpName &&
transpose_unsupported.find((*iter)->GetInputDeviceType(0)) != transpose_unsupported.end()) {
continue;
}
new_kernel_info_list.emplace_back(*iter);
}
(*kernel_info_list_) = new_kernel_info_list;


+ 1
- 1
mindspore/ops/_op_impl/tbe/transpose_d.py View File

@@ -26,7 +26,7 @@ transpose_d_op_info = TBERegOp("Transpose") \
.attr("perm", "optional", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.need_check_supported(False) \
.need_check_supported(True) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \


Loading…
Cancel
Save