Browse Source

!6240 Fix review_bot and codedex problems

Merge pull request !6240 from DeshiChen/0910_review_bot
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
014ea619a8
15 changed files with 524 additions and 482 deletions
  1. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc
  2. +0
    -2
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h
  3. +267
    -247
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc
  4. +2
    -3
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h
  5. +97
    -105
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc
  6. +8
    -0
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h
  7. +59
    -39
      mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc
  8. +9
    -1
      mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h
  9. +0
    -2
      mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.h
  10. +1
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc
  11. +33
    -36
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc
  12. +2
    -2
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h
  13. +18
    -15
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc
  14. +4
    -2
      mindspore/ccsrc/debug/anf_ir_dump.cc
  15. +24
    -26
      mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc

+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc View File

@@ -199,6 +199,5 @@ void SetAkgKernelAttrs(const AnfNodePtr &anf_node) {
it->second(anf_node); it->second(anf_node);
} }
} }

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h View File

@@ -20,9 +20,7 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {

void SetAkgKernelAttrs(const AnfNodePtr &anf_node); void SetAkgKernelAttrs(const AnfNodePtr &anf_node);

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H

+ 267
- 247
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc View File

@@ -42,153 +42,276 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace { namespace {
ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) {
if (type == "str") {
std::string value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "int") {
int value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "bool") {
bool value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "float") {
float value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "listInt") {
std::vector<int> value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "listStr") {
std::vector<std::string> value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else {
MS_LOG(ERROR) << "Unknown type of attr: " << type << ", json: \n" << attr_json;
return nullptr;
}
}
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";


bool DecodeAttrs(const nlohmann::json &attrs_json, std::map<std::string, ValuePtr> *attrs) {
MS_EXCEPTION_IF_NULL(attrs);
MS_LOG(DEBUG) << "start decode attrs, " << attrs_json;
// decode attrs.
if (attrs_json.find(kJsonKeyAttr) == attrs_json.end() || attrs_json[kJsonKeyAttr].is_null()) {
// attrs maybe empty
return true;
class CNodeDecoder {
public:
explicit CNodeDecoder(std::map<std::string, AnfNodePtr> *nodes_map) : nodes_map_(*nodes_map) {}
~CNodeDecoder() = default;
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, kernel::Processor processor) {
MS_LOG(DEBUG) << "start decode cnode, " << cnode_json;
// decode attrs.
if (!DecodeAttrs(cnode_json)) {
MS_LOG(ERROR) << "Decode attrs failed.";
return nullptr;
}
if (!DecodeInputDesc(cnode_json, func_graph) || cnode_ == nullptr) {
MS_LOG(ERROR) << "Decode inputs failed.";
return nullptr;
}
if (!DecodeOutputDesc(cnode_json, func_graph)) {
MS_LOG(ERROR) << "Decode outputs failed.";
return nullptr;
}
CreateKernelInfo(processor);
return cnode_;
} }


std::vector<nlohmann::json> attr_descs = attrs_json[kJsonKeyAttr];
for (const auto &attr_desc : attr_descs) {
std::string name = attr_desc[kJsonKeyName];
std::string type = attr_desc[kJsonKeyDataType];
auto value = ParseValue(attr_desc, type);
if (value == nullptr) {
return false;
private:
ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) {
if (type == "str") {
std::string value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "int") {
int value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "bool") {
bool value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "float") {
float value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "listInt") {
std::vector<int> value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "listStr") {
std::vector<std::string> value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else {
MS_LOG(ERROR) << "Unknown type of attr: " << type << ", json: \n" << attr_json;
return nullptr;
} }
(*attrs)[name] = value;
} }


return true;
}
bool DecodeAttrs(const nlohmann::json &attrs_json) {
MS_LOG(DEBUG) << "start decode attrs, " << attrs_json;
// attrs maybe empty
if (attrs_json.find(kJsonKeyAttr) == attrs_json.end() || attrs_json[kJsonKeyAttr].is_null()) {
return true;
}


// python utils.
constexpr auto kGetPythonOpFunc = "_get_python_op";
constexpr auto kParallelUtilsModule = "mindspore.parallel._utils";
// almost all ops are defined in this path.
constexpr auto kOperationsModule = "mindspore.ops.operations";
std::vector<nlohmann::json> attr_descs = attrs_json[kJsonKeyAttr];
for (const auto &attr_desc : attr_descs) {
std::string name = attr_desc[kJsonKeyName];
std::string type = attr_desc[kJsonKeyDataType];
auto value = ParseValue(attr_desc, type);
if (value == nullptr) {
return false;
}
cnode_attrs_[name] = value;
}
return true;
}


const std::map<std::string, std::vector<std::string>> op_attrs_map = {
{kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}},
};
bool DecodeInputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) {
std::string op_name = cnode_json[kJsonKeyName];
// new primitive.
auto primitive = GetPrimitive(op_name);
if (primitive == nullptr) {
MS_LOG(ERROR) << "Create primitive failed.";
return false;
}


ValuePtr CreatOpInstance(const std::string &op_name, const std::vector<ValuePtr> &attrs) {
py::module mod = py::module::import(kOperationsModule);
if (!py::hasattr(mod, op_name.c_str())) {
MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name;
return nullptr;
// collect inputs.
auto primitive_v = NewValueNode(primitive);
func_graph->AddValueNode(primitive_v);
std::vector<AnfNodePtr> inputs{primitive_v};
std::vector<nlohmann::json> input_descs = cnode_json[kJsonKeyInputDesc];
for (size_t i = 0; i < input_descs.size(); ++i) {
nlohmann::json input_desc = input_descs[i][0];
std::string name = input_desc[kJsonKeyTensorName];
if (input_desc.find(kJsonKeyValue) != input_desc.end()) {
inputs.push_back(DecodeValueNode(input_desc, func_graph));
} else if (nodes_map_.count(name) == 0) {
MS_LOG(ERROR) << "Input: " << name << " of: " << op_name << " not found.";
return false;
} else {
inputs.push_back(nodes_map_[name]);
}
input_formats_.push_back(input_desc[kJsonKeyFormat]);
input_types_.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType]));
}
// new cnode.
cnode_ = func_graph->NewCNode(inputs);
func_graph->AddNode(cnode_);
return true;
} }
std::vector<py::object> arg_list;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
[](const ValuePtr &attr) { return ValuePtrToPyData(attr); });
py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule,
op_name, arg_list);
ValuePtr op_instance = nullptr;
bool succ = parse::ConvertData(obj, &op_instance);
if (!succ) {
MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed.";
return nullptr;

bool DecodeOutputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) {
std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc];
AbstractBasePtr abstract(nullptr);
if (output_descs.empty()) {
MS_LOG(ERROR) << "No outputs found.";
return false;
} else if (output_descs.size() == 1) {
// single output.
nlohmann::json output_desc = output_descs[0];
output_formats_.push_back(output_desc[kJsonKeyFormat]);
output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
nodes_map_[output_desc[kJsonKeyTensorName]] = cnode_;
} else {
// multi outputs.
for (size_t j = 0; j < output_descs.size(); ++j) {
nlohmann::json output_desc = output_descs[j];
output_formats_.push_back(output_desc[kJsonKeyFormat]);
output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
auto get_item =
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_, NewValueNode(SizeToInt(j))});
func_graph->AddNode(get_item);
nodes_map_[output_desc[kJsonKeyTensorName]] = get_item;
}
}
return true;
} }
return op_instance;
}


