|
|
|
@@ -31,11 +31,9 @@ |
|
|
|
#include "backend/optimizer/common/helper.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "backend/session/kernel_build_client.h" |
|
|
|
#include "frontend/parallel/ops_info/ops_utils.h" |
|
|
|
#include "nlohmann/json.hpp" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
namespace mindspore::kernel { |
|
|
|
constexpr auto kName = "name"; |
|
|
|
constexpr auto kDtype = "dtype"; |
|
|
|
constexpr auto kFormat = "format"; |
|
|
|
@@ -83,7 +81,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() { |
|
|
|
MS_LOG(INFO) << "Warning: op pattern is invailed."; |
|
|
|
} |
|
|
|
// check support |
|
|
|
FilterInVaildKernelInfo(); |
|
|
|
FilterInVaildKernelInfo(*op_info_ptr); |
|
|
|
MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -221,7 +219,7 @@ void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { |
|
|
|
MS_LOG(INFO) << "end."; |
|
|
|
} |
|
|
|
|
|
|
|
void TbeKernelSelect::FilterInVaildKernelInfo() { |
|
|
|
void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) { |
|
|
|
if (kernel_info_list_->empty()) { |
|
|
|
MS_LOG(INFO) << "Warning: get kernel build info failed."; |
|
|
|
return; |
|
|
|
@@ -232,9 +230,11 @@ void TbeKernelSelect::FilterInVaildKernelInfo() { |
|
|
|
MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*iter)->ToString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!TbeCheckSupported(iter)) { |
|
|
|
MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString(); |
|
|
|
continue; |
|
|
|
if (op_info.need_check_supported()) { |
|
|
|
if (!TbeCheckSupported(iter)) { |
|
|
|
MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
new_kernel_info_list.emplace_back(*iter); |
|
|
|
} |
|
|
|
@@ -292,22 +292,9 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool TbeKernelSelect::TbeCheckSupported( |
|
|
|
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { |
|
|
|
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); |
|
|
|
static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL, |
|
|
|
parallel::BATCHMATMUL, |
|
|
|
parallel::TOPK, |
|
|
|
parallel::IN_TOPK, |
|
|
|
parallel::PACK, |
|
|
|
parallel::UNSORTEF_SEGMENT_MIND, |
|
|
|
parallel::UNSORTEF_SEGMENT_PRODD, |
|
|
|
parallel::CAST}; |
|
|
|
auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); |
|
|
|
if (iter == kCheckSupportedOpType.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter) { |
|
|
|
MS_LOG(INFO) << "Check support start."; |
|
|
|
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); |
|
|
|
// replace kernel_info with current kernel info |
|
|
|
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); |
|
|
|
@@ -560,7 +547,7 @@ void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, |
|
|
|
std::string input_format_item = item.value().at(kFormat); |
|
|
|
select_input.formats = SplitStrToVec(input_format_item); |
|
|
|
inputs.emplace_back(select_input); |
|
|
|
} else if (is_output) { |
|
|
|
} else { |
|
|
|
SelectOpIOInfo select_output; |
|
|
|
select_output.name = item.value().at(kName); |
|
|
|
std::string input_dtype_item = item.value().at(kDtype); |
|
|
|
@@ -628,5 +615,4 @@ void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) |
|
|
|
MS_LOG(INFO) << "Support format: " << print_str; |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
} // namespace mindspore |
|
|
|
} // namespace mindspore::kernel |