Merge pull request !6240 from DeshiChen/0910_review_bottags/v1.0.0
| @@ -199,6 +199,5 @@ void SetAkgKernelAttrs(const AnfNodePtr &anf_node) { | |||||
| it->second(anf_node); | it->second(anf_node); | ||||
| } | } | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,9 +20,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void SetAkgKernelAttrs(const AnfNodePtr &anf_node); | void SetAkgKernelAttrs(const AnfNodePtr &anf_node); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H | #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H | ||||
| @@ -42,153 +42,276 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| namespace { | namespace { | ||||
| ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) { | |||||
| if (type == "str") { | |||||
| std::string value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "int") { | |||||
| int value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "bool") { | |||||
| bool value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "float") { | |||||
| float value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "listInt") { | |||||
| std::vector<int> value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "listStr") { | |||||
| std::vector<std::string> value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unknown type of attr: " << type << ", json: \n" << attr_json; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | |||||
| constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | |||||
| bool DecodeAttrs(const nlohmann::json &attrs_json, std::map<std::string, ValuePtr> *attrs) { | |||||
| MS_EXCEPTION_IF_NULL(attrs); | |||||
| MS_LOG(DEBUG) << "start decode attrs, " << attrs_json; | |||||
| // decode attrs. | |||||
| if (attrs_json.find(kJsonKeyAttr) == attrs_json.end() || attrs_json[kJsonKeyAttr].is_null()) { | |||||
| // attrs maybe empty | |||||
| return true; | |||||
| class CNodeDecoder { | |||||
| public: | |||||
| explicit CNodeDecoder(std::map<std::string, AnfNodePtr> *nodes_map) : nodes_map_(*nodes_map) {} | |||||
| ~CNodeDecoder() = default; | |||||
| CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, kernel::Processor processor) { | |||||
| MS_LOG(DEBUG) << "start decode cnode, " << cnode_json; | |||||
| // decode attrs. | |||||
| if (!DecodeAttrs(cnode_json)) { | |||||
| MS_LOG(ERROR) << "Decode attrs failed."; | |||||
| return nullptr; | |||||
| } | |||||
| if (!DecodeInputDesc(cnode_json, func_graph) || cnode_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Decode inputs failed."; | |||||
| return nullptr; | |||||
| } | |||||
| if (!DecodeOutputDesc(cnode_json, func_graph)) { | |||||
| MS_LOG(ERROR) << "Decode outputs failed."; | |||||
| return nullptr; | |||||
| } | |||||
| CreateKernelInfo(processor); | |||||
| return cnode_; | |||||
| } | } | ||||
| std::vector<nlohmann::json> attr_descs = attrs_json[kJsonKeyAttr]; | |||||
| for (const auto &attr_desc : attr_descs) { | |||||
| std::string name = attr_desc[kJsonKeyName]; | |||||
| std::string type = attr_desc[kJsonKeyDataType]; | |||||
| auto value = ParseValue(attr_desc, type); | |||||
| if (value == nullptr) { | |||||
| return false; | |||||
| private: | |||||
| ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) { | |||||
| if (type == "str") { | |||||
| std::string value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "int") { | |||||
| int value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "bool") { | |||||
| bool value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "float") { | |||||
| float value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "listInt") { | |||||
| std::vector<int> value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else if (type == "listStr") { | |||||
| std::vector<std::string> value = attr_json[kJsonKeyValue]; | |||||
| return MakeValue(value); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unknown type of attr: " << type << ", json: \n" << attr_json; | |||||
| return nullptr; | |||||
| } | } | ||||
| (*attrs)[name] = value; | |||||
| } | } | ||||
| return true; | |||||
| } | |||||
| bool DecodeAttrs(const nlohmann::json &attrs_json) { | |||||
| MS_LOG(DEBUG) << "start decode attrs, " << attrs_json; | |||||
| // attrs maybe empty | |||||
| if (attrs_json.find(kJsonKeyAttr) == attrs_json.end() || attrs_json[kJsonKeyAttr].is_null()) { | |||||
| return true; | |||||
| } | |||||
| // python utils. | |||||
| constexpr auto kGetPythonOpFunc = "_get_python_op"; | |||||
| constexpr auto kParallelUtilsModule = "mindspore.parallel._utils"; | |||||
| // almost all ops are defined in this path. | |||||
| constexpr auto kOperationsModule = "mindspore.ops.operations"; | |||||
| std::vector<nlohmann::json> attr_descs = attrs_json[kJsonKeyAttr]; | |||||
| for (const auto &attr_desc : attr_descs) { | |||||
| std::string name = attr_desc[kJsonKeyName]; | |||||
| std::string type = attr_desc[kJsonKeyDataType]; | |||||
| auto value = ParseValue(attr_desc, type); | |||||
| if (value == nullptr) { | |||||
| return false; | |||||
| } | |||||
| cnode_attrs_[name] = value; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| const std::map<std::string, std::vector<std::string>> op_attrs_map = { | |||||
| {kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| {kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| {kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| }; | |||||
| bool DecodeInputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { | |||||
| std::string op_name = cnode_json[kJsonKeyName]; | |||||
| // new primitive. | |||||
| auto primitive = GetPrimitive(op_name); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "Create primitive failed."; | |||||
| return false; | |||||
| } | |||||
| ValuePtr CreatOpInstance(const std::string &op_name, const std::vector<ValuePtr> &attrs) { | |||||
| py::module mod = py::module::import(kOperationsModule); | |||||
| if (!py::hasattr(mod, op_name.c_str())) { | |||||
| MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name; | |||||
| return nullptr; | |||||
| // collect inputs. | |||||
| auto primitive_v = NewValueNode(primitive); | |||||
| func_graph->AddValueNode(primitive_v); | |||||
| std::vector<AnfNodePtr> inputs{primitive_v}; | |||||
| std::vector<nlohmann::json> input_descs = cnode_json[kJsonKeyInputDesc]; | |||||
| for (size_t i = 0; i < input_descs.size(); ++i) { | |||||
| nlohmann::json input_desc = input_descs[i][0]; | |||||
| std::string name = input_desc[kJsonKeyTensorName]; | |||||
| if (input_desc.find(kJsonKeyValue) != input_desc.end()) { | |||||
| inputs.push_back(DecodeValueNode(input_desc, func_graph)); | |||||
| } else if (nodes_map_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "Input: " << name << " of: " << op_name << " not found."; | |||||
| return false; | |||||
| } else { | |||||
| inputs.push_back(nodes_map_[name]); | |||||
| } | |||||
| input_formats_.push_back(input_desc[kJsonKeyFormat]); | |||||
| input_types_.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType])); | |||||
| } | |||||
| // new cnode. | |||||
| cnode_ = func_graph->NewCNode(inputs); | |||||
| func_graph->AddNode(cnode_); | |||||
| return true; | |||||
| } | } | ||||
| std::vector<py::object> arg_list; | |||||
| (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), | |||||
| [](const ValuePtr &attr) { return ValuePtrToPyData(attr); }); | |||||
| py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule, | |||||
| op_name, arg_list); | |||||
| ValuePtr op_instance = nullptr; | |||||
| bool succ = parse::ConvertData(obj, &op_instance); | |||||
| if (!succ) { | |||||
| MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed."; | |||||
| return nullptr; | |||||
| bool DecodeOutputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { | |||||
| std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc]; | |||||
| AbstractBasePtr abstract(nullptr); | |||||
| if (output_descs.empty()) { | |||||
| MS_LOG(ERROR) << "No outputs found."; | |||||
| return false; | |||||
| } else if (output_descs.size() == 1) { | |||||
| // single output. | |||||
| nlohmann::json output_desc = output_descs[0]; | |||||
| output_formats_.push_back(output_desc[kJsonKeyFormat]); | |||||
| output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | |||||
| nodes_map_[output_desc[kJsonKeyTensorName]] = cnode_; | |||||
| } else { | |||||
| // multi outputs. | |||||
| for (size_t j = 0; j < output_descs.size(); ++j) { | |||||
| nlohmann::json output_desc = output_descs[j]; | |||||
| output_formats_.push_back(output_desc[kJsonKeyFormat]); | |||||
| output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | |||||
| auto get_item = | |||||
| func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_, NewValueNode(SizeToInt(j))}); | |||||
| func_graph->AddNode(get_item); | |||||
| nodes_map_[output_desc[kJsonKeyTensorName]] = get_item; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | } | ||||
| return op_instance; | |||||
| } | |||||
| PrimitivePtr GetPrimitive(const std::string &op_name, const std::map<std::string, ValuePtr> &attrs_val) { | |||||
| PrimitivePtr primitive{nullptr}; | |||||
| if (op_attrs_map.count(op_name) == 0) { | |||||
| // no attrs for op instance. | |||||
| primitive = CreatOpInstance(op_name, std::vector<ValuePtr>{})->cast<PrimitivePtr>(); | |||||
| } else { | |||||
| // make attrs for op instance. | |||||
| std::vector<ValuePtr> op_attrs; | |||||
| const auto &attr_names = op_attrs_map.at(op_name); | |||||
| for (const auto &attr_name : attr_names) { | |||||
| if (attrs_val.count(attr_name) == 0) { | |||||
| MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found."; | |||||
| return nullptr; | |||||
| void CreateKernelInfo(kernel::Processor processor) { | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| std::vector<size_t> feature_map_input_indexs; | |||||
| // if the node only has the primitive(such as getNext) or the node's input has a feature map input | |||||
| // then the node's output is a feature map output | |||||
| const auto &inputs = cnode_->inputs(); | |||||
| for (size_t index = 1; index < inputs.size(); ++index) { | |||||
| auto node = AnfAlgo::VisitKernel(inputs[index], 0); | |||||
| if (AnfAlgo::IsFeatureMapOutput(node.first)) { | |||||
| feature_map_input_indexs.push_back(index); | |||||
| } | } | ||||
| op_attrs.push_back(attrs_val.at(attr_name)); | |||||
| } | } | ||||
| primitive = CreatOpInstance(op_name, op_attrs)->cast<PrimitivePtr>(); | |||||
| if (AnfAlgo::GetCNodeName(cnode_) == prim::kPrimCast->name()) { | |||||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode_); | |||||
| } | |||||
| if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { | |||||
| kernel_info->SetFeatureMapFlag(true); | |||||
| } | |||||
| if (AnfAlgo::IsRealCNodeKernel(cnode_)) { | |||||
| AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_); | |||||
| AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode_); | |||||
| } | |||||
| cnode_->set_kernel_info(kernel_info); | |||||
| // create kernel_build_info. | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| builder->SetInputsFormat(input_formats_); | |||||
| builder->SetInputsDeviceType(input_types_); | |||||
| builder->SetOutputsFormat(output_formats_); | |||||
| builder->SetOutputsDeviceType(output_types_); | |||||
| builder->SetProcessor(processor); | |||||
| builder->SetKernelType(KernelType::AKG_KERNEL); | |||||
| builder->SetFusionType(kernel::FusionType::OPAQUE); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get()); | |||||
| } | } | ||||
| if (primitive != nullptr) { | |||||
| for (const auto &attr : attrs_val) { | |||||
| primitive->AddAttr(attr.first, attr.second); | |||||
| ValuePtr CreatOpInstance(const std::string &op_name, const std::vector<ValuePtr> &attrs) { | |||||
| // python utils. | |||||
| constexpr auto kGetPythonOpFunc = "_get_python_op"; | |||||
| constexpr auto kParallelUtilsModule = "mindspore.parallel._utils"; | |||||
| // almost all ops are defined in this path. | |||||
| constexpr auto kOperationsModule = "mindspore.ops.operations"; | |||||
| py::module mod = py::module::import(kOperationsModule); | |||||
| if (!py::hasattr(mod, op_name.c_str())) { | |||||
| MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name; | |||||
| return nullptr; | |||||
| } | } | ||||
| std::vector<py::object> arg_list; | |||||
| (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), | |||||
| [](const ValuePtr &attr) { return ValuePtrToPyData(attr); }); | |||||
| py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule, | |||||
| op_name, arg_list); | |||||
| ValuePtr op_instance = nullptr; | |||||
| bool succ = parse::ConvertData(obj, &op_instance); | |||||
| if (!succ) { | |||||
| MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed."; | |||||
| return nullptr; | |||||
| } | |||||
| return op_instance; | |||||
| } | } | ||||
| return primitive; | |||||
| } | |||||
| } // namespace | |||||
| const std::map<std::string, std::vector<std::string>> op_attrs_map_ = { | |||||
| {kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| {kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| {kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}}, | |||||
| }; | |||||
| constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; | |||||
| constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; | |||||
| PrimitivePtr GetPrimitive(const std::string &op_name) { | |||||
| PrimitivePtr primitive{nullptr}; | |||||
| if (op_attrs_map_.count(op_name) == 0) { | |||||
| // no attrs for op instance. | |||||
| primitive = CreatOpInstance(op_name, std::vector<ValuePtr>{})->cast<PrimitivePtr>(); | |||||
| } else { | |||||
| // make attrs for op instance. | |||||
| std::vector<ValuePtr> op_attrs; | |||||
| const auto &attr_names = op_attrs_map_.at(op_name); | |||||
| for (const auto &attr_name : attr_names) { | |||||
| if (cnode_attrs_.count(attr_name) == 0) { | |||||
| MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found."; | |||||
| return nullptr; | |||||
| } | |||||
| op_attrs.push_back(cnode_attrs_.at(attr_name)); | |||||
| } | |||||
| primitive = CreatOpInstance(op_name, op_attrs)->cast<PrimitivePtr>(); | |||||
| } | |||||
| if (primitive != nullptr) { | |||||
| for (const auto &attr : cnode_attrs_) { | |||||
| primitive->AddAttr(attr.first, attr.second); | |||||
| } | |||||
| } | |||||
| return primitive; | |||||
| } | |||||
| ScalarPtr AkgKernelJsonDecoder::DecodeScalar(const nlohmann::json &scalar_json) { | |||||
| auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); | |||||
| switch (type_id) { | |||||
| case kNumberTypeFloat16: | |||||
| case kNumberTypeFloat32: | |||||
| return std::make_shared<FP32Imm>(scalar_json[kJsonKeyValue]); | |||||
| case kNumberTypeInt32: | |||||
| return std::make_shared<Int32Imm>(scalar_json[kJsonKeyValue]); | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType]; | |||||
| break; | |||||
| ScalarPtr DecodeScalar(const nlohmann::json &scalar_json) { | |||||
| auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); | |||||
| switch (type_id) { | |||||
| case kNumberTypeFloat16: | |||||
| case kNumberTypeFloat32: | |||||
| return std::make_shared<FP32Imm>(scalar_json[kJsonKeyValue]); | |||||
| case kNumberTypeInt32: | |||||
| return std::make_shared<Int32Imm>(scalar_json[kJsonKeyValue]); | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType]; | |||||
| break; | |||||
| } | |||||
| return nullptr; | |||||
| } | } | ||||
| return nullptr; | |||||
| } | |||||
| ValueNodePtr AkgKernelJsonDecoder::DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) { | |||||
| MS_LOG(DEBUG) << "start decode value node, " << value_json; | |||||
| auto scalar = DecodeScalar(value_json); | |||||
| auto tensor = ScalarToTensor(scalar); | |||||
| ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) { | |||||
| MS_LOG(DEBUG) << "start decode value node, " << value_json; | |||||
| auto scalar = DecodeScalar(value_json); | |||||
| auto tensor = ScalarToTensor(scalar); | |||||
| auto value_node = std::make_shared<ValueNode>(tensor); | |||||
| value_node->set_abstract(tensor->ToAbstract()); | |||||
| // create kernel_info fo new value node. | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| value_node->set_kernel_info(kernel_info); | |||||
| // create kernel_build_info for new value node. | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| // layout info. | |||||
| builder->SetOutputsFormat(std::vector<std::string>{value_json[kJsonKeyFormat]}); | |||||
| builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(value_json[kJsonKeyDataType])}); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get()); | |||||
| func_graph->AddValueNode(value_node); | |||||
| MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2); | |||||
| return value_node; | |||||
| } | |||||
| auto value_node = std::make_shared<ValueNode>(tensor); | |||||
| value_node->set_abstract(tensor->ToAbstract()); | |||||
| // create kernel_info fo new value node. | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| value_node->set_kernel_info(kernel_info); | |||||
| // create kernel_build_info for new value node. | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| // layout info. | |||||
| builder->SetOutputsFormat(std::vector<std::string>{value_json[kJsonKeyFormat]}); | |||||
| builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(value_json[kJsonKeyDataType])}); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get()); | |||||
| func_graph->AddValueNode(value_node); | |||||
| MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2); | |||||
| return value_node; | |||||
| } | |||||
| std::map<std::string, AnfNodePtr> &nodes_map_; | |||||
| std::map<std::string, ValuePtr> cnode_attrs_; | |||||
| std::vector<std::string> input_formats_; | |||||
| std::vector<std::string> output_formats_; | |||||
| std::vector<TypeId> input_types_; | |||||
| std::vector<TypeId> output_types_; | |||||
| CNodePtr cnode_{nullptr}; | |||||
| }; | |||||
| } // namespace | |||||
| ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶meter_json, | ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶meter_json, | ||||
| const FuncGraphPtr &func_graph) { | const FuncGraphPtr &func_graph) { | ||||
| @@ -208,118 +331,35 @@ ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶met | |||||
| CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, | CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, | ||||
| const std::string &processor) { | const std::string &processor) { | ||||
| CNodeDecoder decoder(&nodes_map_); | |||||
| Processor p = kernel::GetProcessor(processor); | Processor p = kernel::GetProcessor(processor); | ||||
| MS_LOG(DEBUG) << "start decode cnode, " << cnode_json; | |||||
| // decode attrs. | |||||
| std::map<std::string, ValuePtr> cnode_attrs; | |||||
| if (!DecodeAttrs(cnode_json, &cnode_attrs)) { | |||||
| MS_LOG(ERROR) << "Error decode attrs."; | |||||
| return nullptr; | |||||
| } | |||||
| std::string op_name = cnode_json[kJsonKeyName]; | |||||
| // new primitive. | |||||
| auto primitive = GetPrimitive(op_name, cnode_attrs); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "Create primitive failed."; | |||||
| return nullptr; | |||||
| } | |||||
| // data layout info. | |||||
| std::vector<std::string> input_formats; | |||||
| std::vector<TypeId> input_types; | |||||
| std::vector<std::string> output_formats; | |||||
| std::vector<TypeId> output_types; | |||||
| return decoder.DecodeCNode(cnode_json, func_graph, p); | |||||
| } | |||||
| // collect inputs. | |||||
| auto primitive_v = NewValueNode(primitive); | |||||
| func_graph->AddValueNode(primitive_v); | |||||
| std::vector<AnfNodePtr> inputs{primitive_v}; | |||||
| std::vector<nlohmann::json> input_descs = cnode_json[kJsonKeyInputDesc]; | |||||
| for (size_t i = 0; i < input_descs.size(); ++i) { | |||||
| nlohmann::json input_desc = input_descs[i][0]; | |||||
| std::string name = input_desc[kJsonKeyTensorName]; | |||||
| if (input_desc.find(kJsonKeyValue) != input_desc.end()) { | |||||
| inputs.push_back(DecodeValueNode(input_desc, func_graph)); | |||||
| } else if (nodes_map_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "Input: " << name << " of: " << op_name << " not found."; | |||||
| AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector<nlohmann::json> &output_descs, | |||||
| const FuncGraphPtr &func_graph) { | |||||
| std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)}; | |||||
| for (const auto &output_desc : output_descs) { | |||||
| std::string name = output_desc[kJsonKeyTensorName]; | |||||
| if (nodes_map_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "Output: " << name << " of graph not found."; | |||||
| return nullptr; | return nullptr; | ||||
| } else { | |||||
| inputs.push_back(nodes_map_[name]); | |||||
| } | } | ||||
| input_formats.push_back(input_desc[kJsonKeyFormat]); | |||||
| input_types.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType])); | |||||
| outputs.push_back(nodes_map_[name]); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "decode inputs success."; | |||||
| // new cnode. | |||||
| auto cnode = func_graph->NewCNode(inputs); | |||||
| func_graph->AddNode(cnode); | |||||
| // decode outputs. | |||||
| std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc]; | |||||
| AbstractBasePtr abstract(nullptr); | |||||
| if (output_descs.empty()) { | |||||
| MS_LOG(ERROR) << "No outputs found."; | |||||
| return nullptr; | |||||
| } else if (output_descs.size() == 1) { | |||||
| // single output. | |||||
| nlohmann::json output_desc = output_descs[0]; | |||||
| output_formats.push_back(output_desc[kJsonKeyFormat]); | |||||
| output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | |||||
| nodes_map_[output_desc[kJsonKeyTensorName]] = cnode; | |||||
| if (outputs.size() == 2) { | |||||
| func_graph->set_output(outputs[1]); | |||||
| } else { | } else { | ||||
| // multi outputs. | |||||
| for (size_t j = 0; j < output_descs.size(); ++j) { | |||||
| nlohmann::json output_desc = output_descs[j]; | |||||
| output_formats.push_back(output_desc[kJsonKeyFormat]); | |||||
| output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); | |||||
| auto get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, NewValueNode(SizeToInt(j))}); | |||||
| func_graph->AddNode(get_item); | |||||
| nodes_map_[output_desc[kJsonKeyTensorName]] = get_item; | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "decode outputs success."; | |||||
| // create kernel_info. | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| std::vector<size_t> feature_map_input_indexs; | |||||
| // if the node only has the primitive(such as getNext) or the node's input has a feature map input | |||||
| // then the node's output is a feature map output | |||||
| for (size_t index = 1; index < inputs.size(); ++index) { | |||||
| auto node = AnfAlgo::VisitKernel(inputs[index], 0); | |||||
| if (AnfAlgo::IsFeatureMapOutput(node.first)) { | |||||
| feature_map_input_indexs.push_back(index); | |||||
| } | |||||
| auto output = func_graph->NewCNode(outputs); | |||||
| func_graph->AddNode(output); | |||||
| func_graph->set_output(output); | |||||
| } | } | ||||
| if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { | |||||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); | |||||
| } | |||||
| if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { | |||||
| kernel_info->SetFeatureMapFlag(true); | |||||
| } | |||||
| if (AnfAlgo::IsRealCNodeKernel(cnode)) { | |||||
| AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); | |||||
| AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); | |||||
| } | |||||
| cnode->set_kernel_info(kernel_info); | |||||
| // create kernel_build_info. | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| builder->SetInputsFormat(input_formats); | |||||
| builder->SetInputsDeviceType(input_types); | |||||
| builder->SetOutputsFormat(output_formats); | |||||
| builder->SetOutputsDeviceType(output_types); | |||||
| builder->SetProcessor(p); | |||||
| builder->SetKernelType(KernelType::AKG_KERNEL); | |||||
| builder->SetFusionType(kernel::FusionType::OPAQUE); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get()); | |||||
| return cnode; | |||||
| return func_graph->output(); | |||||
| } | } | ||||
| FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel_json) { | FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel_json) { | ||||
| MS_LOG(DEBUG) << "start decode, " << kernel_json; | MS_LOG(DEBUG) << "start decode, " << kernel_json; | ||||
| // clear cache. | |||||
| nodes_map_.clear(); | nodes_map_.clear(); | ||||
| // create a graph. | |||||
| auto graph = std::make_shared<FuncGraph>(); | auto graph = std::make_shared<FuncGraph>(); | ||||
| // decode parameters. | // decode parameters. | ||||
| @@ -331,10 +371,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel | |||||
| for (size_t i = 0; i < input_descs.size(); ++i) { | for (size_t i = 0; i < input_descs.size(); ++i) { | ||||
| std::vector<nlohmann::json> input_desc = input_descs[i]; | std::vector<nlohmann::json> input_desc = input_descs[i]; | ||||
| auto parameter = DecodeParameter(input_desc[0], graph); | auto parameter = DecodeParameter(input_desc[0], graph); | ||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Error decode parameter."; | |||||
| return nullptr; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "decode parameters success."; | MS_LOG(DEBUG) << "decode parameters success."; | ||||
| @@ -346,10 +383,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel | |||||
| } | } | ||||
| for (const auto &op_desc : op_node_descs) { | for (const auto &op_desc : op_node_descs) { | ||||
| auto op_node = DecodeCNode(op_desc, graph, kernel_json[kJsonKeyProcess]); | auto op_node = DecodeCNode(op_desc, graph, kernel_json[kJsonKeyProcess]); | ||||
| if (op_node == nullptr) { | |||||
| MS_LOG(ERROR) << "Error decode cnode."; | |||||
| return nullptr; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(op_node); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "decode cnodes success."; | MS_LOG(DEBUG) << "decode cnodes success."; | ||||
| @@ -359,22 +393,8 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel | |||||
| MS_LOG(ERROR) << "Error decode outputs, no outputs for graph."; | MS_LOG(ERROR) << "Error decode outputs, no outputs for graph."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)}; | |||||
| for (const auto &output_desc : output_descs) { | |||||
| std::string name = output_desc[kJsonKeyTensorName]; | |||||
| if (nodes_map_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "Output: " << name << " of graph not found."; | |||||
| return nullptr; | |||||
| } | |||||
| outputs.push_back(nodes_map_[name]); | |||||
| } | |||||
| if (outputs.size() == 2) { | |||||
| graph->set_output(outputs[1]); | |||||
| } else { | |||||
| auto output = graph->NewCNode(outputs); | |||||
| graph->AddNode(output); | |||||
| graph->set_output(output); | |||||
| } | |||||
| auto output = DecodeOutput(output_descs, graph); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| MS_LOG(DEBUG) << "decode success, " << kernel_json; | MS_LOG(DEBUG) << "decode success, " << kernel_json; | ||||
| return graph; | return graph; | ||||
| } | } | ||||
| @@ -37,11 +37,10 @@ class AkgKernelJsonDecoder { | |||||
| AnfNodePtrList *res_graphs); | AnfNodePtrList *res_graphs); | ||||
| private: | private: | ||||
| ScalarPtr DecodeScalar(const nlohmann::json &scalar_json); | |||||
| ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph); | |||||
| ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); | ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); | ||||
| CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); | CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); | ||||
| std::map<std::string, AnfNodePtr> nodes_map_{}; | |||||
| AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph); | |||||
| std::map<std::string, AnfNodePtr> nodes_map_; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -79,22 +79,16 @@ inline std::string AkgKernelJsonGenerator::GetOutputFormat(const AnfNodePtr &anf | |||||
| bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info, | bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info, | ||||
| nlohmann::json *const inputs_json) { | nlohmann::json *const inputs_json) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| MS_EXCEPTION_IF_NULL(inputs_json); | |||||
| // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. | // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr(); | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr(); | ||||
| if (inputs_ptr.empty()) { | if (inputs_ptr.empty()) { | ||||
| MS_LOG(DEBUG) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info"; | |||||
| return true; | |||||
| MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info"; | |||||
| return false; | |||||
| } | } | ||||
| // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. | // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. | ||||
| auto dyn_input_sizes = GetDynInputSize(anf_node); | auto dyn_input_sizes = GetDynInputSize(anf_node); | ||||
| size_t real_input_index = 0; | size_t real_input_index = 0; | ||||
| std::vector<nlohmann::json> input_list; | |||||
| for (size_t i = 0; i < inputs_ptr.size(); i++) { | for (size_t i = 0; i < inputs_ptr.size(); i++) { | ||||
| std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i]; | std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i]; | ||||
| if (input_ptr == nullptr) { | if (input_ptr == nullptr) { | ||||
| @@ -102,10 +96,8 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto op_input_name = input_ptr->name(); | |||||
| size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : IntToSize(dyn_input_sizes[i]); | size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : IntToSize(dyn_input_sizes[i]); | ||||
| input_list.clear(); | |||||
| std::vector<nlohmann::json> input_list; | |||||
| for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { | for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { | ||||
| auto type_id = this->GetInputDataType(anf_node, real_input_index); | auto type_id = this->GetInputDataType(anf_node, real_input_index); | ||||
| std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); | std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); | ||||
| @@ -117,7 +109,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con | |||||
| nlohmann::json input_desc_json; | nlohmann::json input_desc_json; | ||||
| input_desc_json[kJsonKeyDataType] = dtype; | input_desc_json[kJsonKeyDataType] = dtype; | ||||
| input_desc_json[kJsonKeyFormat] = this->GetInputFormat(anf_node, real_input_index); | input_desc_json[kJsonKeyFormat] = this->GetInputFormat(anf_node, real_input_index); | ||||
| input_desc_json[kJsonKeyName] = op_input_name; | |||||
| input_desc_json[kJsonKeyName] = input_ptr->name(); | |||||
| input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | ||||
| auto input_shape = this->GetInputShape(anf_node, real_input_index); | auto input_shape = this->GetInputShape(anf_node, real_input_index); | ||||
| if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && | if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && | ||||
| @@ -204,77 +196,56 @@ void AkgKernelJsonGenerator::GetJson(const AnfNodePtr &anf_node, const std::vect | |||||
| bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info, | bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info, | ||||
| nlohmann::json *const attrs_json) { | nlohmann::json *const attrs_json) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| MS_EXCEPTION_IF_NULL(attrs_json); | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs = op_info->attrs_ptr(); | std::vector<std::shared_ptr<OpAttr>> attrs = op_info->attrs_ptr(); | ||||
| if (attrs.empty()) { | if (attrs.empty()) { | ||||
| MS_LOG(INFO) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty"; | |||||
| MS_LOG(DEBUG) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty"; | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info->inputs_ptr(); | |||||
| std::vector<int> dyn_input_sizes; | |||||
| auto dyn_input_sizes = GetDynInputSize(anf_node); | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { | |||||
| dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes)); | |||||
| } | |||||
| if (inputs.empty()) { | |||||
| MS_LOG(ERROR) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info inputs is empty"; | |||||
| return false; | |||||
| } | |||||
| // create input name list for "x_shape" in attr with "x" in primitive. | // create input name list for "x_shape" in attr with "x" in primitive. | ||||
| std::map<size_t, std::string> op_info_shape_name; | |||||
| for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) { | |||||
| std::string input_name = inputs[op_info_input_i]->name(); | |||||
| std::string x_shape_name = input_name + "_shape"; | |||||
| static_cast<void>(op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name))); | |||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info->inputs_ptr(); | |||||
| std::map<std::string, size_t> op_info_shape_name; | |||||
| for (size_t i = 0; i < inputs.size(); i++) { | |||||
| op_info_shape_name[inputs[i]->name() + "_shape"] = i; | |||||
| } | } | ||||
| for (const auto &op_attr : attrs) { | for (const auto &op_attr : attrs) { | ||||
| nlohmann::json attr_json; | nlohmann::json attr_json; | ||||
| ValuePtr attr_value = primitive->GetAttr(op_attr->name()); | ValuePtr attr_value = primitive->GetAttr(op_attr->name()); | ||||
| if (attr_value == nullptr && op_attr->name() != kArgDataformat) { | if (attr_value == nullptr && op_attr->name() != kArgDataformat) { | ||||
| if (op_attr->param_type() == "required") { | |||||
| // match "x_shape" in att with "x" in primitive. | |||||
| std::string attr_name = op_attr->name(); | |||||
| auto find_item = std::find_if( | |||||
| op_info_shape_name.begin(), op_info_shape_name.end(), | |||||
| [attr_name](const std::map<size_t, std::string>::value_type item) { return item.second == attr_name; }); | |||||
| if (find_item != op_info_shape_name.end()) { | |||||
| if (!dyn_input_sizes.empty()) { | |||||
| if (find_item->first >= dyn_input_sizes.size() - 1) { | |||||
| MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first | |||||
| << " is out of range:" << dyn_input_sizes.size() - 1 << "."; | |||||
| return false; | |||||
| } | |||||
| size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0)); | |||||
| for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) { | |||||
| attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx); | |||||
| attr_json[kJsonKeyName] = op_attr->name(); | |||||
| attrs_json->push_back(attr_json); | |||||
| tensor_idx++; | |||||
| } | |||||
| } else { | |||||
| attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first); | |||||
| if (op_attr->param_type() != "required") continue; | |||||
| // match "x_shape" in attr with "x" in primitive. | |||||
| auto find_item = op_info_shape_name.find(op_attr->name()); | |||||
| if (find_item != op_info_shape_name.end()) { | |||||
| if (!dyn_input_sizes.empty()) { | |||||
| if (find_item->second >= dyn_input_sizes.size() - 1) { | |||||
| MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->second | |||||
| << " is out of range:" << dyn_input_sizes.size() - 1 << "."; | |||||
| return false; | |||||
| } | |||||
| size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->second], 0)); | |||||
| for (int input_i = 0; input_i < dyn_input_sizes[find_item->second]; input_i++) { | |||||
| attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx); | |||||
| attr_json[kJsonKeyName] = op_attr->name(); | attr_json[kJsonKeyName] = op_attr->name(); | ||||
| attrs_json->push_back(attr_json); | attrs_json->push_back(attr_json); | ||||
| tensor_idx++; | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "op [" << anf_node->fullname_with_scope() << "] should have attr :" << op_attr->name(); | |||||
| return false; | |||||
| attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->second); | |||||
| attr_json[kJsonKeyName] = op_attr->name(); | |||||
| attrs_json->push_back(attr_json); | |||||
| } | } | ||||
| } else { | |||||
| MS_LOG(ERROR) << "op [" << anf_node->fullname_with_scope() << "] should have attr :" << op_attr->name(); | |||||
| return false; | |||||
| } | } | ||||
| continue; | |||||
| } else { | |||||
| GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value); | |||||
| attr_json[kJsonKeyName] = op_attr->name(); | |||||
| attrs_json->push_back(attr_json); | |||||
| } | } | ||||
| GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value); | |||||
| attr_json[kJsonKeyName] = op_attr->name(); | |||||
| attrs_json->push_back(attr_json); | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -485,7 +456,47 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() | MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() | ||||
| << "], output_list: [" << input_list.size() << "]."; | << "], output_list: [" << input_list.size() << "]."; | ||||
| std::map<AnfNodePtr, nlohmann::json> node_json_map; | std::map<AnfNodePtr, nlohmann::json> node_json_map; | ||||
| if (!GenSingleJsons(anf_nodes, &node_json_map)) return false; | |||||
| UpdateTensorName(anf_nodes, &node_json_map); | |||||
| std::vector<nlohmann::json> node_json_desc; | |||||
| std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), | |||||
| [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); | |||||
| (*kernel_json)[kJsonKeyOpDesc] = node_json_desc; | |||||
| auto inputs_json = CreateInputsJson(anf_nodes, input_list, node_json_map); | |||||
| (*kernel_json)[kJsonKeyInputDesc] = inputs_json; | |||||
| (*kernel_json)[kJsonKeyOutputDesc] = | |||||
| CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); | |||||
| size_t hash_id = std::hash<std::string>()(kernel_json->dump()); | |||||
| kernel_name_ = "Fused_"; | |||||
| auto fg = anf_nodes[0]->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||||
| if (attr_val != nullptr) { | |||||
| auto fg_attr = GetValue<std::string>(attr_val); | |||||
| (void)kernel_name_.append(fg_attr).append("_"); | |||||
| } | |||||
| (void)kernel_name_.append(std::to_string(hash_id)); | |||||
| (*kernel_json)[kJsonKeyId] = GetOpCntInc(); | |||||
| (*kernel_json)[kJsonKeyOp] = kernel_name_; | |||||
| (*kernel_json)[kJsonKeyPlatform] = "AKG"; | |||||
| (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); | |||||
| (*kernel_json)[kJsonKeyComposite] = true; | |||||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString(); | |||||
| if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { | |||||
| MS_LOG(ERROR) << "Cal mem size failed."; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AkgKernelJsonGenerator::GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, | |||||
| std::map<AnfNodePtr, nlohmann::json> *node_json_map) { | |||||
| for (auto const &anf_node : anf_nodes) { | for (auto const &anf_node : anf_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| if (!AnfAlgo::IsRealKernel(anf_node)) { | if (!AnfAlgo::IsRealKernel(anf_node)) { | ||||
| @@ -507,9 +518,13 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); | node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); | ||||
| } | } | ||||
| node_json_map[anf_node] = node_json; | |||||
| (*node_json_map)[anf_node] = node_json; | |||||
| } | } | ||||
| return true; | |||||
| } | |||||
| void AkgKernelJsonGenerator::UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes, | |||||
| std::map<AnfNodePtr, nlohmann::json> *node_json_map) { | |||||
| for (auto const &anf_node : anf_nodes) { | for (auto const &anf_node : anf_nodes) { | ||||
| auto dyn_input_sizes = GetDynInputSize(anf_node); | auto dyn_input_sizes = GetDynInputSize(anf_node); | ||||
| bool is_dynamic_input = !dyn_input_sizes.empty(); | bool is_dynamic_input = !dyn_input_sizes.empty(); | ||||
| @@ -519,11 +534,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; | size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; | ||||
| for (size_t j = 0; j < input_tensor_num; ++j) { | for (size_t j = 0; j < input_tensor_num; ++j) { | ||||
| auto tmp_input = GetKernelInput(anf_node, real_input_index); | auto tmp_input = GetKernelInput(anf_node, real_input_index); | ||||
| std::string tensor_name = GetTensorName(node_json_map[anf_node], kJsonKeyInputDesc, std::make_pair(i, j)); | |||||
| if (node_json_map.find(tmp_input.first) != node_json_map.end()) { | |||||
| std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kJsonKeyInputDesc, std::make_pair(i, j)); | |||||
| if (node_json_map->find(tmp_input.first) != node_json_map->end()) { | |||||
| std::string new_tensor_name = | std::string new_tensor_name = | ||||
| GetTensorName(node_json_map[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second)); | |||||
| SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node])); | |||||
| GetTensorName((*node_json_map)[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second)); | |||||
| SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node])); | |||||
| MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" | MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" | ||||
| << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" | << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" | ||||
| << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; | << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; | ||||
| @@ -535,12 +550,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| std::vector<nlohmann::json> node_json_desc; | |||||
| std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), | |||||
| [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); | |||||
| (*kernel_json)[kJsonKeyOpDesc] = node_json_desc; | |||||
| nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector<AnfNodePtr> &anf_nodes, | |||||
| const std::vector<AnfNodePtr> &input_list, | |||||
| const std::map<AnfNodePtr, nlohmann::json> &node_json_map) { | |||||
| nlohmann::json inputs_json; | nlohmann::json inputs_json; | ||||
| auto input_index = GetInputIndex(anf_nodes, input_list); | auto input_index = GetInputIndex(anf_nodes, input_list); | ||||
| for (size_t i = 0; i < input_index.size(); ++i) { | for (size_t i = 0; i < input_index.size(); ++i) { | ||||
| @@ -549,18 +563,22 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); | std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); | ||||
| nlohmann::json input_desc_json; | nlohmann::json input_desc_json; | ||||
| input_desc_json[kJsonKeyTensorName] = | input_desc_json[kJsonKeyTensorName] = | ||||
| GetTensorName(node_json_map[tmp_input.first], kJsonKeyInputDesc, tmp_input.second); | |||||
| GetTensorName(node_json_map.at(tmp_input.first), kJsonKeyInputDesc, tmp_input.second); | |||||
| input_desc_json[kJsonKeyDataType] = dtype; | input_desc_json[kJsonKeyDataType] = dtype; | ||||
| input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first); | input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first); | ||||
| input_desc_json[kJsonKeyShape] = this->GetInputShape(tmp_input.first, tmp_input.second.first); | input_desc_json[kJsonKeyShape] = this->GetInputShape(tmp_input.first, tmp_input.second.first); | ||||
| inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json}); | inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json}); | ||||
| } | } | ||||
| (*kernel_json)[kJsonKeyInputDesc] = inputs_json; | |||||
| return inputs_json; | |||||
| } | |||||
| nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, | |||||
| const std::vector<AnfNodePtr> &input_list, | |||||
| const std::vector<AnfNodePtr> &output_list, | |||||
| const nlohmann::json &inputs_json, | |||||
| const std::map<AnfNodePtr, nlohmann::json> &node_json_map) { | |||||
| nlohmann::json outputs_json; | nlohmann::json outputs_json; | ||||
| auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); | auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); | ||||
| std::map<size_t, std::vector<std::string>> sub_graphs; | |||||
| std::map<size_t, size_t> dim_infos; | |||||
| for (size_t i = 0; i < output_index.size(); ++i) { | for (size_t i = 0; i < output_index.size(); ++i) { | ||||
| auto tmp_output = output_index[i]; | auto tmp_output = output_index[i]; | ||||
| bool found = false; | bool found = false; | ||||
| @@ -576,7 +594,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second); | auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second); | ||||
| std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); | std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); | ||||
| output_desc_json[kJsonKeyTensorName] = | output_desc_json[kJsonKeyTensorName] = | ||||
| GetTensorName(node_json_map[tmp_output.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second)); | |||||
| GetTensorName(node_json_map.at(tmp_output.first), kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second)); | |||||
| output_desc_json[kJsonKeyDataType] = dtype; | output_desc_json[kJsonKeyDataType] = dtype; | ||||
| output_desc_json[kJsonKeyFormat] = this->GetOutputFormat(tmp_output.first, tmp_output.second); | output_desc_json[kJsonKeyFormat] = this->GetOutputFormat(tmp_output.first, tmp_output.second); | ||||
| auto output_shape = this->GetOutputShape(tmp_output.first, tmp_output.second); | auto output_shape = this->GetOutputShape(tmp_output.first, tmp_output.second); | ||||
| @@ -587,33 +605,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf | |||||
| } | } | ||||
| outputs_json.emplace_back(output_desc_json); | outputs_json.emplace_back(output_desc_json); | ||||
| } | } | ||||
| (*kernel_json)[kJsonKeyOutputDesc] = outputs_json; | |||||
| auto processor = GetProcessorStr(anf_nodes[0]); | |||||
| size_t hash_id = std::hash<std::string>()(kernel_json->dump()); | |||||
| kernel_name_ = "Fused_"; | |||||
| auto fg = anf_nodes[0]->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||||
| if (attr_val != nullptr) { | |||||
| auto fg_attr = GetValue<std::string>(attr_val); | |||||
| (void)kernel_name_.append(fg_attr).append("_"); | |||||
| } | |||||
| (void)kernel_name_.append(std::to_string(hash_id)); | |||||
| (*kernel_json)[kJsonKeyId] = GetOpCntInc(); | |||||
| (*kernel_json)[kJsonKeyOp] = kernel_name_; | |||||
| (*kernel_json)[kJsonKeyPlatform] = "AKG"; | |||||
| (*kernel_json)[kJsonKeyProcess] = processor; | |||||
| (*kernel_json)[kJsonKeyComposite] = true; | |||||
| (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString(); | |||||
| if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { | |||||
| MS_LOG(ERROR) << "Cal mem size failed."; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| return outputs_json; | |||||
| } | } | ||||
| bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node) { | bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node) { | ||||
| @@ -94,6 +94,14 @@ class AkgKernelJsonGenerator { | |||||
| nlohmann::json *const attrs_json); | nlohmann::json *const attrs_json); | ||||
| bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size, | bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size, | ||||
| std::vector<size_t> *const output_size); | std::vector<size_t> *const output_size); | ||||
| bool GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map); | |||||
| void UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map); | |||||
| nlohmann::json CreateInputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list, | |||||
| const std::map<AnfNodePtr, nlohmann::json> &node_json_map); | |||||
| nlohmann::json CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list, | |||||
| const std::vector<AnfNodePtr> &output_list, const nlohmann::json &inputs_json, | |||||
| const std::map<AnfNodePtr, nlohmann::json> &node_json_map); | |||||
| int GetOpCntInc(); | int GetOpCntInc(); | ||||
| size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); | size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); | ||||
| size_t GetOutputTensorIdxInc(); | size_t GetOutputTensorIdxInc(); | ||||
| @@ -36,15 +36,23 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| namespace { | |||||
| constexpr int32_t PROCESS_NUM = 16; | constexpr int32_t PROCESS_NUM = 16; | ||||
| constexpr int32_t TIME_OUT = 300; | constexpr int32_t TIME_OUT = 300; | ||||
| bool AkgAscendKernelBuilder::AkgOpParallelBuild( | |||||
| const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args) { | |||||
| void SetKernelMod(const KernelPackPtr &kernel_pack, const AkgKernelJsonGenerator &json_generator, | |||||
| const AnfNodePtr &anf_node) { | |||||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(kernel_pack); | |||||
| kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list()); | |||||
| kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list()); | |||||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||||
| } | |||||
| } // namespace | |||||
| std::vector<std::string> AkgAscendKernelBuilder::GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args) { | |||||
| // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. | // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. | ||||
| std::vector<std::string> jsons; | std::vector<std::string> jsons; | ||||
| std::unordered_set<std::string> kernel_name_set; | std::unordered_set<std::string> kernel_name_set; | ||||
| std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> repeat_nodes; | |||||
| for (const auto &[json_generator, anf_node] : build_args) { | for (const auto &[json_generator, anf_node] : build_args) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| auto kernel_name = json_generator.kernel_name(); | auto kernel_name = json_generator.kernel_name(); | ||||
| @@ -53,15 +61,12 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild( | |||||
| if (cached_kernel_pack != nullptr) { | if (cached_kernel_pack != nullptr) { | ||||
| MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" | MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" | ||||
| << anf_node->fullname_with_scope() << "]."; | << anf_node->fullname_with_scope() << "]."; | ||||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack); | |||||
| kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list()); | |||||
| kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list()); | |||||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||||
| SetKernelMod(cached_kernel_pack, json_generator, anf_node); | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (kernel_name_set.count(kernel_name) != 0) { | if (kernel_name_set.count(kernel_name) != 0) { | ||||
| repeat_nodes.push_back({json_generator, anf_node}); | |||||
| repeat_nodes_.push_back({json_generator, anf_node}); | |||||
| continue; | continue; | ||||
| } | } | ||||
| kernel_name_set.insert(kernel_name); | kernel_name_set.insert(kernel_name); | ||||
| @@ -69,7 +74,43 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild( | |||||
| kernel::SaveJsonInfo(kernel_name, kernel_json); | kernel::SaveJsonInfo(kernel_name, kernel_json); | ||||
| jsons.push_back(kernel_json); | jsons.push_back(kernel_json); | ||||
| } | } | ||||
| return jsons; | |||||
| } | |||||
| bool AkgAscendKernelBuilder::InsertToCache(const std::vector<JsonNodePair> &build_args) { | |||||
| for (const auto &[json_generator, anf_node] : build_args) { | |||||
| auto kernel_name = json_generator.kernel_name(); | |||||
| auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node)); | |||||
| if (new_kernel_pack == nullptr) { | |||||
| MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope[" | |||||
| << anf_node->fullname_with_scope() << "]."; | |||||
| return false; | |||||
| } | |||||
| SetKernelMod(new_kernel_pack, json_generator, anf_node); | |||||
| MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!"; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AkgAscendKernelBuilder::HandleRepeatNodes() { | |||||
| for (const auto &[json_generator, anf_node] : repeat_nodes_) { | |||||
| auto kernel_name = json_generator.kernel_name(); | |||||
| auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node)); | |||||
| if (cached_kernel_pack == nullptr) { | |||||
| MS_LOG(ERROR) << "Use cached kernel failed, kernel_name[" << kernel_name << "], fullname_with_scope[" | |||||
| << anf_node->fullname_with_scope() << "]."; | |||||
| return false; | |||||
| } | |||||
| MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" | |||||
| << anf_node->fullname_with_scope() << "]."; | |||||
| SetKernelMod(cached_kernel_pack, json_generator, anf_node); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AkgAscendKernelBuilder::AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args) { | |||||
| repeat_nodes_.clear(); | |||||
| auto jsons = GetNotCachedKernelJsons(build_args); | |||||
| if (jsons.empty()) { | if (jsons.empty()) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -89,56 +130,35 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild( | |||||
| } | } | ||||
| // All unique done here, cache them and set kernel. | // All unique done here, cache them and set kernel. | ||||
| for (const auto &[json_generator, anf_node] : build_args) { | |||||
| auto kernel_name = json_generator.kernel_name(); | |||||
| auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node)); | |||||
| if (new_kernel_pack == nullptr) { | |||||
| MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope[" | |||||
| << anf_node->fullname_with_scope() << "]."; | |||||
| return false; | |||||
| } | |||||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack); | |||||
| kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list()); | |||||
| kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list()); | |||||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||||
| MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!"; | |||||
| if (!InsertToCache(build_args)) { | |||||
| MS_LOG(ERROR) << "Insert cache failed."; | |||||
| return false; | |||||
| } | } | ||||
| // Handle repeated nodes. | |||||
| for (const auto &[json_generator, anf_node] : repeat_nodes) { | |||||
| auto kernel_name = json_generator.kernel_name(); | |||||
| auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node)); | |||||
| if (cached_kernel_pack == nullptr) return false; | |||||
| MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" | |||||
| << anf_node->fullname_with_scope() << "]."; | |||||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack); | |||||
| kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list()); | |||||
| kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list()); | |||||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||||
| if (!HandleRepeatNodes()) { | |||||
| MS_LOG(ERROR) << "Handle repeat nodes failed."; | |||||
| return false; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) { | bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) { | ||||
| std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> json_and_node; | |||||
| std::vector<JsonNodePair> json_and_node; | |||||
| for (const auto &anf_node : anf_nodes) { | for (const auto &anf_node : anf_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| AkgKernelJsonGenerator akg_kernel_json_generator; | AkgKernelJsonGenerator akg_kernel_json_generator; | ||||
| KernelPackPtr kernel_pack = nullptr; | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | auto cnode = anf_node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (AnfAlgo::IsGraphKernel(cnode)) { | if (AnfAlgo::IsGraphKernel(cnode)) { | ||||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); | auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto mng = func_graph->manager(); | auto mng = func_graph->manager(); | ||||
| if (mng == nullptr) { | if (mng == nullptr) { | ||||
| mng = Manage(func_graph, true); | mng = Manage(func_graph, true); | ||||
| func_graph->set_manager(mng); | func_graph->set_manager(mng); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> node_list; | |||||
| std::vector<AnfNodePtr> input_list; | |||||
| std::vector<AnfNodePtr> output_list; | |||||
| std::vector<AnfNodePtr> node_list, input_list, output_list; | |||||
| MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]"; | MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]"; | ||||
| GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | ||||
| if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) { | if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) { | ||||
| @@ -146,7 +166,7 @@ bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) { | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (!akg_kernel_json_generator.CollectJson(anf_node)) { | if (!akg_kernel_json_generator.CollectJson(anf_node)) { | ||||
| MS_EXCEPTION(UnknownError) << "Akg build failed op[" << anf_node->fullname_with_scope() << "]."; | |||||
| MS_EXCEPTION(UnknownError) << "Akg build failed basic op[" << anf_node->fullname_with_scope() << "]."; | |||||
| } | } | ||||
| } | } | ||||
| json_and_node.push_back({akg_kernel_json_generator, anf_node}); | json_and_node.push_back({akg_kernel_json_generator, anf_node}); | ||||
| @@ -27,12 +27,20 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| using JsonNodePair = std::pair<AkgKernelJsonGenerator, AnfNodePtr>; | |||||
| class AkgAscendKernelBuilder { | class AkgAscendKernelBuilder { | ||||
| public: | public: | ||||
| AkgAscendKernelBuilder() = default; | AkgAscendKernelBuilder() = default; | ||||
| ~AkgAscendKernelBuilder() = default; | ~AkgAscendKernelBuilder() = default; | ||||
| bool AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args); | |||||
| private: | |||||
| std::vector<std::string> GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args); | |||||
| bool InsertToCache(const std::vector<JsonNodePair> &build_args); | |||||
| bool HandleRepeatNodes(); | |||||
| bool AkgOpParallelBuild(const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args); | |||||
| std::vector<JsonNodePair> repeat_nodes_; | |||||
| }; | }; | ||||
| bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes); | bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes); | ||||
| @@ -1,4 +1,3 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | * Copyright 2020 Huawei Technologies Co., Ltd | ||||
| * | * | ||||
| @@ -32,7 +31,6 @@ class BasicOpsFusion : public Pass { | |||||
| bool Run(const FuncGraphPtr &func_graph) override; | bool Run(const FuncGraphPtr &func_graph) override; | ||||
| }; | }; | ||||
| using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>; | using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_ | #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_ | ||||
| @@ -128,7 +128,7 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { | |||||
| MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str; | MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str; | ||||
| auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str); | auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str); | ||||
| // parse result. | // parse result. | ||||
| if (ret.is(py::none())) { | |||||
| if (py::isinstance<py::none>(ret)) { | |||||
| MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n" | MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n" | ||||
| << node_desc_str; | << node_desc_str; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -211,9 +211,9 @@ AnfNodePtr DeleteAttrInInput(const FuncGraphPtr &func_graph, const CNodePtr &cno | |||||
| return new_cnode; | return new_cnode; | ||||
| } | } | ||||
| AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) { | |||||
| AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { | |||||
| AnfNodePtrList outs; | AnfNodePtrList outs; | ||||
| auto out_node = (*fg)->output(); | |||||
| auto out_node = fg->output(); | |||||
| if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { | if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { | ||||
| std::vector<AnfNodePtr> output_args; | std::vector<AnfNodePtr> output_args; | ||||
| auto out_cnode = out_node->cast<CNodePtr>(); | auto out_cnode = out_node->cast<CNodePtr>(); | ||||
| @@ -228,8 +228,8 @@ AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *m | |||||
| } | } | ||||
| } | } | ||||
| if (output_args.size() != out_cnode->inputs().size()) { | if (output_args.size() != out_cnode->inputs().size()) { | ||||
| auto new_out = (*fg)->NewCNode(output_args); | |||||
| (*mng)->Replace(out_node, new_out); | |||||
| auto new_out = fg->NewCNode(output_args); | |||||
| mng->Replace(out_node, new_out); | |||||
| } | } | ||||
| for (size_t i = 1; i < output_args.size(); ++i) { | for (size_t i = 1; i < output_args.size(); ++i) { | ||||
| @@ -241,6 +241,27 @@ AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *m | |||||
| outs.push_back(out_node); | outs.push_back(out_node); | ||||
| return outs; | return outs; | ||||
| } | } | ||||
| bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, | |||||
| const DumpOption &dump_option, nlohmann::json *op_desc, | |||||
| std::map<std::string, AnfNodePtr> *address_node_map) { | |||||
| kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option); | |||||
| if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) { | |||||
| MS_LOG(ERROR) << "Collect json desc failed."; | |||||
| return false; | |||||
| } | |||||
| *op_desc = akg_kernel_json_generator.kernel_json(); | |||||
| if (address_node_map != nullptr) { | |||||
| *address_node_map = akg_kernel_json_generator.address_node_map(); | |||||
| } | |||||
| std::string fused_name; | |||||
| std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) { | |||||
| (void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_"); | |||||
| }); | |||||
| MS_LOG(INFO) << "Collect fusion json: " << fused_name; | |||||
| return true; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, | ||||
| @@ -457,7 +478,7 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | |||||
| mng->Replace(n, out); | mng->Replace(n, out); | ||||
| } | } | ||||
| EliminateMakeTuple(&fg, &mng); | |||||
| EliminateMakeTuple(fg, mng); | |||||
| // set graphKernel attr | // set graphKernel attr | ||||
| std::string fuse_op_name = ""; | std::string fuse_op_name = ""; | ||||
| for (auto &fuse_node : fuse_nodes) { | for (auto &fuse_node : fuse_nodes) { | ||||
| @@ -476,50 +497,26 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | |||||
| fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); | fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); | ||||
| } | } | ||||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc, | |||||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, | |||||
| std::map<std::string, AnfNodePtr> *address_node_map) { | std::map<std::string, AnfNodePtr> *address_node_map) { | ||||
| MS_EXCEPTION_IF_NULL(op_desc); | MS_EXCEPTION_IF_NULL(op_desc); | ||||
| if (nodes.empty()) { | if (nodes.empty()) { | ||||
| MS_LOG(ERROR) << "Input nodes is empty."; | MS_LOG(ERROR) << "Input nodes is empty."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool has_graph_kernel = | |||||
| std::any_of(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) { return AnfAlgo::IsGraphKernel(node); }); | |||||
| bool has_graph_kernel = std::any_of(nodes.begin(), nodes.end(), AnfAlgo::IsGraphKernel); | |||||
| bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1; | bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1; | ||||
| auto gen_json = [&dump_option, &op_desc, &address_node_map](const AnfNodePtrList &op_nodes, | |||||
| const AnfNodePtrList &inputs, | |||||
| const AnfNodePtrList &outputs) -> bool { | |||||
| kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option); | |||||
| if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) { | |||||
| MS_LOG(ERROR) << "Collect json desc failed."; | |||||
| return false; | |||||
| } | |||||
| *op_desc = akg_kernel_json_generator.kernel_json(); | |||||
| if (address_node_map != nullptr) { | |||||
| *address_node_map = akg_kernel_json_generator.address_node_map(); | |||||
| } | |||||
| std::string fused_name; | |||||
| std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) { | |||||
| (void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_"); | |||||
| }); | |||||
| MS_LOG(INFO) << "Collect fusion json: " << fused_name; | |||||
| return true; | |||||
| }; | |||||
| FuncGraphPtr fg; | FuncGraphPtr fg; | ||||
| AnfNodePtrList op_nodes; | |||||
| AnfNodePtrList inputs; | |||||
| AnfNodePtrList outputs; | |||||
| AnfNodePtrList op_nodes, inputs, outputs; | |||||
| if (is_single_graph_kernel) { | if (is_single_graph_kernel) { | ||||
| fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]); | fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]); | ||||
| kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); | kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); | ||||
| return gen_json(op_nodes, inputs, outputs); | |||||
| return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); | |||||
| } else if (!has_graph_kernel) { | } else if (!has_graph_kernel) { | ||||
| std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); | std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); | ||||
| op_nodes = nodes; | op_nodes = nodes; | ||||
| return gen_json(op_nodes, inputs, outputs); | |||||
| return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); | |||||
| } | } | ||||
| std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); | std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); | ||||
| @@ -540,10 +537,10 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann | |||||
| inputs.clear(); | inputs.clear(); | ||||
| outputs.clear(); | outputs.clear(); | ||||
| kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); | kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); | ||||
| return gen_json(op_nodes, inputs, outputs); | |||||
| return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); | |||||
| } | } | ||||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, DumpOption dump_option, nlohmann::json *op_desc) { | |||||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc) { | |||||
| MS_EXCEPTION_IF_NULL(op_desc); | MS_EXCEPTION_IF_NULL(op_desc); | ||||
| std::vector<nlohmann::json> graphs_desc; | std::vector<nlohmann::json> graphs_desc; | ||||
| for (auto const &graph_nodes : graphs) { | for (auto const &graph_nodes : graphs) { | ||||
| @@ -46,9 +46,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new | |||||
| void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, | ||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix, | const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix, | ||||
| bool is_before_kernel_select); | bool is_before_kernel_select); | ||||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc, | |||||
| bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, | |||||
| std::map<std::string, AnfNodePtr> *address_node_map = nullptr); | std::map<std::string, AnfNodePtr> *address_node_map = nullptr); | ||||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, DumpOption dump_option, nlohmann::json *op_desc); | |||||
| bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); | |||||
| FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs); | ||||
| bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map, | bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map, | ||||
| std::vector<AnfNodePtrList> *res_graphs); | std::vector<AnfNodePtrList> *res_graphs); | ||||
| @@ -57,8 +57,6 @@ inline void TraverseFuncGraph(const FuncGraphPtr &root, std::function<void(AnfNo | |||||
| TraverseFuncGraphFromCNode(root->get_return(), callback); | TraverseFuncGraphFromCNode(root->get_return(), callback); | ||||
| } | } | ||||
| class AreaGraph; | |||||
| class Splitter; | |||||
| class Area { | class Area { | ||||
| public: | public: | ||||
| explicit Area(const AnfNodePtrList &anf_arr) { | explicit Area(const AnfNodePtrList &anf_arr) { | ||||
| @@ -73,6 +71,8 @@ class Area { | |||||
| } | } | ||||
| } | } | ||||
| ~Area() = default; | |||||
| // Set the external inputs of spy as a Parameter. | // Set the external inputs of spy as a Parameter. | ||||
| void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map<ParameterPtr, AnfNodePtr> *param_node_map) { | void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map<ParameterPtr, AnfNodePtr> *param_node_map) { | ||||
| std::unordered_map<AnfNodePtr, ParameterPtr> node_param_map; | std::unordered_map<AnfNodePtr, ParameterPtr> node_param_map; | ||||
| @@ -148,8 +148,8 @@ class Area { | |||||
| } | } | ||||
| } | } | ||||
| friend AreaGraph; | |||||
| friend Splitter; | |||||
| const std::unordered_set<AnfNodePtr> &nodes() const { return nodes_; } | |||||
| const std::vector<AnfNodePtr> &spy_cnodes() const { return spy_cnodes_; } | |||||
| private: | private: | ||||
| // This is a CNode that does not belong to this area. | // This is a CNode that does not belong to this area. | ||||
| @@ -170,9 +170,8 @@ class AreaGraph { | |||||
| // Build an area graph to maintain the relation between areas. | // Build an area graph to maintain the relation between areas. | ||||
| // Input node_groups: A group list, each element is a AnfNode list representing the node set in this group. | // Input node_groups: A group list, each element is a AnfNode list representing the node set in this group. | ||||
| static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) { | static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) { | ||||
| AreaGraph *area_graph_ptr = new (std::nothrow) AreaGraph(node_groups); | |||||
| if (!area_graph_ptr) return nullptr; | |||||
| auto area_graph = AreaGraphPtr(area_graph_ptr); | |||||
| auto area_graph = AreaGraphPtr(new AreaGraph(node_groups)); | |||||
| if (area_graph == nullptr) return nullptr; | |||||
| if (!area_graph->TopoSort()) { | if (!area_graph->TopoSort()) { | ||||
| MS_LOG(WARNING) << "The groups have a cycle."; | MS_LOG(WARNING) << "The groups have a cycle."; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -184,12 +183,12 @@ class AreaGraph { | |||||
| // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs. | // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs. | ||||
| // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting. | // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting. | ||||
| void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes, | void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes, | ||||
| std::vector<size_t> *cnode_group_id, std::function<void(Area *)> expand_callback) { | |||||
| std::vector<size_t> *cnode_group_id, std::function<void(const Area &)> expand_callback) { | |||||
| main_cnodes->clear(); | main_cnodes->clear(); | ||||
| main_cnodes->resize(areas_.size(), nullptr); | main_cnodes->resize(areas_.size(), nullptr); | ||||
| for (auto &area : this->areas_) { | for (auto &area : this->areas_) { | ||||
| expand_callback(&area); | |||||
| expand_callback(area); | |||||
| } | } | ||||
| for (auto index : topo_order_) { | for (auto index : topo_order_) { | ||||
| @@ -208,6 +207,8 @@ class AreaGraph { | |||||
| return; | return; | ||||
| } | } | ||||
| ~AreaGraph() = default; | |||||
| private: | private: | ||||
| explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) { | explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) { | ||||
| for (size_t i = 0; i < node_groups.size(); ++i) { | for (size_t i = 0; i < node_groups.size(); ++i) { | ||||
| @@ -217,7 +218,7 @@ class AreaGraph { | |||||
| } | } | ||||
| } | } | ||||
| for (auto &area : areas_) { | for (auto &area : areas_) { | ||||
| for (auto &spy : area.spy_cnodes_) { | |||||
| for (auto &spy : area.spy_cnodes()) { | |||||
| auto cnode = spy->cast<CNodePtr>(); | auto cnode = spy->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| size_t v = node_area_map_[spy]; | size_t v = node_area_map_[spy]; | ||||
| @@ -333,8 +334,8 @@ class Splitter { | |||||
| // The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan. | // The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan. | ||||
| std::vector<size_t> cnodes_group_id; | std::vector<size_t> cnodes_group_id; | ||||
| std::function<void(Area *)> expand_callback = std::bind(&Splitter::AreaExpand, this, std::placeholders::_1); | |||||
| area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id, expand_callback); | |||||
| area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id, | |||||
| [this](const Area &area) { this->AreaExpand(area); }); | |||||
| RebuildGraph(cnodes_group_id); | RebuildGraph(cnodes_group_id); | ||||
| @@ -348,6 +349,8 @@ class Splitter { | |||||
| return SplitterPtr(new Splitter(main_cnode, split_schemer)); | return SplitterPtr(new Splitter(main_cnode, split_schemer)); | ||||
| } | } | ||||
| ~Splitter() = default; | |||||
| private: | private: | ||||
| Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) | Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) | ||||
| : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {} | : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {} | ||||
| @@ -479,9 +482,9 @@ class Splitter { | |||||
| } | } | ||||
| // Copy all Parameter and ValueNode that the area used. | // Copy all Parameter and ValueNode that the area used. | ||||
| void AreaExpand(Area *area) { | |||||
| void AreaExpand(const Area &area) { | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map; | std::unordered_map<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map; | ||||
| for (auto sub_node : area->nodes_) { | |||||
| for (auto sub_node : area.nodes()) { | |||||
| auto sub_cnode = sub_node->cast<CNodePtr>(); | auto sub_cnode = sub_node->cast<CNodePtr>(); | ||||
| if (sub_cnode == nullptr) continue; | if (sub_cnode == nullptr) continue; | ||||
| for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) { | for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) { | ||||
| @@ -565,7 +568,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { | |||||
| auto json_desc_str = json_desc.dump(); | auto json_desc_str = json_desc.dump(); | ||||
| MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json:\n" << json_desc_str; | MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json:\n" << json_desc_str; | ||||
| auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str); | auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str); | ||||
| if (ret.is(py::none())) { | |||||
| if (py::isinstance<py::none>(ret)) { | |||||
| MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n" | MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n" | ||||
| << json_desc_str; | << json_desc_str; | ||||
| return false; | return false; | ||||
| @@ -463,13 +463,15 @@ std::string AddGlobalId(const std::string &filename) { | |||||
| static size_t g_id = 0; | static size_t g_id = 0; | ||||
| std::ostringstream s; | std::ostringstream s; | ||||
| auto i = filename.rfind('/'); | auto i = filename.rfind('/'); | ||||
| if (i == string::npos) { | |||||
| if (i >= filename.size()) { | |||||
| s << std::setfill('0') << std::setw(4) << g_id << "_"; | s << std::setfill('0') << std::setw(4) << g_id << "_"; | ||||
| s << filename; | s << filename; | ||||
| } else { | } else { | ||||
| s << filename.substr(0, i + 1); | s << filename.substr(0, i + 1); | ||||
| s << std::setfill('0') << std::setw(4) << g_id << "_"; | s << std::setfill('0') << std::setw(4) << g_id << "_"; | ||||
| s << filename.substr(i + 1); | |||||
| if (i + 1 < filename.size()) { | |||||
| s << filename.substr(i + 1); | |||||
| } | |||||
| } | } | ||||
| ++g_id; | ++g_id; | ||||
| return s.str(); | return s.str(); | ||||
| @@ -236,12 +236,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI | |||||
| } | } | ||||
| void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { | void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> node_list; | |||||
| std::vector<AnfNodePtr> input_list; | |||||
| std::vector<AnfNodePtr> output_list; | |||||
| std::vector<AnfNodePtr> node_list, input_list, output_list; | |||||
| kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | ||||
| std::vector<std::string> graph_input_format; | std::vector<std::string> graph_input_format; | ||||
| @@ -295,6 +290,22 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_gr | |||||
| AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); | AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); | ||||
| SetTensorDeviceInfo(*graph_selected_info, kernel_node); | SetTensorDeviceInfo(*graph_selected_info, kernel_node); | ||||
| } | } | ||||
| void PrintUnsupportedTypeException(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type, | |||||
| const std::vector<TypeId> &outputs_type) { | |||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| std::string build_type = "in ["; | |||||
| std::for_each(std::begin(inputs_type), std::end(inputs_type), | |||||
| [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); | |||||
| build_type += "] out ["; | |||||
| std::for_each(std::begin(outputs_type), std::end(outputs_type), | |||||
| [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); | |||||
| build_type += "]"; | |||||
| auto supported_type_lists = SupportedTypeList(kernel_node); | |||||
| MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name | |||||
| << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists | |||||
| << ", but get " << build_type; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | ||||
| @@ -329,7 +340,7 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s | |||||
| void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | ||||
| if (AnfAlgo::IsGraphKernel(kernel_node)) { | if (AnfAlgo::IsGraphKernel(kernel_node)) { | ||||
| auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex)); | |||||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| SetGraphKernelInfo(kernel_node, func_graph); | SetGraphKernelInfo(kernel_node, func_graph); | ||||
| return; | return; | ||||
| @@ -351,8 +362,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | |||||
| if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) { | if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) { | ||||
| UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); | UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); | ||||
| } | } | ||||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||||
| std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| builder->SetOriginDataFormat(origin_data_format); | builder->SetOriginDataFormat(origin_data_format); | ||||
| builder->SetInputsFormat(inputs_format); | builder->SetInputsFormat(inputs_format); | ||||
| builder->SetInputsDeviceType(inputs_type); | builder->SetInputsDeviceType(inputs_type); | ||||
| @@ -360,35 +370,23 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | |||||
| builder->SetOutputsDeviceType(outputs_type); | builder->SetOutputsDeviceType(outputs_type); | ||||
| bool result = false; | bool result = false; | ||||
| KernelType res_kernel_type = UNKNOWN_KERNEL_TYPE; | |||||
| if (kernel_type == UNKNOWN_KERNEL_TYPE) { | if (kernel_type == UNKNOWN_KERNEL_TYPE) { | ||||
| result = | result = | ||||
| kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); | kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); | ||||
| if (!result) { | if (!result) { | ||||
| result = SelectAkgKernel(kernel_node, builder->Build()); | result = SelectAkgKernel(kernel_node, builder->Build()); | ||||
| res_kernel_type = AKG_KERNEL; | |||||
| kernel_type = AKG_KERNEL; | |||||
| } | } | ||||
| } else if (kernel_type == AKG_KERNEL) { | } else if (kernel_type == AKG_KERNEL) { | ||||
| result = SelectAkgKernel(kernel_node, builder->Build()); | result = SelectAkgKernel(kernel_node, builder->Build()); | ||||
| res_kernel_type = AKG_KERNEL; | |||||
| } | } | ||||
| if (!result) { | if (!result) { | ||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| std::string build_type = "in ["; | |||||
| std::for_each(std::begin(inputs_type), std::end(inputs_type), | |||||
| [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); | |||||
| build_type += "] out ["; | |||||
| std::for_each(std::begin(outputs_type), std::end(outputs_type), | |||||
| [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); | |||||
| build_type += "]"; | |||||
| auto supported_type_lists = SupportedTypeList(kernel_node); | |||||
| MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name | |||||
| << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists | |||||
| << ", but get " << build_type; | |||||
| } | |||||
| builder->SetKernelType(res_kernel_type); | |||||
| PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type); | |||||
| return; | |||||
| } | |||||
| builder->SetKernelType(kernel_type); | |||||
| builder->SetProcessor(kernel::Processor::CUDA); | builder->SetProcessor(kernel::Processor::CUDA); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); | ||||
| SetTensorDeviceInfo(*(builder->Build()), kernel_node); | SetTensorDeviceInfo(*(builder->Build()), kernel_node); | ||||