PrimitivePtr GetPrimitive(const std::string &op_name, const std::map<std::string, ValuePtr> &attrs_val) {
PrimitivePtr primitive{nullptr};
if (op_attrs_map.count(op_name) == 0) {
// no attrs for op instance.
primitive = CreatOpInstance(op_name, std::vector<ValuePtr>{})->cast<PrimitivePtr>();
} else {
// make attrs for op instance.
std::vector<ValuePtr> op_attrs;
const auto &attr_names = op_attrs_map.at(op_name);
for (const auto &attr_name : attr_names) {
if (attrs_val.count(attr_name) == 0) {
MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found.";
return nullptr;
void CreateKernelInfo(kernel::Processor processor) {
auto kernel_info = std::make_shared<device::KernelInfo>();
std::vector<size_t> feature_map_input_indexs;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
const auto &inputs = cnode_->inputs();
for (size_t index = 1; index < inputs.size(); ++index) {
auto node = AnfAlgo::VisitKernel(inputs[index], 0);
if (AnfAlgo::IsFeatureMapOutput(node.first)) {
feature_map_input_indexs.push_back(index);
} }
op_attrs.push_back(attrs_val.at(attr_name));
} }
primitive = CreatOpInstance(op_name, op_attrs)->cast<PrimitivePtr>();
if (AnfAlgo::GetCNodeName(cnode_) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode_);
}
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealCNodeKernel(cnode_)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode_);
}
cnode_->set_kernel_info(kernel_info);
// create kernel_build_info.
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetInputsFormat(input_formats_);
builder->SetInputsDeviceType(input_types_);
builder->SetOutputsFormat(output_formats_);
builder->SetOutputsDeviceType(output_types_);
builder->SetProcessor(processor);
builder->SetKernelType(KernelType::AKG_KERNEL);
builder->SetFusionType(kernel::FusionType::OPAQUE);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get());
} }


if (primitive != nullptr) {
for (const auto &attr : attrs_val) {
primitive->AddAttr(attr.first, attr.second);
ValuePtr CreatOpInstance(const std::string &op_name, const std::vector<ValuePtr> &attrs) {
// python utils.
constexpr auto kGetPythonOpFunc = "_get_python_op";
constexpr auto kParallelUtilsModule = "mindspore.parallel._utils";
// almost all ops are defined in this path.
constexpr auto kOperationsModule = "mindspore.ops.operations";
py::module mod = py::module::import(kOperationsModule);
if (!py::hasattr(mod, op_name.c_str())) {
MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name;
return nullptr;
} }
std::vector<py::object> arg_list;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
[](const ValuePtr &attr) { return ValuePtrToPyData(attr); });
py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule,
op_name, arg_list);
ValuePtr op_instance = nullptr;
bool succ = parse::ConvertData(obj, &op_instance);
if (!succ) {
MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed.";
return nullptr;
}
return op_instance;
} }


return primitive;
}
} // namespace
const std::map<std::string, std::vector<std::string>> op_attrs_map_ = {
{kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}},
};


constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
PrimitivePtr GetPrimitive(const std::string &op_name) {
PrimitivePtr primitive{nullptr};
if (op_attrs_map_.count(op_name) == 0) {
// no attrs for op instance.
primitive = CreatOpInstance(op_name, std::vector<ValuePtr>{})->cast<PrimitivePtr>();
} else {
// make attrs for op instance.
std::vector<ValuePtr> op_attrs;
const auto &attr_names = op_attrs_map_.at(op_name);
for (const auto &attr_name : attr_names) {
if (cnode_attrs_.count(attr_name) == 0) {
MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found.";
return nullptr;
}
op_attrs.push_back(cnode_attrs_.at(attr_name));
}
primitive = CreatOpInstance(op_name, op_attrs)->cast<PrimitivePtr>();
}
if (primitive != nullptr) {
for (const auto &attr : cnode_attrs_) {
primitive->AddAttr(attr.first, attr.second);
}
}
return primitive;
}


ScalarPtr AkgKernelJsonDecoder::DecodeScalar(const nlohmann::json &scalar_json) {
auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]);
switch (type_id) {
case kNumberTypeFloat16:
case kNumberTypeFloat32:
return std::make_shared<FP32Imm>(scalar_json[kJsonKeyValue]);
case kNumberTypeInt32:
return std::make_shared<Int32Imm>(scalar_json[kJsonKeyValue]);
default:
MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType];
break;
ScalarPtr DecodeScalar(const nlohmann::json &scalar_json) {
auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]);
switch (type_id) {
case kNumberTypeFloat16:
case kNumberTypeFloat32:
return std::make_shared<FP32Imm>(scalar_json[kJsonKeyValue]);
case kNumberTypeInt32:
return std::make_shared<Int32Imm>(scalar_json[kJsonKeyValue]);
default:
MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType];
break;
}
return nullptr;
} }
return nullptr;
}


ValueNodePtr AkgKernelJsonDecoder::DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) {
MS_LOG(DEBUG) << "start decode value node, " << value_json;
auto scalar = DecodeScalar(value_json);
auto tensor = ScalarToTensor(scalar);
ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) {
MS_LOG(DEBUG) << "start decode value node, " << value_json;
auto scalar = DecodeScalar(value_json);
auto tensor = ScalarToTensor(scalar);


auto value_node = std::make_shared<ValueNode>(tensor);
value_node->set_abstract(tensor->ToAbstract());
// create kernel_info fo new value node.
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node.
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// layout info.
builder->SetOutputsFormat(std::vector<std::string>{value_json[kJsonKeyFormat]});
builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(value_json[kJsonKeyDataType])});
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
func_graph->AddValueNode(value_node);
MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2);
return value_node;
}
auto value_node = std::make_shared<ValueNode>(tensor);
value_node->set_abstract(tensor->ToAbstract());
// create kernel_info fo new value node.
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node.
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// layout info.
builder->SetOutputsFormat(std::vector<std::string>{value_json[kJsonKeyFormat]});
builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(value_json[kJsonKeyDataType])});
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
func_graph->AddValueNode(value_node);
MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2);
return value_node;
}

