|
|
|
@@ -22,14 +22,11 @@ |
|
|
|
#include <set> |
|
|
|
#include <functional> |
|
|
|
#include <algorithm> |
|
|
|
#include "backend/kernel_compiler/common_utils.h" |
|
|
|
#include "ir/func_graph.h" |
|
|
|
#include "utils/anf_utils.h" |
|
|
|
#include "backend/kernel_compiler/oplib/oplib.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
|
|
|
|
namespace mindspore::graphkernel { |
|
|
|
using kernel::GetInputIndex; |
|
|
|
using kernel::GetKernelInput; |
|
|
|
using kernel::GetOutputIndex; |
|
|
|
using kernel::OpAttr; |
|
|
|
using kernel::OpImplyType; |
|
|
|
using kernel::OpInfo; |
|
|
|
@@ -45,6 +42,95 @@ std::vector<int64_t> GetDynInputSizes(const AnfNodePtr &anf_node) { |
|
|
|
return dyn_input_sizes; |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
if (index >= AnfUtils::GetInputTensorNum(anf_node)) { |
|
|
|
MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs. Node info : [" |
|
|
|
<< anf_node->DebugString() << "]"; |
|
|
|
} |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr) { |
|
|
|
return AnfUtils::VisitKernel(anf_node, 0); |
|
|
|
} else { |
|
|
|
return AnfUtils::VisitKernel(cnode->input(index + 1), 0); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list, |
|
|
|
const std::vector<AnfNodePtr> &input_list) { |
|
|
|
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index; |
|
|
|
for (size_t i = 0; i < input_list.size(); ++i) { |
|
|
|
auto const &input = input_list[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
auto mng = input->func_graph()->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
|
const NodeUsersMap &users = mng->node_users(); |
|
|
|
auto input_users = users.find(input); |
|
|
|
if (input_users == users.end() || input_users->second.empty()) { |
|
|
|
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" |
|
|
|
<< input->func_graph()->ToString() << "] has no users."; |
|
|
|
} |
|
|
|
bool found = false; |
|
|
|
for (auto const &input_user : input_users->second) { |
|
|
|
for (auto const &anf_node : node_list) { |
|
|
|
if (anf_node != input_user.first) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto dyn_input_sizes = GetDynInputSizes(anf_node); |
|
|
|
if (dyn_input_sizes.empty()) { |
|
|
|
input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0))); |
|
|
|
found = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
int used_as_idx = input_user.second - 1; |
|
|
|
int accum_idx = 0; |
|
|
|
for (size_t dyn_i = 0; dyn_i < dyn_input_sizes.size(); ++dyn_i) { |
|
|
|
accum_idx += LongToInt(dyn_input_sizes[dyn_i]); |
|
|
|
if (used_as_idx < accum_idx) { |
|
|
|
input_index.push_back(std::make_pair( |
|
|
|
anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i]))))); |
|
|
|
found = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (found) break; |
|
|
|
} |
|
|
|
if (found) break; |
|
|
|
} |
|
|
|
if (found) continue; |
|
|
|
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" |
|
|
|
<< input->func_graph()->ToString() << "] found no related kernel info."; |
|
|
|
} |
|
|
|
return input_index; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list, |
|
|
|
const std::vector<AnfNodePtr> &input_list, |
|
|
|
const std::vector<AnfNodePtr> &output_list) { |
|
|
|
std::vector<std::pair<AnfNodePtr, size_t>> output_index; |
|
|
|
for (size_t i = 0; i < output_list.size(); ++i) { |
|
|
|
auto const &output = output_list[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
bool found = false; |
|
|
|
auto pree_node = AnfUtils::VisitKernel(output, 0); |
|
|
|
auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first); |
|
|
|
if (pos != std::end(node_list)) { |
|
|
|
output_index.push_back(pree_node); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first); |
|
|
|
if (ret != std::end(input_list)) { |
|
|
|
output_index.push_back(std::make_pair(pree_node.first, 0)); |
|
|
|
found = true; |
|
|
|
} |
|
|
|
if (!found) { |
|
|
|
MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of [" |
|
|
|
<< output->func_graph()->ToString() << "] found no related kernel info."; |
|
|
|
} |
|
|
|
} |
|
|
|
return output_index; |
|
|
|
} |
|
|
|
|
|
|
|
class OpInfoExtractor { |
|
|
|
public: |
|
|
|
OpInfoExtractor() = default; |
|
|
|
@@ -54,7 +140,7 @@ class OpInfoExtractor { |
|
|
|
cnode_ = anf_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode_); |
|
|
|
auto op_info = std::make_shared<OpInfo>(); |
|
|
|
op_info->set_op_name(AnfAlgo::GetCNodeName(cnode_)); |
|
|
|
op_info->set_op_name(AnfUtils::GetCNodeName(cnode_)); |
|
|
|
op_info->set_imply_type(OpImplyType::kAKG); |
|
|
|
ExtractInputs(op_info); |
|
|
|
ExtractOutputs(op_info); |
|
|
|
@@ -82,7 +168,7 @@ class OpInfoExtractor { |
|
|
|
} |
|
|
|
|
|
|
|
void ExtractOutputs(const OpInfoPtr &op_info) const { |
|
|
|
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_); |
|
|
|
size_t output_tensor_num = AnfUtils::GetOutputTensorNum(cnode_); |
|
|
|
for (size_t i = 0; i < output_tensor_num; i++) { |
|
|
|
auto io_info = std::make_shared<OpIOInfo>(); |
|
|
|
io_info->set_name("output_" + std::to_string(i)); |
|
|
|
@@ -257,7 +343,7 @@ bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, co |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
MS_EXCEPTION_IF_NULL(outputs_json); |
|
|
|
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node); |
|
|
|
size_t output_tensor_num = AnfUtils::GetOutputTensorNum(anf_node); |
|
|
|
|
|
|
|
auto outputs = op_info->outputs_ptr(); |
|
|
|
for (size_t i = 0; i < output_tensor_num; i++) { |
|
|
|
@@ -316,7 +402,8 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std:: |
|
|
|
} else if (type == "listStr") { |
|
|
|
std::vector<std::string> data_format; |
|
|
|
if (op_attr->name() == kJsonKeyDataformat) { |
|
|
|
size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); |
|
|
|
size_t tensor_args_num = |
|
|
|
!dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfUtils::GetInputTensorNum(anf_node); |
|
|
|
for (size_t format_i = 0; format_i < tensor_args_num; format_i++) { |
|
|
|
auto input_format = this->cb_->GetInputFormat(anf_node, format_i); |
|
|
|
data_format.push_back(input_format); |
|
|
|
@@ -486,7 +573,11 @@ OpInfoPtr AkgKernelJsonGenerator::ExtractOpInfo(const AnfNodePtr &anf_node) cons |
|
|
|
if (dump_option_.extract_opinfo_from_anfnode) { |
|
|
|
return OpInfoExtractor().Run(anf_node); |
|
|
|
} else { |
|
|
|
return mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(anf_node), OpImplyType::kAKG); |
|
|
|
#ifdef MSLITE_ENABLE_GRAPH_KERNEL |
|
|
|
MS_LOG(EXCEPTION) << "OpLib is not supported."; |
|
|
|
#else |
|
|
|
return kernel::OpLib::FindOp(AnfUtils::GetCNodeName(anf_node), OpImplyType::kAKG); |
|
|
|
#endif |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -578,7 +669,7 @@ bool AkgKernelJsonGenerator::GetIOSize(const nlohmann::json &node_json, std::vec |
|
|
|
bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_json); |
|
|
|
std::string op_name = AnfAlgo::GetCNodeName(anf_node); |
|
|
|
std::string op_name = AnfUtils::GetCNodeName(anf_node); |
|
|
|
MS_LOG(DEBUG) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope(); |
|
|
|
is_basic_op_ = true; |
|
|
|
if (!GenerateSingleKernelJson(anf_node, kernel_json)) { |
|
|
|
@@ -730,7 +821,7 @@ void AkgKernelJsonGenerator::UpdateTensorName(const std::vector<AnfNodePtr> &anf |
|
|
|
for (auto const &anf_node : anf_nodes) { |
|
|
|
auto dyn_input_sizes = GetDynInputSizes(anf_node); |
|
|
|
bool is_dynamic_input = !dyn_input_sizes.empty(); |
|
|
|
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); |
|
|
|
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfUtils::GetInputTensorNum(anf_node); |
|
|
|
size_t real_input_index = 0; |
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
size_t input_tensor_num = is_dynamic_input ? LongToSize(dyn_input_sizes[i]) : 1; |
|
|
|
|