Browse Source

decouple akg_kernel_json_generator

tags/v1.6.0
dayschan 4 years ago
parent
commit
7cc4e170cc
3 changed files with 103 additions and 95 deletions
  1. +103
    -12
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc
  2. +0
    -80
      mindspore/ccsrc/backend/kernel_compiler/common_utils.cc
  3. +0
    -3
      mindspore/ccsrc/backend/kernel_compiler/common_utils.h

+ 103
- 12
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc View File

@@ -22,14 +22,11 @@
#include <set>
#include <functional>
#include <algorithm>
#include "backend/kernel_compiler/common_utils.h"
#include "ir/func_graph.h"
#include "utils/anf_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore::graphkernel {
using kernel::GetInputIndex;
using kernel::GetKernelInput;
using kernel::GetOutputIndex;
using kernel::OpAttr;
using kernel::OpImplyType;
using kernel::OpInfo;
@@ -45,6 +42,95 @@ std::vector<int64_t> GetDynInputSizes(const AnfNodePtr &anf_node) {
return dyn_input_sizes;
}

std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (index >= AnfUtils::GetInputTensorNum(anf_node)) {
MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs. Node info : ["
<< anf_node->DebugString() << "]";
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode == nullptr) {
return AnfUtils::VisitKernel(anf_node, 0);
} else {
return AnfUtils::VisitKernel(cnode->input(index + 1), 0);
}
}