std::map<std::string, AnfNodePtr> &nodes_map_;
std::map<std::string, ValuePtr> cnode_attrs_;
std::vector<std::string> input_formats_;
std::vector<std::string> output_formats_;
std::vector<TypeId> input_types_;
std::vector<TypeId> output_types_;
CNodePtr cnode_{nullptr};
};
} // namespace


ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json &parameter_json, ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json &parameter_json,
const FuncGraphPtr &func_graph) { const FuncGraphPtr &func_graph) {
@@ -208,118 +331,35 @@ ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json &paramet


CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph,
const std::string &processor) { const std::string &processor) {
CNodeDecoder decoder(&nodes_map_);
Processor p = kernel::GetProcessor(processor); Processor p = kernel::GetProcessor(processor);
MS_LOG(DEBUG) << "start decode cnode, " << cnode_json;
// decode attrs.
std::map<std::string, ValuePtr> cnode_attrs;
if (!DecodeAttrs(cnode_json, &cnode_attrs)) {
MS_LOG(ERROR) << "Error decode attrs.";
return nullptr;
}
std::string op_name = cnode_json[kJsonKeyName];
// new primitive.
auto primitive = GetPrimitive(op_name, cnode_attrs);
if (primitive == nullptr) {
MS_LOG(ERROR) << "Create primitive failed.";
return nullptr;
}

// data layout info.
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<std::string> output_formats;
std::vector<TypeId> output_types;
return decoder.DecodeCNode(cnode_json, func_graph, p);
}


// collect inputs.
auto primitive_v = NewValueNode(primitive);
func_graph->AddValueNode(primitive_v);
std::vector<AnfNodePtr> inputs{primitive_v};
std::vector<nlohmann::json> input_descs = cnode_json[kJsonKeyInputDesc];
for (size_t i = 0; i < input_descs.size(); ++i) {
nlohmann::json input_desc = input_descs[i][0];
std::string name = input_desc[kJsonKeyTensorName];
if (input_desc.find(kJsonKeyValue) != input_desc.end()) {
inputs.push_back(DecodeValueNode(input_desc, func_graph));
} else if (nodes_map_.count(name) == 0) {
MS_LOG(ERROR) << "Input: " << name << " of: " << op_name << " not found.";
AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector<nlohmann::json> &output_descs,
const FuncGraphPtr &func_graph) {
std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)};
for (const auto &output_desc : output_descs) {
std::string name = output_desc[kJsonKeyTensorName];
if (nodes_map_.count(name) == 0) {
MS_LOG(ERROR) << "Output: " << name << " of graph not found.";
return nullptr; return nullptr;
} else {
inputs.push_back(nodes_map_[name]);
} }
input_formats.push_back(input_desc[kJsonKeyFormat]);
input_types.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType]));
outputs.push_back(nodes_map_[name]);
} }
MS_LOG(DEBUG) << "decode inputs success.";

// new cnode.
auto cnode = func_graph->NewCNode(inputs);
func_graph->AddNode(cnode);

// decode outputs.
std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc];
AbstractBasePtr abstract(nullptr);
if (output_descs.empty()) {
MS_LOG(ERROR) << "No outputs found.";
return nullptr;
} else if (output_descs.size() == 1) {
// single output.
nlohmann::json output_desc = output_descs[0];
output_formats.push_back(output_desc[kJsonKeyFormat]);
output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
nodes_map_[output_desc[kJsonKeyTensorName]] = cnode;
if (outputs.size() == 2) {
func_graph->set_output(outputs[1]);
} else { } else {
// multi outputs.
for (size_t j = 0; j < output_descs.size(); ++j) {
nlohmann::json output_desc = output_descs[j];
output_formats.push_back(output_desc[kJsonKeyFormat]);
output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
auto get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, NewValueNode(SizeToInt(j))});
func_graph->AddNode(get_item);
nodes_map_[output_desc[kJsonKeyTensorName]] = get_item;
}
}
MS_LOG(DEBUG) << "decode outputs success.";

// create kernel_info.
auto kernel_info = std::make_shared<device::KernelInfo>();
std::vector<size_t> feature_map_input_indexs;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
for (size_t index = 1; index < inputs.size(); ++index) {
auto node = AnfAlgo::VisitKernel(inputs[index], 0);
if (AnfAlgo::IsFeatureMapOutput(node.first)) {
feature_map_input_indexs.push_back(index);
}
auto output = func_graph->NewCNode(outputs);
func_graph->AddNode(output);
func_graph->set_output(output);
} }
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
}
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealCNodeKernel(cnode)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
}
cnode->set_kernel_info(kernel_info);
// create kernel_build_info.
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetInputsFormat(input_formats);
builder->SetInputsDeviceType(input_types);
builder->SetOutputsFormat(output_formats);
builder->SetOutputsDeviceType(output_types);
builder->SetProcessor(p);
builder->SetKernelType(KernelType::AKG_KERNEL);
builder->SetFusionType(kernel::FusionType::OPAQUE);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
return cnode;
return func_graph->output();
} }


FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel_json) { FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel_json) {
MS_LOG(DEBUG) << "start decode, " << kernel_json; MS_LOG(DEBUG) << "start decode, " << kernel_json;
// clear cache.
nodes_map_.clear(); nodes_map_.clear();
// create a graph.
auto graph = std::make_shared<FuncGraph>(); auto graph = std::make_shared<FuncGraph>();


// decode parameters. // decode parameters.
@@ -331,10 +371,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel
for (size_t i = 0; i < input_descs.size(); ++i) { for (size_t i = 0; i < input_descs.size(); ++i) {
std::vector<nlohmann::json> input_desc = input_descs[i]; std::vector<nlohmann::json> input_desc = input_descs[i];
auto parameter = DecodeParameter(input_desc[0], graph); auto parameter = DecodeParameter(input_desc[0], graph);
if (parameter == nullptr) {
MS_LOG(ERROR) << "Error decode parameter.";
return nullptr;
}
MS_EXCEPTION_IF_NULL(parameter);
} }
MS_LOG(DEBUG) << "decode parameters success."; MS_LOG(DEBUG) << "decode parameters success.";


@@ -346,10 +383,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel
} }
for (const auto &op_desc : op_node_descs) { for (const auto &op_desc : op_node_descs) {
auto op_node = DecodeCNode(op_desc, graph, kernel_json[kJsonKeyProcess]); auto op_node = DecodeCNode(op_desc, graph, kernel_json[kJsonKeyProcess]);
if (op_node == nullptr) {
MS_LOG(ERROR) << "Error decode cnode.";
return nullptr;
}
MS_EXCEPTION_IF_NULL(op_node);
} }
MS_LOG(DEBUG) << "decode cnodes success."; MS_LOG(DEBUG) << "decode cnodes success.";


@@ -359,22 +393,8 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel
MS_LOG(ERROR) << "Error decode outputs, no outputs for graph."; MS_LOG(ERROR) << "Error decode outputs, no outputs for graph.";
return nullptr; return nullptr;
} }
std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)};
for (const auto &output_desc : output_descs) {
std::string name = output_desc[kJsonKeyTensorName];
if (nodes_map_.count(name) == 0) {
MS_LOG(ERROR) << "Output: " << name << " of graph not found.";
return nullptr;
}
outputs.push_back(nodes_map_[name]);
}
if (outputs.size() == 2) {
graph->set_output(outputs[1]);
} else {
auto output = graph->NewCNode(outputs);
graph->AddNode(output);
graph->set_output(output);
}
auto output = DecodeOutput(output_descs, graph);
MS_EXCEPTION_IF_NULL(output);
MS_LOG(DEBUG) << "decode success, " << kernel_json; MS_LOG(DEBUG) << "decode success, " << kernel_json;
return graph; return graph;
} }


+ 2
- 3
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h View File

@@ -37,11 +37,10 @@ class AkgKernelJsonDecoder {
AnfNodePtrList *res_graphs); AnfNodePtrList *res_graphs);


private: private:
ScalarPtr DecodeScalar(const nlohmann::json &scalar_json);
ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph);
ParameterPtr DecodeParameter(const nlohmann::json &parameter_json, const FuncGraphPtr &func_graph); ParameterPtr DecodeParameter(const nlohmann::json &parameter_json, const FuncGraphPtr &func_graph);
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
std::map<std::string, AnfNodePtr> nodes_map_{};
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
std::map<std::string, AnfNodePtr> nodes_map_;
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore


+ 97
- 105
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc View File

@@ -79,22 +79,16 @@ inline std::string AkgKernelJsonGenerator::GetOutputFormat(const AnfNodePtr &anf


bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info, bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *const inputs_json) { nlohmann::json *const inputs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_info);
MS_EXCEPTION_IF_NULL(inputs_json);

// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr(); std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr();
if (inputs_ptr.empty()) { if (inputs_ptr.empty()) {
MS_LOG(DEBUG) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info";
return true;
MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info";
return false;
} }


// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
auto dyn_input_sizes = GetDynInputSize(anf_node); auto dyn_input_sizes = GetDynInputSize(anf_node);

size_t real_input_index = 0; size_t real_input_index = 0;
std::vector<nlohmann::json> input_list;
for (size_t i = 0; i < inputs_ptr.size(); i++) { for (size_t i = 0; i < inputs_ptr.size(); i++) {
std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i]; std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i];
if (input_ptr == nullptr) { if (input_ptr == nullptr) {
@@ -102,10 +96,8 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
return false; return false;
} }


auto op_input_name = input_ptr->name();
size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : IntToSize(dyn_input_sizes[i]); size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : IntToSize(dyn_input_sizes[i]);

input_list.clear();
std::vector<nlohmann::json> input_list;
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
auto type_id = this->GetInputDataType(anf_node, real_input_index); auto type_id = this->GetInputDataType(anf_node, real_input_index);
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
@@ -117,7 +109,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
nlohmann::json input_desc_json; nlohmann::json input_desc_json;
input_desc_json[kJsonKeyDataType] = dtype; input_desc_json[kJsonKeyDataType] = dtype;
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(anf_node, real_input_index); input_desc_json[kJsonKeyFormat] = this->GetInputFormat(anf_node, real_input_index);
input_desc_json[kJsonKeyName] = op_input_name;
input_desc_json[kJsonKeyName] = input_ptr->name();
input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
auto input_shape = this->GetInputShape(anf_node, real_input_index); auto input_shape = this->GetInputShape(anf_node, real_input_index);
if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) &&
@@ -204,77 +196,56 @@ void AkgKernelJsonGenerator::GetJson(const AnfNodePtr &anf_node, const std::vect


bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info, bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *const attrs_json) { nlohmann::json *const attrs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_info);
MS_EXCEPTION_IF_NULL(attrs_json);
std::vector<std::shared_ptr<OpAttr>> attrs = op_info->attrs_ptr(); std::vector<std::shared_ptr<OpAttr>> attrs = op_info->attrs_ptr();
if (attrs.empty()) { if (attrs.empty()) {
MS_LOG(INFO) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty";
MS_LOG(DEBUG) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty";
return true; return true;
} }
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info->inputs_ptr();

std::vector<int> dyn_input_sizes;
auto dyn_input_sizes = GetDynInputSize(anf_node);
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}

if (inputs.empty()) {
MS_LOG(ERROR) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info inputs is empty";
return false;
}


// create input name list for "x_shape" in attr with "x" in primitive. // create input name list for "x_shape" in attr with "x" in primitive.
std::map<size_t, std::string> op_info_shape_name;
for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) {
std::string input_name = inputs[op_info_input_i]->name();
std::string x_shape_name = input_name + "_shape";
static_cast<void>(op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name)));
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info->inputs_ptr();
std::map<std::string, size_t> op_info_shape_name;
for (size_t i = 0; i < inputs.size(); i++) {
op_info_shape_name[inputs[i]->name() + "_shape"] = i;
} }


for (const auto &op_attr : attrs) { for (const auto &op_attr : attrs) {
nlohmann::json attr_json; nlohmann::json attr_json;
ValuePtr attr_value = primitive->GetAttr(op_attr->name()); ValuePtr attr_value = primitive->GetAttr(op_attr->name());
if (attr_value == nullptr && op_attr->name() != kArgDataformat) { if (attr_value == nullptr && op_attr->name() != kArgDataformat) {
if (op_attr->param_type() == "required") {
// match "x_shape" in att with "x" in primitive.
std::string attr_name = op_attr->name();
auto find_item = std::find_if(
op_info_shape_name.begin(), op_info_shape_name.end(),
[attr_name](const std::map<size_t, std::string>::value_type item) { return item.second == attr_name; });
if (find_item != op_info_shape_name.end()) {
if (!dyn_input_sizes.empty()) {
if (find_item->first >= dyn_input_sizes.size() - 1) {
MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first
<< " is out of range:" << dyn_input_sizes.size() - 1 << ".";
return false;
}
size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0));
for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) {
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx);
attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
tensor_idx++;
}
} else {
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first);
if (op_attr->param_type() != "required") continue;
// match "x_shape" in attr with "x" in primitive.
auto find_item = op_info_shape_name.find(op_attr->name());
if (find_item != op_info_shape_name.end()) {
if (!dyn_input_sizes.empty()) {
if (find_item->second >= dyn_input_sizes.size() - 1) {
MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->second
<< " is out of range:" << dyn_input_sizes.size() - 1 << ".";
return false;
}
size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->second], 0));
for (int input_i = 0; input_i < dyn_input_sizes[find_item->second]; input_i++) {
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx);
attr_json[kJsonKeyName] = op_attr->name(); attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json); attrs_json->push_back(attr_json);
tensor_idx++;
} }
} else { } else {
MS_LOG(ERROR) << "op [" << anf_node->fullname_with_scope() << "] should have attr :" << op_attr->name();
return false;
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->second);
attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
} }
} else {
MS_LOG(ERROR) << "op [" << anf_node->fullname_with_scope() << "] should have attr :" << op_attr->name();
return false;
} }
continue;
} else {
GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value);
attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
} }

GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value);

attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
} }
return true; return true;
} }
@@ -485,7 +456,47 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size()
<< "], output_list: [" << input_list.size() << "]."; << "], output_list: [" << input_list.size() << "].";
std::map<AnfNodePtr, nlohmann::json> node_json_map; std::map<AnfNodePtr, nlohmann::json> node_json_map;
if (!GenSingleJsons(anf_nodes, &node_json_map)) return false;

UpdateTensorName(anf_nodes, &node_json_map);