std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list) {
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
for (size_t i = 0; i < input_list.size(); ++i) {
auto const &input = input_list[i];
MS_EXCEPTION_IF_NULL(input);
auto mng = input->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
const NodeUsersMap &users = mng->node_users();
auto input_users = users.find(input);
if (input_users == users.end() || input_users->second.empty()) {
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
<< input->func_graph()->ToString() << "] has no users.";
}
bool found = false;
for (auto const &input_user : input_users->second) {
for (auto const &anf_node : node_list) {
if (anf_node != input_user.first) {
continue;
}
auto dyn_input_sizes = GetDynInputSizes(anf_node);
if (dyn_input_sizes.empty()) {
input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
found = true;
break;
}
int used_as_idx = input_user.second - 1;
int accum_idx = 0;
for (size_t dyn_i = 0; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
accum_idx += LongToInt(dyn_input_sizes[dyn_i]);
if (used_as_idx < accum_idx) {
input_index.push_back(std::make_pair(
anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
found = true;
break;
}
}
if (found) break;
}
if (found) break;
}
if (found) continue;
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
<< input->func_graph()->ToString() << "] found no related kernel info.";
}
return input_index;
}

std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {
std::vector<std::pair<AnfNodePtr, size_t>> output_index;
for (size_t i = 0; i < output_list.size(); ++i) {
auto const &output = output_list[i];
MS_EXCEPTION_IF_NULL(output);
bool found = false;
auto pree_node = AnfUtils::VisitKernel(output, 0);
auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
if (pos != std::end(node_list)) {
output_index.push_back(pree_node);
continue;
}
auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
if (ret != std::end(input_list)) {
output_index.push_back(std::make_pair(pree_node.first, 0));
found = true;
}
if (!found) {
MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
<< output->func_graph()->ToString() << "] found no related kernel info.";
}
}
return output_index;
}

class OpInfoExtractor {
public:
OpInfoExtractor() = default;
@@ -54,7 +140,7 @@ class OpInfoExtractor {
cnode_ = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode_);
auto op_info = std::make_shared<OpInfo>();
op_info->set_op_name(AnfAlgo::GetCNodeName(cnode_));
op_info->set_op_name(AnfUtils::GetCNodeName(cnode_));
op_info->set_imply_type(OpImplyType::kAKG);
ExtractInputs(op_info);
ExtractOutputs(op_info);
@@ -82,7 +168,7 @@ class OpInfoExtractor {
}

void ExtractOutputs(const OpInfoPtr &op_info) const {
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_);
size_t output_tensor_num = AnfUtils::GetOutputTensorNum(cnode_);
for (size_t i = 0; i < output_tensor_num; i++) {
auto io_info = std::make_shared<OpIOInfo>();
io_info->set_name("output_" + std::to_string(i));
@@ -257,7 +343,7 @@ bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, co
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_info);
MS_EXCEPTION_IF_NULL(outputs_json);
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node);
size_t output_tensor_num = AnfUtils::GetOutputTensorNum(anf_node);

auto outputs = op_info->outputs_ptr();
for (size_t i = 0; i < output_tensor_num; i++) {
@@ -316,7 +402,8 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
} else if (type == "listStr") {
std::vector<std::string> data_format;
if (op_attr->name() == kJsonKeyDataformat) {
size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
size_t tensor_args_num =
!dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfUtils::GetInputTensorNum(anf_node);
for (size_t format_i = 0; format_i < tensor_args_num; format_i++) {
auto input_format = this->cb_->GetInputFormat(anf_node, format_i);
data_format.push_back(input_format);
@@ -486,7 +573,11 @@ OpInfoPtr AkgKernelJsonGenerator::ExtractOpInfo(const AnfNodePtr &anf_node) cons
if (dump_option_.extract_opinfo_from_anfnode) {
return OpInfoExtractor().Run(anf_node);
} else {
return mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(anf_node), OpImplyType::kAKG);
#ifdef MSLITE_ENABLE_GRAPH_KERNEL
MS_LOG(EXCEPTION) << "OpLib is not supported.";
#else
return kernel::OpLib::FindOp(AnfUtils::GetCNodeName(anf_node), OpImplyType::kAKG);
#endif
}
}

@@ -578,7 +669,7 @@ bool AkgKernelJsonGenerator::GetIOSize(const nlohmann::json &node_json, std::vec
bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(kernel_json);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
std::string op_name = AnfUtils::GetCNodeName(anf_node);
MS_LOG(DEBUG) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope();
is_basic_op_ = true;
if (!GenerateSingleKernelJson(anf_node, kernel_json)) {
@@ -730,7 +821,7 @@ void AkgKernelJsonGenerator::UpdateTensorName(const std::vector<AnfNodePtr> &anf
for (auto const &anf_node : anf_nodes) {
auto dyn_input_sizes = GetDynInputSizes(anf_node);
bool is_dynamic_input = !dyn_input_sizes.empty();
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfUtils::GetInputTensorNum(anf_node);
size_t real_input_index = 0;
for (size_t i = 0; i < input_num; ++i) {
size_t input_tensor_num = is_dynamic_input ? LongToSize(dyn_input_sizes[i]) : 1;


+ 0
- 80
mindspore/ccsrc/backend/kernel_compiler/common_utils.cc View File

@@ -541,86 +541,6 @@ int GetReductionInt(const std::string &reduction) {
}
}

std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);

if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs. Node info : ["
<< anf_node->DebugString() << "]";
}

auto cnode = anf_node->cast<CNodePtr>();
if (cnode == nullptr) {
return AnfAlgo::VisitKernel(anf_node, 0);
} else {
return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
}
}

std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list) {
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
for (size_t i = 0; i < input_list.size(); ++i) {
auto const &input = input_list[i];
MS_EXCEPTION_IF_NULL(input);
bool found = false;
auto mng = input->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
const NodeUsersMap &users = mng->node_users();
auto input_users = users.find(input);
if (input_users == users.end() || input_users->second.empty()) {
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
<< input->func_graph()->ToString() << "] has no users.";
}

for (auto const &input_user : input_users->second) {
for (auto const &anf_node : node_list) {
if (anf_node != input_user.first) {
continue;
}

std::vector<int64_t> dyn_input_sizes;
auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int64_t>>(prim->GetAttr(kAttrDynInputSizes));
}

if (dyn_input_sizes.empty()) {
input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
found = true;
break;
} else {
int used_as_idx = input_user.second - 1;
int accum_idx = 0;
size_t dyn_i = 0;
for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
accum_idx += LongToInt(dyn_input_sizes[dyn_i]);
if (used_as_idx < accum_idx) {
input_index.push_back(std::make_pair(
anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
break;
}
}
if (dyn_i != dyn_input_sizes.size()) {
found = true;
break;
}
}
}
if (found) {
break;
}
}

if (!found) {
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
<< input->func_graph()->ToString() << "] found no related kernel info.";
}
}
return input_index;
}

std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {


+ 0
- 3
mindspore/ccsrc/backend/kernel_compiler/common_utils.h View File

@@ -88,9 +88,6 @@ Processor GetProcessor(const string &processor);
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
int Sign(float x);
int GetReductionInt(const std::string &reduction);
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list);
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);


Loading…
Cancel
Save