std::vector<nlohmann::json> node_json_desc;
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
(*kernel_json)[kJsonKeyOpDesc] = node_json_desc;

auto inputs_json = CreateInputsJson(anf_nodes, input_list, node_json_map);
(*kernel_json)[kJsonKeyInputDesc] = inputs_json;
(*kernel_json)[kJsonKeyOutputDesc] =
CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map);


size_t hash_id = std::hash<std::string>()(kernel_json->dump());
kernel_name_ = "Fused_";
auto fg = anf_nodes[0]->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (attr_val != nullptr) {
auto fg_attr = GetValue<std::string>(attr_val);
(void)kernel_name_.append(fg_attr).append("_");
}
(void)kernel_name_.append(std::to_string(hash_id));
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]);
(*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();

if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}

return true;
}

bool AkgKernelJsonGenerator::GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
for (auto const &anf_node : anf_nodes) { for (auto const &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) { if (!AnfAlgo::IsRealKernel(anf_node)) {
@@ -507,9 +518,13 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); node_json["fusion"] = primitive->GetAttr("fusion")->ToString();
} }


node_json_map[anf_node] = node_json;
(*node_json_map)[anf_node] = node_json;
} }
return true;
}


void AkgKernelJsonGenerator::UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
for (auto const &anf_node : anf_nodes) { for (auto const &anf_node : anf_nodes) {
auto dyn_input_sizes = GetDynInputSize(anf_node); auto dyn_input_sizes = GetDynInputSize(anf_node);
bool is_dynamic_input = !dyn_input_sizes.empty(); bool is_dynamic_input = !dyn_input_sizes.empty();
@@ -519,11 +534,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
for (size_t j = 0; j < input_tensor_num; ++j) { for (size_t j = 0; j < input_tensor_num; ++j) {
auto tmp_input = GetKernelInput(anf_node, real_input_index); auto tmp_input = GetKernelInput(anf_node, real_input_index);
std::string tensor_name = GetTensorName(node_json_map[anf_node], kJsonKeyInputDesc, std::make_pair(i, j));
if (node_json_map.find(tmp_input.first) != node_json_map.end()) {
std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kJsonKeyInputDesc, std::make_pair(i, j));
if (node_json_map->find(tmp_input.first) != node_json_map->end()) {
std::string new_tensor_name = std::string new_tensor_name =
GetTensorName(node_json_map[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second));
SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node]));
GetTensorName((*node_json_map)[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second));
SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node]));
MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
<< new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
@@ -535,12 +550,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
} }
} }
} }
}


std::vector<nlohmann::json> node_json_desc;
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
(*kernel_json)[kJsonKeyOpDesc] = node_json_desc;

nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::map<AnfNodePtr, nlohmann::json> &node_json_map) {
nlohmann::json inputs_json; nlohmann::json inputs_json;
auto input_index = GetInputIndex(anf_nodes, input_list); auto input_index = GetInputIndex(anf_nodes, input_list);
for (size_t i = 0; i < input_index.size(); ++i) { for (size_t i = 0; i < input_index.size(); ++i) {
@@ -549,18 +563,22 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
nlohmann::json input_desc_json; nlohmann::json input_desc_json;
input_desc_json[kJsonKeyTensorName] = input_desc_json[kJsonKeyTensorName] =
GetTensorName(node_json_map[tmp_input.first], kJsonKeyInputDesc, tmp_input.second);
GetTensorName(node_json_map.at(tmp_input.first), kJsonKeyInputDesc, tmp_input.second);
input_desc_json[kJsonKeyDataType] = dtype; input_desc_json[kJsonKeyDataType] = dtype;
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first); input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first);
input_desc_json[kJsonKeyShape] = this->GetInputShape(tmp_input.first, tmp_input.second.first); input_desc_json[kJsonKeyShape] = this->GetInputShape(tmp_input.first, tmp_input.second.first);
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json}); inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
} }
(*kernel_json)[kJsonKeyInputDesc] = inputs_json;
return inputs_json;
}


nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list,
const nlohmann::json &inputs_json,
const std::map<AnfNodePtr, nlohmann::json> &node_json_map) {
nlohmann::json outputs_json; nlohmann::json outputs_json;
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
std::map<size_t, std::vector<std::string>> sub_graphs;
std::map<size_t, size_t> dim_infos;
for (size_t i = 0; i < output_index.size(); ++i) { for (size_t i = 0; i < output_index.size(); ++i) {
auto tmp_output = output_index[i]; auto tmp_output = output_index[i];
bool found = false; bool found = false;
@@ -576,7 +594,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second); auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second);
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
output_desc_json[kJsonKeyTensorName] = output_desc_json[kJsonKeyTensorName] =
GetTensorName(node_json_map[tmp_output.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second));
GetTensorName(node_json_map.at(tmp_output.first), kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second));
output_desc_json[kJsonKeyDataType] = dtype; output_desc_json[kJsonKeyDataType] = dtype;
output_desc_json[kJsonKeyFormat] = this->GetOutputFormat(tmp_output.first, tmp_output.second); output_desc_json[kJsonKeyFormat] = this->GetOutputFormat(tmp_output.first, tmp_output.second);
auto output_shape = this->GetOutputShape(tmp_output.first, tmp_output.second); auto output_shape = this->GetOutputShape(tmp_output.first, tmp_output.second);
@@ -587,33 +605,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
} }
outputs_json.emplace_back(output_desc_json); outputs_json.emplace_back(output_desc_json);
} }
(*kernel_json)[kJsonKeyOutputDesc] = outputs_json;

auto processor = GetProcessorStr(anf_nodes[0]);

size_t hash_id = std::hash<std::string>()(kernel_json->dump());
kernel_name_ = "Fused_";
auto fg = anf_nodes[0]->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (attr_val != nullptr) {
auto fg_attr = GetValue<std::string>(attr_val);
(void)kernel_name_.append(fg_attr).append("_");
}
(void)kernel_name_.append(std::to_string(hash_id));
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = processor;
(*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();

if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}

return true;
return outputs_json;
} }


bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node) { bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node) {


+ 8
- 0
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h View File

@@ -94,6 +94,14 @@ class AkgKernelJsonGenerator {
nlohmann::json *const attrs_json); nlohmann::json *const attrs_json);
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size, bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size); std::vector<size_t> *const output_size);
bool GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map);
void UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map);
nlohmann::json CreateInputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::map<AnfNodePtr, nlohmann::json> &node_json_map);
nlohmann::json CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list, const nlohmann::json &inputs_json,
const std::map<AnfNodePtr, nlohmann::json> &node_json_map);

int GetOpCntInc(); int GetOpCntInc();
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
size_t GetOutputTensorIdxInc(); size_t GetOutputTensorIdxInc();


+ 59
- 39
mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc View File

@@ -36,15 +36,23 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace {
constexpr int32_t PROCESS_NUM = 16; constexpr int32_t PROCESS_NUM = 16;
constexpr int32_t TIME_OUT = 300; constexpr int32_t TIME_OUT = 300;


bool AkgAscendKernelBuilder::AkgOpParallelBuild(
const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args) {
void SetKernelMod(const KernelPackPtr &kernel_pack, const AkgKernelJsonGenerator &json_generator,
const AnfNodePtr &anf_node) {
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(kernel_pack);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
}
} // namespace

std::vector<std::string> AkgAscendKernelBuilder::GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args) {
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
std::vector<std::string> jsons; std::vector<std::string> jsons;
std::unordered_set<std::string> kernel_name_set; std::unordered_set<std::string> kernel_name_set;
std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> repeat_nodes;
for (const auto &[json_generator, anf_node] : build_args) { for (const auto &[json_generator, anf_node] : build_args) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
auto kernel_name = json_generator.kernel_name(); auto kernel_name = json_generator.kernel_name();
@@ -53,15 +61,12 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild(
if (cached_kernel_pack != nullptr) { if (cached_kernel_pack != nullptr) {
MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "]."; << anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
SetKernelMod(cached_kernel_pack, json_generator, anf_node);
continue; continue;
} }


if (kernel_name_set.count(kernel_name) != 0) { if (kernel_name_set.count(kernel_name) != 0) {
repeat_nodes.push_back({json_generator, anf_node});
repeat_nodes_.push_back({json_generator, anf_node});
continue; continue;
} }
kernel_name_set.insert(kernel_name); kernel_name_set.insert(kernel_name);
@@ -69,7 +74,43 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild(
kernel::SaveJsonInfo(kernel_name, kernel_json); kernel::SaveJsonInfo(kernel_name, kernel_json);
jsons.push_back(kernel_json); jsons.push_back(kernel_json);
} }
return jsons;
}

bool AkgAscendKernelBuilder::InsertToCache(const std::vector<JsonNodePair> &build_args) {
for (const auto &[json_generator, anf_node] : build_args) {
auto kernel_name = json_generator.kernel_name();
auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node));
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return false;
}
SetKernelMod(new_kernel_pack, json_generator, anf_node);
MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!";
}
return true;
}

bool AkgAscendKernelBuilder::HandleRepeatNodes() {
for (const auto &[json_generator, anf_node] : repeat_nodes_) {
auto kernel_name = json_generator.kernel_name();
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node));
if (cached_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Use cached kernel failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return false;
}
MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
SetKernelMod(cached_kernel_pack, json_generator, anf_node);
}
return true;
}


bool AkgAscendKernelBuilder::AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args) {
repeat_nodes_.clear();
auto jsons = GetNotCachedKernelJsons(build_args);
if (jsons.empty()) { if (jsons.empty()) {
return true; return true;
} }
@@ -89,56 +130,35 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild(
} }


// All unique done here, cache them and set kernel. // All unique done here, cache them and set kernel.
for (const auto &[json_generator, anf_node] : build_args) {
auto kernel_name = json_generator.kernel_name();
auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node));
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return false;
}
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!";
if (!InsertToCache(build_args)) {
MS_LOG(ERROR) << "Insert cache failed.";
return false;
} }


// Handle repeated nodes.
for (const auto &[json_generator, anf_node] : repeat_nodes) {
auto kernel_name = json_generator.kernel_name();
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node));
if (cached_kernel_pack == nullptr) return false;
MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
if (!HandleRepeatNodes()) {
MS_LOG(ERROR) << "Handle repeat nodes failed.";
return false;
} }


return true; return true;
} }


bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) { bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> json_and_node;
std::vector<JsonNodePair> json_and_node;
for (const auto &anf_node : anf_nodes) { for (const auto &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
AkgKernelJsonGenerator akg_kernel_json_generator; AkgKernelJsonGenerator akg_kernel_json_generator;
KernelPackPtr kernel_pack = nullptr;
auto cnode = anf_node->cast<CNodePtr>(); auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::IsGraphKernel(cnode)) { if (AnfAlgo::IsGraphKernel(cnode)) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager(); auto mng = func_graph->manager();
if (mng == nullptr) { if (mng == nullptr) {
mng = Manage(func_graph, true); mng = Manage(func_graph, true);
func_graph->set_manager(mng); func_graph->set_manager(mng);
} }
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_list;
std::vector<AnfNodePtr> input_list;
std::vector<AnfNodePtr> output_list;
std::vector<AnfNodePtr> node_list, input_list, output_list;
MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]"; MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]";
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) { if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) {
@@ -146,7 +166,7 @@ bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
} }
} else { } else {
if (!akg_kernel_json_generator.CollectJson(anf_node)) { if (!akg_kernel_json_generator.CollectJson(anf_node)) {
MS_EXCEPTION(UnknownError) << "Akg build failed op[" << anf_node->fullname_with_scope() << "].";
MS_EXCEPTION(UnknownError) << "Akg build failed basic op[" << anf_node->fullname_with_scope() << "].";
} }
} }
json_and_node.push_back({akg_kernel_json_generator, anf_node}); json_and_node.push_back({akg_kernel_json_generator, anf_node});


+ 9
- 1
mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h View File

@@ -27,12 +27,20 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
using JsonNodePair = std::pair<AkgKernelJsonGenerator, AnfNodePtr>;

class AkgAscendKernelBuilder { class AkgAscendKernelBuilder {
public: public:
AkgAscendKernelBuilder() = default; AkgAscendKernelBuilder() = default;
~AkgAscendKernelBuilder() = default; ~AkgAscendKernelBuilder() = default;
bool AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args);

private:
std::vector<std::string> GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args);
bool InsertToCache(const std::vector<JsonNodePair> &build_args);
bool HandleRepeatNodes();


bool AkgOpParallelBuild(const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args);
std::vector<JsonNodePair> repeat_nodes_;
}; };


bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes); bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);


+ 0
- 2
mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.h View File

@@ -1,4 +1,3 @@

/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
@@ -32,7 +31,6 @@ class BasicOpsFusion : public Pass {
bool Run(const FuncGraphPtr &func_graph) override; bool Run(const FuncGraphPtr &func_graph) override;
}; };
using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>; using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>;

} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_

+ 1
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc View File

@@ -128,7 +128,7 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str; MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str;
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str); auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str);
// parse result. // parse result.
if (ret.is(py::none())) {
if (py::isinstance<py::none>(ret)) {
MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n" MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n"
<< node_desc_str; << node_desc_str;
return nullptr; return nullptr;


+ 33
- 36
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -211,9 +211,9 @@ AnfNodePtr DeleteAttrInInput(const FuncGraphPtr &func_graph, const CNodePtr &cno
return new_cnode; return new_cnode;
} }


AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) {
AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
AnfNodePtrList outs; AnfNodePtrList outs;
auto out_node = (*fg)->output();
auto out_node = fg->output();
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> output_args; std::vector<AnfNodePtr> output_args;
auto out_cnode = out_node->cast<CNodePtr>(); auto out_cnode = out_node->cast<CNodePtr>();
@@ -228,8 +228,8 @@ AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *m
} }
} }
if (output_args.size() != out_cnode->inputs().size()) { if (output_args.size() != out_cnode->inputs().size()) {
auto new_out = (*fg)->NewCNode(output_args);
(*mng)->Replace(out_node, new_out);
auto new_out = fg->NewCNode(output_args);
mng->Replace(out_node, new_out);
} }


for (size_t i = 1; i < output_args.size(); ++i) { for (size_t i = 1; i < output_args.size(); ++i) {
@@ -241,6 +241,27 @@ AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *m
outs.push_back(out_node); outs.push_back(out_node);
return outs; return outs;
} }

bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs,
const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map) {
kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option);
if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) {
MS_LOG(ERROR) << "Collect json desc failed.";
return false;
}

*op_desc = akg_kernel_json_generator.kernel_json();
if (address_node_map != nullptr) {
*address_node_map = akg_kernel_json_generator.address_node_map();
}
std::string fused_name;
std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) {
(void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_");
});
MS_LOG(INFO) << "Collect fusion json: " << fused_name;
return true;
}
} // namespace } // namespace


void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
@@ -457,7 +478,7 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
mng->Replace(n, out); mng->Replace(n, out);
} }


EliminateMakeTuple(&fg, &mng);
EliminateMakeTuple(fg, mng);
// set graphKernel attr // set graphKernel attr
std::string fuse_op_name = ""; std::string fuse_op_name = "";
for (auto &fuse_node : fuse_nodes) { for (auto &fuse_node : fuse_nodes) {
@@ -476,50 +497,26 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
} }


bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc,
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map) { std::map<std::string, AnfNodePtr> *address_node_map) {
MS_EXCEPTION_IF_NULL(op_desc); MS_EXCEPTION_IF_NULL(op_desc);
if (nodes.empty()) { if (nodes.empty()) {
MS_LOG(ERROR) << "Input nodes is empty."; MS_LOG(ERROR) << "Input nodes is empty.";
return false; return false;
} }
bool has_graph_kernel =
std::any_of(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) { return AnfAlgo::IsGraphKernel(node); });
bool has_graph_kernel = std::any_of(nodes.begin(), nodes.end(), AnfAlgo::IsGraphKernel);
bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1; bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1;


auto gen_json = [&dump_option, &op_desc, &address_node_map](const AnfNodePtrList &op_nodes,
const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs) -> bool {
kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option);
if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) {
MS_LOG(ERROR) << "Collect json desc failed.";
return false;
}

*op_desc = akg_kernel_json_generator.kernel_json();
if (address_node_map != nullptr) {
*address_node_map = akg_kernel_json_generator.address_node_map();
}
std::string fused_name;
std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) {
(void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_");
});
MS_LOG(INFO) << "Collect fusion json: " << fused_name;
return true;
};

FuncGraphPtr fg; FuncGraphPtr fg;
AnfNodePtrList op_nodes;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
AnfNodePtrList op_nodes, inputs, outputs;
if (is_single_graph_kernel) { if (is_single_graph_kernel) {
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]); fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
return gen_json(op_nodes, inputs, outputs);
return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map);
} else if (!has_graph_kernel) { } else if (!has_graph_kernel) {
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes);
op_nodes = nodes; op_nodes = nodes;
return gen_json(op_nodes, inputs, outputs);
return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map);
} }


std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes);
@@ -540,10 +537,10 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann
inputs.clear(); inputs.clear();
outputs.clear(); outputs.clear();
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
return gen_json(op_nodes, inputs, outputs);
return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map);
} }


bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, DumpOption dump_option, nlohmann::json *op_desc) {
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc) {
MS_EXCEPTION_IF_NULL(op_desc); MS_EXCEPTION_IF_NULL(op_desc);
std::vector<nlohmann::json> graphs_desc; std::vector<nlohmann::json> graphs_desc;
for (auto const &graph_nodes : graphs) { for (auto const &graph_nodes : graphs) {


+ 2
- 2
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h View File

@@ -46,9 +46,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix, const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix,
bool is_before_kernel_select); bool is_before_kernel_select);
bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc,
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map = nullptr); std::map<std::string, AnfNodePtr> *address_node_map = nullptr);
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, DumpOption dump_option, nlohmann::json *op_desc);
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc);
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs);
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map, bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
std::vector<AnfNodePtrList> *res_graphs); std::vector<AnfNodePtrList> *res_graphs);


+ 18
- 15
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc View File

@@ -57,8 +57,6 @@ inline void TraverseFuncGraph(const FuncGraphPtr &root, std::function<void(AnfNo
TraverseFuncGraphFromCNode(root->get_return(), callback); TraverseFuncGraphFromCNode(root->get_return(), callback);
} }


class AreaGraph;
class Splitter;
class Area { class Area {
public: public:
explicit Area(const AnfNodePtrList &anf_arr) { explicit Area(const AnfNodePtrList &anf_arr) {
@@ -73,6 +71,8 @@ class Area {
} }
} }


~Area() = default;

// Set the external inputs of spy as a Parameter. // Set the external inputs of spy as a Parameter.
void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map<ParameterPtr, AnfNodePtr> *param_node_map) { void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map<ParameterPtr, AnfNodePtr> *param_node_map) {
std::unordered_map<AnfNodePtr, ParameterPtr> node_param_map; std::unordered_map<AnfNodePtr, ParameterPtr> node_param_map;
@@ -148,8 +148,8 @@ class Area {
} }
} }


friend AreaGraph;
friend Splitter;
const std::unordered_set<AnfNodePtr> &nodes() const { return nodes_; }
const std::vector<AnfNodePtr> &spy_cnodes() const { return spy_cnodes_; }


private: private:
// This is a CNode that does not belong to this area. // This is a CNode that does not belong to this area.
@@ -170,9 +170,8 @@ class AreaGraph {
// Build an area graph to maintain the relation between areas. // Build an area graph to maintain the relation between areas.
// Input node_groups: A group list, each element is a AnfNode list representing the node set in this group. // Input node_groups: A group list, each element is a AnfNode list representing the node set in this group.
static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) { static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) {
AreaGraph *area_graph_ptr = new (std::nothrow) AreaGraph(node_groups);
if (!area_graph_ptr) return nullptr;
auto area_graph = AreaGraphPtr(area_graph_ptr);
auto area_graph = AreaGraphPtr(new AreaGraph(node_groups));
if (area_graph == nullptr) return nullptr;
if (!area_graph->TopoSort()) { if (!area_graph->TopoSort()) {
MS_LOG(WARNING) << "The groups have a cycle."; MS_LOG(WARNING) << "The groups have a cycle.";
return nullptr; return nullptr;
@@ -184,12 +183,12 @@ class AreaGraph {
// The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs. // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
// The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting. // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes, void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
std::vector<size_t> *cnode_group_id, std::function<void(Area *)> expand_callback) {
std::vector<size_t> *cnode_group_id, std::function<void(const Area &)> expand_callback) {
main_cnodes->clear(); main_cnodes->clear();
main_cnodes->resize(areas_.size(), nullptr); main_cnodes->resize(areas_.size(), nullptr);


for (auto &area : this->areas_) { for (auto &area : this->areas_) {
expand_callback(&area);
expand_callback(area);
} }


for (auto index : topo_order_) { for (auto index : topo_order_) {
@@ -208,6 +207,8 @@ class AreaGraph {
return; return;
} }


~AreaGraph() = default;

private: private:
explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) { explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) {
for (size_t i = 0; i < node_groups.size(); ++i) { for (size_t i = 0; i < node_groups.size(); ++i) {
@@ -217,7 +218,7 @@ class AreaGraph {
} }
} }
for (auto &area : areas_) { for (auto &area : areas_) {
for (auto &spy : area.spy_cnodes_) {
for (auto &spy : area.spy_cnodes()) {
auto cnode = spy->cast<CNodePtr>(); auto cnode = spy->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
size_t v = node_area_map_[spy]; size_t v = node_area_map_[spy];
@@ -333,8 +334,8 @@ class Splitter {


// The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan. // The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan.
std::vector<size_t> cnodes_group_id; std::vector<size_t> cnodes_group_id;
std::function<void(Area *)> expand_callback = std::bind(&Splitter::AreaExpand, this, std::placeholders::_1);
area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id, expand_callback);
area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id,
[this](const Area &area) { this->AreaExpand(area); });


RebuildGraph(cnodes_group_id); RebuildGraph(cnodes_group_id);


@@ -348,6 +349,8 @@ class Splitter {
return SplitterPtr(new Splitter(main_cnode, split_schemer)); return SplitterPtr(new Splitter(main_cnode, split_schemer));
} }


~Splitter() = default;

private: private:
Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer)
: main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {} : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
@@ -479,9 +482,9 @@ class Splitter {
} }


// Copy all Parameter and ValueNode that the area used. // Copy all Parameter and ValueNode that the area used.
void AreaExpand(Area *area) {
void AreaExpand(const Area &area) {
std::unordered_map<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map; std::unordered_map<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map;
for (auto sub_node : area->nodes_) {
for (auto sub_node : area.nodes()) {
auto sub_cnode = sub_node->cast<CNodePtr>(); auto sub_cnode = sub_node->cast<CNodePtr>();
if (sub_cnode == nullptr) continue; if (sub_cnode == nullptr) continue;
for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) { for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) {
@@ -565,7 +568,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
auto json_desc_str = json_desc.dump(); auto json_desc_str = json_desc.dump();
MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json:\n" << json_desc_str; MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json:\n" << json_desc_str;
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str); auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str);
if (ret.is(py::none())) {
if (py::isinstance<py::none>(ret)) {
MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n" MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
<< json_desc_str; << json_desc_str;
return false; return false;


+ 4
- 2
mindspore/ccsrc/debug/anf_ir_dump.cc View File

@@ -463,13 +463,15 @@ std::string AddGlobalId(const std::string &filename) {
static size_t g_id = 0; static size_t g_id = 0;
std::ostringstream s; std::ostringstream s;
auto i = filename.rfind('/'); auto i = filename.rfind('/');
if (i == string::npos) {
if (i >= filename.size()) {
s << std::setfill('0') << std::setw(4) << g_id << "_"; s << std::setfill('0') << std::setw(4) << g_id << "_";
s << filename; s << filename;
} else { } else {
s << filename.substr(0, i + 1); s << filename.substr(0, i + 1);
s << std::setfill('0') << std::setw(4) << g_id << "_"; s << std::setfill('0') << std::setw(4) << g_id << "_";
s << filename.substr(i + 1);
if (i + 1 < filename.size()) {
s << filename.substr(i + 1);
}
} }
++g_id; ++g_id;
return s.str(); return s.str();


+ 24
- 26
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc View File

@@ -236,12 +236,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
} }


void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(func_graph);

std::vector<AnfNodePtr> node_list;
std::vector<AnfNodePtr> input_list;
std::vector<AnfNodePtr> output_list;
std::vector<AnfNodePtr> node_list, input_list, output_list;
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);


std::vector<std::string> graph_input_format; std::vector<std::string> graph_input_format;
@@ -295,6 +290,22 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_gr
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
SetTensorDeviceInfo(*graph_selected_info, kernel_node); SetTensorDeviceInfo(*graph_selected_info, kernel_node);
} }

void PrintUnsupportedTypeException(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type,
const std::vector<TypeId> &outputs_type) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
std::string build_type = "in [";
std::for_each(std::begin(inputs_type), std::end(inputs_type),
[&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
build_type += "] out [";
std::for_each(std::begin(outputs_type), std::end(outputs_type),
[&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
build_type += "]";
auto supported_type_lists = SupportedTypeList(kernel_node);
MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name
<< "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists
<< ", but get " << build_type;
}
} // namespace } // namespace


void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
@@ -329,7 +340,7 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s


void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
if (AnfAlgo::IsGraphKernel(kernel_node)) { if (AnfAlgo::IsGraphKernel(kernel_node)) {
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(kernel_node);
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
SetGraphKernelInfo(kernel_node, func_graph); SetGraphKernelInfo(kernel_node, func_graph);
return; return;
@@ -351,8 +362,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) { if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
} }
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetOriginDataFormat(origin_data_format); builder->SetOriginDataFormat(origin_data_format);
builder->SetInputsFormat(inputs_format); builder->SetInputsFormat(inputs_format);
builder->SetInputsDeviceType(inputs_type); builder->SetInputsDeviceType(inputs_type);
@@ -360,35 +370,23 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
builder->SetOutputsDeviceType(outputs_type); builder->SetOutputsDeviceType(outputs_type);


bool result = false; bool result = false;
KernelType res_kernel_type = UNKNOWN_KERNEL_TYPE;
if (kernel_type == UNKNOWN_KERNEL_TYPE) { if (kernel_type == UNKNOWN_KERNEL_TYPE) {
result = result =
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());


if (!result) { if (!result) {
result = SelectAkgKernel(kernel_node, builder->Build()); result = SelectAkgKernel(kernel_node, builder->Build());
res_kernel_type = AKG_KERNEL;
kernel_type = AKG_KERNEL;
} }
} else if (kernel_type == AKG_KERNEL) { } else if (kernel_type == AKG_KERNEL) {
result = SelectAkgKernel(kernel_node, builder->Build()); result = SelectAkgKernel(kernel_node, builder->Build());
res_kernel_type = AKG_KERNEL;
} }


if (!result) { if (!result) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
std::string build_type = "in [";
std::for_each(std::begin(inputs_type), std::end(inputs_type),
[&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
build_type += "] out [";
std::for_each(std::begin(outputs_type), std::end(outputs_type),
[&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
build_type += "]";
auto supported_type_lists = SupportedTypeList(kernel_node);
MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name
<< "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists
<< ", but get " << build_type;
}
builder->SetKernelType(res_kernel_type);
PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type);
return;
}
builder->SetKernelType(kernel_type);
builder->SetProcessor(kernel::Processor::CUDA); builder->SetProcessor(kernel::Processor::CUDA);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
SetTensorDeviceInfo(*(builder->Build()), kernel_node); SetTensorDeviceInfo(*(builder->Build()), kernel_node);


Loading…
Cancel
Save