| @@ -20,7 +20,7 @@ | |||
| #include "kernel/aicpu/aicpu_kernel_metadata.h" | |||
| #include "kernel/rts/rt_kernel_info.h" | |||
| #include "kernel/hccl/hccl_kernel_metadata.h" | |||
| #include "kernel/tbe/tbe_kernel_select.h" | |||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| @@ -63,7 +63,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| TbeMetadataInfo(kernel_node, kernel_info_list); | |||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | |||
| if (kernel_info_list->empty()) { | |||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | |||
| if (!kernel_info_list->empty()) { | |||
| @@ -114,7 +113,6 @@ bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr | |||
| auto cnode = kernel_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| TbeMetadataInfo(cnode, &kernel_info_list); | |||
| FilterInvalidKernelInfo(cnode, &kernel_info_list); | |||
| return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | |||
| [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| @@ -126,6 +126,8 @@ class OpInfo { | |||
| bool is_ref() const { return !ref_infos_.empty(); } | |||
| bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } | |||
| void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } | |||
| void ClearInputs() { (void)inputs_ptr_.clear(); } | |||
| void ClearOutputs() { (void)outputs_ptr_.clear(); } | |||
| private: | |||
| std::string op_name_; | |||
| @@ -35,7 +35,7 @@ constexpr auto kKernelName = "kernel_name"; | |||
| constexpr auto kPartialFlag = "partial_flag"; | |||
| constexpr auto kReshapeType = "reshape_type"; | |||
| constexpr auto kOpPattern = "op_pattern"; | |||
| constexpr auto kDynamicFormat = "dynamic_format"; | |||
| constexpr auto kDynamicFormat = "dynamicFormat"; | |||
| constexpr auto kFormatAgnostic = "formatAgnostic"; | |||
| constexpr auto kBroadcast = "broadcast"; | |||
| constexpr auto kReduce = "reduce"; | |||
| @@ -100,7 +100,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) | |||
| void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | |||
| const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, | |||
| {kFormatAgnostic, kBroadcastPattern}, | |||
| {kBroadcast, kBroadcastPattern}, | |||
| {kReduce, kReducePattern}, | |||
| {kDynamicFormat, kDynamicFormatPattern}}; | |||
| op_info->set_async_flag(obj.at(kAsyncFlag)); | |||
| @@ -108,14 +108,19 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p | |||
| op_info->set_compute_cost(obj.at(kComputeCost)); | |||
| op_info->set_kernel_name(obj.at(kKernelName)); | |||
| op_info->set_partial_flag(obj.at(kPartialFlag)); | |||
| if (obj.find(kOpPattern) != obj.end()) { | |||
| if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) { | |||
| op_info->set_op_pattern(obj.at(kOpPattern)); | |||
| std::string op_pattern = obj.at(kOpPattern); | |||
| auto find_iter = kOpPatternMap.find(op_pattern); | |||
| if (find_iter == kOpPatternMap.end()) { | |||
| if (!op_pattern.empty()) { | |||
| MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern; | |||
| } | |||
| op_info->set_op_pattern(kCommonPattern); | |||
| } else { | |||
| op_info->set_op_pattern(find_iter->second); | |||
| } | |||
| } | |||
| if (obj.find(kDynamicFormat) != obj.end()) { | |||
| op_info->set_dynamic_format(obj.at(kDynamicFormat)); | |||
| } | |||
| } | |||
| bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, | |||
| @@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = { | |||
| {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, | |||
| {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, | |||
| {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, | |||
| {TypeId::kNumberTypeBool, "bool"}, | |||
| {TypeId::kNumberTypeBool, "int8"}, | |||
| }; | |||
| const std::map<std::string, std::string> type_str_maps = { | |||
| @@ -85,7 +85,7 @@ std::string DtypeToString(const std::string &dtypes) { | |||
| std::string TypeIdToString(TypeId type_id) { | |||
| auto iter = type_id_str_maps.find(type_id); | |||
| if (iter == type_id_str_maps.end()) { | |||
| MS_LOG(EXCEPTION) << "Illegal input dtype." << TypeIdLabel(type_id); | |||
| MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id); | |||
| } | |||
| return iter->second; | |||
| } | |||
| @@ -111,41 +111,20 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node, | |||
| if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { | |||
| TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); | |||
| } else { | |||
| // dtype : float16 | |||
| auto tensor_dtype = | |||
| std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index))); | |||
| MS_EXCEPTION_IF_NULL(tensor_dtype); | |||
| std::string dtype = tensor_dtype->element()->ToString(); | |||
| dtype = tbe::DtypeToString(dtype); | |||
| // format | |||
| std::string format = AnfAlgo::GetInputFormat(anf_node, real_input_index); | |||
| if (format == kOpFormat_DEFAULT) { | |||
| format = kOpFormat_NCHW; | |||
| } else if (format == kOpFormat_FRAC_Z) { | |||
| format = kOpFormat_FRACTAL_Z; | |||
| } | |||
| nlohmann::json input_desc_json; | |||
| input_desc_json["dtype"] = dtype; | |||
| input_desc_json["name"] = op_input_name + std::to_string(input_i); | |||
| auto dtype = GetDeviceInputType(anf_node, real_input_index); | |||
| auto format = GetDeviceInputFormat(anf_node, real_input_index); | |||
| auto shape = GetDeviceInputShape(anf_node, real_input_index); | |||
| auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } | |||
| nlohmann::json input_desc_json; | |||
| input_desc_json["dtype"] = dtype; | |||
| input_desc_json["name"] = op_input_name + std::to_string(input_i); | |||
| input_desc_json["ori_shape"] = ori_shape; | |||
| input_desc_json["ori_format"] = kOpFormat_NCHW; | |||
| auto shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); | |||
| if (shape.empty()) { | |||
| shape.emplace_back(1); | |||
| } | |||
| if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { | |||
| input_desc_json["shape"] = ori_shape; | |||
| input_desc_json["format"] = kOpFormat_NCHW; | |||
| } else { | |||
| input_desc_json["shape"] = shape; | |||
| input_desc_json["format"] = format; | |||
| } | |||
| input_desc_json["shape"] = shape; | |||
| input_desc_json["format"] = format; | |||
| input_desc_json["valid"] = value; | |||
| input_desc_json["param_type"] = input_ptr->param_type(); | |||
| input_list->emplace_back(input_desc_json); | |||
| @@ -325,40 +304,22 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co | |||
| MS_EXCEPTION_IF_NULL(output_idx); | |||
| MS_EXCEPTION_IF_NULL(output_list); | |||
| for (size_t i = 0; i < output_obj_num; i++) { | |||
| nlohmann::json output_obj; | |||
| auto type_ptr = std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, *output_idx))); | |||
| std::string dtype = type_ptr->element()->ToString(); | |||
| dtype = tbe::DtypeToString(dtype); | |||
| std::string format = AnfAlgo::GetOutputFormat(anf_node, *output_idx); | |||
| if (format == kOpFormat_DEFAULT) { | |||
| format = kOpFormat_NCHW; | |||
| } else if (format == kOpFormat_FRAC_Z) { | |||
| format = kOpFormat_FRACTAL_Z; | |||
| } | |||
| std::vector<size_t> ori_shape; | |||
| if (AnfAlgo::GetOutputInferShape(anf_node, *output_idx).empty()) { | |||
| auto dtype = GetDeviceOutputType(anf_node, *output_idx); | |||
| auto format = GetDeviceOutputFormat(anf_node, *output_idx); | |||
| auto shape = GetDeviceOutputShape(anf_node, *output_idx); | |||
| std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } else { | |||
| ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); | |||
| } | |||
| nlohmann::json output_obj; | |||
| output_obj["dtype"] = dtype; | |||
| auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, *output_idx); | |||
| if (shape.empty()) { | |||
| shape.emplace_back(1); | |||
| } | |||
| if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { | |||
| output_obj["shape"] = ori_shape; | |||
| output_obj["format"] = kOpFormat_NCHW; | |||
| } else { | |||
| output_obj["shape"] = shape; | |||
| output_obj["format"] = format; | |||
| } | |||
| output_obj["shape"] = shape; | |||
| output_obj["format"] = format; | |||
| output_obj["ori_shape"] = ori_shape; | |||
| output_obj["ori_format"] = kOpFormat_NCHW; | |||
| output_obj["name"] = output_ptr->name(); | |||
| output_obj["valid"] = true; | |||
| output_obj["param_type"] = output_ptr->param_type(); | |||
| output_list->emplace_back(output_obj); | |||
| (*output_idx)++; | |||
| } | |||
| @@ -456,6 +417,84 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo | |||
| } | |||
| } | |||
| std::vector<size_t> TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::vector<size_t> shape; | |||
| if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { | |||
| shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index); | |||
| } else { | |||
| shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index); | |||
| } | |||
| if (shape.empty()) { | |||
| shape.emplace_back(1); | |||
| } | |||
| return shape; | |||
| } | |||
| std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| TypeId type_id; | |||
| if (creater_type_ == OP_SELECT_FORMAT) { | |||
| type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index); | |||
| } else { | |||
| type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index); | |||
| } | |||
| return tbe::TypeIdToString(type_id); | |||
| } | |||
| std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::string format = kOpFormat_NCHW; | |||
| if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { | |||
| format = AnfAlgo::GetInputFormat(anf_node, real_index); | |||
| if (format == kOpFormat_FRAC_Z) { | |||
| format = kOpFormat_FRACTAL_Z; | |||
| } else if (format == kOpFormat_DEFAULT) { | |||
| format = kOpFormat_NCHW; | |||
| } | |||
| } | |||
| return format; | |||
| } | |||
| std::vector<size_t> TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::vector<size_t> shape; | |||
| if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { | |||
| shape = AnfAlgo::GetOutputInferShape(anf_node, real_index); | |||
| } else { | |||
| shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index); | |||
| } | |||
| if (shape.empty()) { | |||
| shape.emplace_back(1); | |||
| } | |||
| return shape; | |||
| } | |||
| std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| TypeId type_id; | |||
| if (creater_type_ == OP_SELECT_FORMAT) { | |||
| type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index); | |||
| } else { | |||
| type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index); | |||
| } | |||
| return tbe::TypeIdToString(type_id); | |||
| } | |||
| std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::string format = kOpFormat_NCHW; | |||
| if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { | |||
| format = AnfAlgo::GetOutputFormat(anf_node, real_index); | |||
| if (format == kOpFormat_FRAC_Z) { | |||
| format = kOpFormat_FRACTAL_Z; | |||
| } else if (format == kOpFormat_DEFAULT) { | |||
| format = kOpFormat_NCHW; | |||
| } | |||
| } | |||
| return format; | |||
| } | |||
| bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list, | |||
| std::vector<size_t> *output_size_list) { | |||
| if (input_size_list == nullptr || output_size_list == nullptr) { | |||
| @@ -93,7 +93,7 @@ class TbeKernelJsonCreator { | |||
| nlohmann::json *outputs_json); | |||
| bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info, | |||
| nlohmann::json *attrs_json); | |||
| void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); | |||
| static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); | |||
| bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value, | |||
| const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i, | |||
| std::vector<nlohmann::json> *input_list); | |||
| @@ -105,6 +105,13 @@ class TbeKernelJsonCreator { | |||
| void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num, | |||
| const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx, | |||
| std::vector<nlohmann::json> *output_list); | |||
| std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| std::vector<size_t> GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| kCreaterType creater_type_; | |||
| std::string json_name_; | |||
| std::string json_info_; | |||
| @@ -230,7 +230,7 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t | |||
| task_iter->second.output_size_list, kernel_pack); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| if (set_kernel_mod) { | |||
| AnfAlgo ::SetKernelMod(kernel_mod, task_iter->second.node); | |||
| AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); | |||
| } | |||
| auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); | |||
| (void)task_map_.erase(task_iter); | |||
| @@ -1,664 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/tbe/tbe_kernel_select.h" | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <set> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "kernel/tbe/tbe_kernel_build.h" | |||
| #include "nlohmann/json.hpp" | |||
| #include "common/utils.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "kernel/tbe/tbe_python_funcs.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "kernel/tbe/tbe_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr auto kName = "name"; | |||
| constexpr auto kDtype = "dtype"; | |||
| constexpr auto kFormat = "format"; | |||
| constexpr auto kPrefixInput = "input"; | |||
| constexpr auto kPrefixOutput = "output"; | |||
| const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"}, | |||
| {"NHWC", "DefaultFormat"}, | |||
| {"ND", "DefaultFormat"}, | |||
| {"FRACTAL_Z", "FracZ"}, | |||
| {"NDHWC", "DefaultFormat"}}; | |||
| static const std::vector<std::string> CHECK_SUPPORTED_OPTYPE{ | |||
| "MatMul", "BatchMatMul", "TopK", "InTopK", "Pack", "GatherNd", "UnsortedSegmentMinD", "UnsortedSegmentProdD", "Cast"}; | |||
| bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(select_kernel_build_info); | |||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||
| auto iter = std::find(CHECK_SUPPORTED_OPTYPE.begin(), CHECK_SUPPORTED_OPTYPE.end(), op_name); | |||
| if (iter == CHECK_SUPPORTED_OPTYPE.end()) { | |||
| MS_LOG(DEBUG) << "Op " << op_name << "this op does not need to check op supported."; | |||
| return true; | |||
| } | |||
| // replace kernel_info with current kernel info | |||
| auto ori_select_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(anf_node); | |||
| AnfAlgo::SetSelectKernelBuildInfo(select_kernel_build_info, anf_node.get()); | |||
| nlohmann::json kernel_json; | |||
| TbeKernelJsonCreator creator(CHECK_SUPPORTED); | |||
| bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json); | |||
| if (!ret) { | |||
| MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed"; | |||
| AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get()); | |||
| return false; | |||
| } | |||
| ret = TbePythonFuncs::CheckSupported(kernel_json); | |||
| AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get()); | |||
| return ret; | |||
| } | |||
| bool CheckJsonItemValidity(const nlohmann::json &json_obj, const std::string &key_name, | |||
| const std::vector<std::string> &keys) { | |||
| if (!json_obj[key_name].is_object()) { | |||
| MS_LOG(DEBUG) << key_name << "is not an object!"; | |||
| return false; | |||
| } | |||
| for (auto key : keys) { | |||
| if (json_obj[key_name].find(key) == json_obj[key_name].end()) { | |||
| MS_LOG(DEBUG) << "Key" << key << "of " << key_name << " is not found!"; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| std::vector<std::string> SplitStr(const std::string &string, const std::string &sep) { | |||
| std::vector<std::string> result; | |||
| size_t start = 0; | |||
| size_t index = string.find(sep, start); | |||
| std::string substr; | |||
| while (index != std::string::npos) { | |||
| if (string.size() > start) { | |||
| substr = string.substr(start, index - start); | |||
| } | |||
| (void)substr.erase(0, substr.find_first_not_of(' ')); | |||
| (void)substr.erase(substr.find_last_not_of(' ') + 1); | |||
| auto iter = DYNAMIC_FORMAT_MAP.find(substr); | |||
| if (iter != DYNAMIC_FORMAT_MAP.end()) { | |||
| substr = iter->second; | |||
| } | |||
| result.push_back(substr); | |||
| start = index + sep.size(); | |||
| index = string.find(sep, start); | |||
| } | |||
| if (string.size() > start) { | |||
| substr = string.substr(start); | |||
| } | |||
| (void)substr.erase(0, substr.find_first_not_of(' ')); | |||
| (void)substr.erase(substr.find_last_not_of(' ') + 1); | |||
| auto iter = DYNAMIC_FORMAT_MAP.find(substr); | |||
| if (iter != DYNAMIC_FORMAT_MAP.end()) { | |||
| substr = iter->second; | |||
| } | |||
| result.push_back(substr); | |||
| return result; | |||
| } | |||
| void ConvertFormatDtype(const std::string &format, const std::string &dtype, const std::shared_ptr<OpIOInfo> &io_info) { | |||
| MS_EXCEPTION_IF_NULL(io_info); | |||
| std::vector<std::string> format_vec = SplitStr(format, ","); | |||
| std::vector<std::string> dtype_vec = SplitStr(dtype, ","); | |||
| io_info->set_formats(format_vec); | |||
| io_info->set_dtypes(dtype_vec); | |||
| } | |||
| bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector<std::shared_ptr<OpIOInfo>> *const inputs, | |||
| std::vector<std::shared_ptr<OpIOInfo>> *const outputs) { | |||
| nlohmann::json json_obj = nlohmann::json::parse(jsonStr); | |||
| if (!json_obj.is_object()) { | |||
| MS_LOG(DEBUG) << "JsonStr is not an object, the jsonStr is:" << jsonStr; | |||
| return false; | |||
| } | |||
| std::vector<std::string> keys = {kName, kDtype, kFormat}; | |||
| for (const auto &item : json_obj.items()) { | |||
| std::string key_name; | |||
| key_name = item.key(); | |||
| if (key_name.empty()) { | |||
| MS_LOG(DEBUG) << "Key name is empty!"; | |||
| return false; | |||
| } | |||
| if (!CheckJsonItemValidity(json_obj, key_name, keys)) { | |||
| return false; | |||
| } | |||
| if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) { | |||
| std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>(); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| input->set_name(json_obj[key_name].at(kName)); | |||
| ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input); | |||
| inputs->emplace_back(input); | |||
| } else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) { | |||
| std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>(); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| output->set_name(json_obj[key_name].at(kName)); | |||
| ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), output); | |||
| outputs->emplace_back(output); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Key name:" << key_name << " is undefined!"; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| std::string OpSelectFormat(const std::shared_ptr<AnfNode> &anf_node) { | |||
| nlohmann::json kernel_json; | |||
| std::string res_json_str; | |||
| TbeKernelJsonCreator creator(OP_SELECT_FORMAT); | |||
| bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json); | |||
| if (!ret) { | |||
| MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed"; | |||
| return res_json_str; | |||
| } | |||
| res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); | |||
| MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; | |||
| return res_json_str; | |||
| } | |||
| void SetTidyInputsInfo(const std::shared_ptr<AnfNode> &anf_node, | |||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { | |||
| std::vector<TypeId> inputs_type; | |||
| std::vector<std::string> inputs_format; | |||
| std::vector<int> dyn_input_sizes; | |||
| size_t dyn_input_idx = 0; | |||
| size_t kernel_info_index = 0; | |||
| size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (primitive->GetAttr("dyn_input_sizes") != nullptr) { | |||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes")); | |||
| } | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| MS_EXCEPTION_IF_NULL(inputs[i]); | |||
| std::string param_type = inputs[i]->param_type(); | |||
| if (i >= real_input_num) { | |||
| MS_LOG(INFO) << "Input index: " << i << " is out of real_input_num:" << real_input_num; | |||
| continue; | |||
| } | |||
| auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, i); | |||
| auto format = kOpFormat_DEFAULT; | |||
| if (param_type == "dynamic") { | |||
| if (!dyn_input_sizes.empty()) { | |||
| for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { | |||
| kernel_info_index++; | |||
| inputs_type.emplace_back(type_id); | |||
| inputs_format.emplace_back(format); | |||
| } | |||
| dyn_input_idx++; | |||
| } | |||
| } else if (param_type == "required") { | |||
| kernel_info_index++; | |||
| inputs_type.emplace_back(type_id); | |||
| inputs_format.emplace_back(format); | |||
| } else { | |||
| if (kernel_info_index < real_input_num) { | |||
| MS_LOG(INFO) << "Input type is optional, input index is :" << kernel_info_index; | |||
| kernel_info_index++; | |||
| inputs_type.emplace_back(type_id); | |||
| inputs_format.emplace_back(format); | |||
| } | |||
| } | |||
| } | |||
| builder->SetInputsDeviceType(inputs_type); | |||
| builder->SetInputsFormat(inputs_format); | |||
| } | |||
| void SetTidyOutputsInfo(const std::shared_ptr<AnfNode> &anf_node, | |||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { | |||
| std::vector<TypeId> outputs_type; | |||
| std::vector<std::string> outputs_format; | |||
| auto real_output_num = AnfAlgo::GetOutputTensorNum(anf_node); | |||
| size_t output_idx = 0; | |||
| for (const auto &output : outputs) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (output_idx >= real_output_num) { | |||
| continue; | |||
| } | |||
| size_t output_num = 0; | |||
| if (output->param_type() == "dynamic") { | |||
| if (outputs.size() > 1) { | |||
| MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!"; | |||
| } | |||
| output_num = real_output_num; | |||
| } else if (output->param_type() == "required") { | |||
| output_num = 1; | |||
| } else { | |||
| if (output_idx < real_output_num) { | |||
| MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; | |||
| output_num = 1; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| auto type_id = AnfAlgo::GetOutputInferDataType(anf_node, output_idx); | |||
| outputs_type.emplace_back(type_id); | |||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| output_idx++; | |||
| } | |||
| } | |||
| builder->SetOutputsDeviceType(outputs_type); | |||
| builder->SetOutputsFormat(outputs_format); | |||
| } | |||
| void GenTidyKernelBuildInfo(const std::shared_ptr<AnfNode> &anf_node, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &inputs, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { | |||
| auto builder_tmp = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| builder_tmp->SetKernelType(TBE_KERNEL); | |||
| SetTidyInputsInfo(anf_node, builder_tmp, inputs); | |||
| SetTidyOutputsInfo(anf_node, builder_tmp, outputs); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder_tmp->Build(), anf_node.get()); | |||
| } | |||
| void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, | |||
| const std::shared_ptr<OpInfo> &op_info_new_ptr) { | |||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_static = op_info_ptr->inputs_ptr(); | |||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_static = op_info_ptr->outputs_ptr(); | |||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_dyn; | |||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_dyn; | |||
| if ((op_info_ptr->imply_type() == kTBE) && (!mindspore::opt::IsNopNode(kernel_node->cast<AnfNodePtr>()))) { | |||
| // 1. create tidy kernelBuildInfo in order to generate json for calling op_select_format | |||
| auto anf_node = kernel_node->cast<std::shared_ptr<AnfNode>>(); | |||
| auto kernel_build_info_ptr = AnfAlgo::GetSelectKernelBuildInfo(anf_node); | |||
| GenTidyKernelBuildInfo(kernel_node, inputs_static, outputs_static); | |||
| // 2.get dynamic format from op_impl | |||
| std::string res_json_str; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->execution_mode() != kPynativeMode) { | |||
| res_json_str = OpSelectFormat(kernel_node); | |||
| } | |||
| if (!res_json_str.empty()) { | |||
| (void)ParseDynamicFormatJson(res_json_str, &inputs_dyn, &outputs_dyn); | |||
| } | |||
| if (inputs_static.size() != inputs_dyn.size()) { | |||
| inputs_dyn.clear(); | |||
| } | |||
| if (outputs_static.size() != outputs_dyn.size()) { | |||
| outputs_dyn.clear(); | |||
| } | |||
| // 3. resume kernel node's SelectKernelBuildInfo | |||
| // As it has been replaced by GenTidyKernelBuildInfo in order to call python func | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_ptr, anf_node.get()); | |||
| } | |||
| // 4.replace by dynamic format and dtype | |||
| if (inputs_dyn.empty() && outputs_dyn.empty()) { | |||
| MS_LOG(INFO) << "Dynamic select format response is empty, use static register info."; | |||
| op_info_new_ptr->set_inputs_ptr(inputs_static); | |||
| op_info_new_ptr->set_outputs_ptr(outputs_static); | |||
| } else { | |||
| MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format."; | |||
| for (size_t i = 0; i < inputs_static.size(); i++) { | |||
| inputs_dyn[i]->set_param_type(inputs_static[i]->param_type()); | |||
| inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type()); | |||
| } | |||
| for (size_t j = 0; j < outputs_static.size(); j++) { | |||
| outputs_dyn[j]->set_param_type(outputs_static[j]->param_type()); | |||
| outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type()); | |||
| } | |||
| op_info_new_ptr->set_inputs_ptr(inputs_dyn); | |||
| op_info_new_ptr->set_outputs_ptr(outputs_dyn); | |||
| } | |||
| // 5.copy other opinfo to new op_info_new | |||
| op_info_new_ptr->set_op_name(op_info_ptr->op_name()); | |||
| op_info_new_ptr->set_imply_type(op_info_ptr->imply_type()); | |||
| op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type()); | |||
| } | |||
| bool StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) { | |||
| for (const auto &c : reshape_type_str) { | |||
| switch (c) { | |||
| case 'N': | |||
| reshape_type_vec->push_back(kernel::N); | |||
| break; | |||
| case 'C': | |||
| reshape_type_vec->push_back(kernel::C); | |||
| break; | |||
| case 'H': | |||
| reshape_type_vec->push_back(kernel::H); | |||
| break; | |||
| case 'W': | |||
| reshape_type_vec->push_back(kernel::W); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type."; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num, | |||
| size_t builder_idex, const std::vector<int> &dyn_input_sizes, | |||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) { | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| std::vector<TypeId> inputs_device_type; | |||
| std::vector<std::string> inputs_format; | |||
| size_t dyn_input_idx = 0; | |||
| size_t kernel_info_index = 0; | |||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||
| size_t kernel_info_cnt = inputs[0]->dtypes().size(); | |||
| std::vector<std::vector<Axis>> reshape_types; | |||
| for (const auto &input : inputs) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| std::string param_type = input->param_type(); | |||
| std::vector<std::string> dtypes = input->dtypes(); | |||
| std::vector<std::string> formats = input->formats(); | |||
| if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { | |||
| MS_LOG(ERROR) << "Set input kernel builder info, dtyps size != formats size."; | |||
| return false; | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| if (param_type == "dynamic") { | |||
| if (dyn_input_sizes.empty()) { | |||
| MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; | |||
| return false; | |||
| } | |||
| for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { | |||
| kernel_info_index++; | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| dyn_input_idx++; | |||
| } else if (param_type == "required") { | |||
| kernel_info_index++; | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } else { | |||
| if (kernel_info_index < real_input_num) { | |||
| MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index; | |||
| kernel_info_index++; | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| } | |||
| } | |||
| builder->SetInputReshapeType(reshape_types); | |||
| builder->SetInputsDeviceType(inputs_device_type); | |||
| builder->SetInputsFormat(inputs_format); | |||
| return true; | |||
| } | |||
| bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex, | |||
| const size_t &real_output_num, | |||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) { | |||
| // not now but in the next we need to support dynamic output case | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| size_t output_idx = 0; | |||
| std::vector<TypeId> outputs_device_type; | |||
| std::vector<std::string> outputs_format; | |||
| MS_EXCEPTION_IF_NULL(outputs[0]); | |||
| size_t kernel_info_cnt = outputs[0]->dtypes().size(); | |||
| std::vector<std::vector<Axis>> reshape_types; | |||
| for (const auto &output : outputs) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (output_idx >= real_output_num) { | |||
| MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!"; | |||
| continue; | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| size_t output_num = 0; | |||
| if (output->param_type() == "dynamic") { | |||
| if (outputs.size() > 1) { | |||
| MS_LOG(EXCEPTION) << "Dynamic output is unsupported multi output!"; | |||
| } | |||
| output_num = real_output_num; | |||
| } else if (output->param_type() == "required") { | |||
| output_num = 1; | |||
| } else { | |||
| if (output_idx < real_output_num) { | |||
| MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is " << output_idx; | |||
| output_num = 1; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| std::vector<std::string> dtypes = output->dtypes(); | |||
| std::vector<std::string> formats = output->formats(); | |||
| if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { | |||
| MS_LOG(ERROR) << "Set output kernel builder info, dtyps size != formats size."; | |||
| return false; | |||
| } | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| outputs_device_type.push_back(type_id); | |||
| outputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| output_idx++; | |||
| } | |||
| } | |||
| builder->SetOutputReshapeType(reshape_types); | |||
| builder->SetOutputsFormat(outputs_format); | |||
| builder->SetOutputsDeviceType(outputs_device_type); | |||
| return true; | |||
| } | |||
| void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, | |||
| Processor processor, const std::shared_ptr<const OpInfo> &op_info_ptr) { | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| MS_EXCEPTION_IF_NULL(op_info_ptr); | |||
| builder->SetProcessor(processor); | |||
| std::string fusion_type = op_info_ptr->fusion_type(); | |||
| if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { | |||
| builder->SetFusionType(tbe::GetFusionType(fusion_type)); | |||
| } | |||
| builder->SetOpPattern(op_info_ptr->op_pattern()); | |||
| builder->SetKernelType(TBE_KERNEL); | |||
| } | |||
| bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, | |||
| std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr(); | |||
| std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr(); | |||
| std::vector<int> dyn_input_sizes; | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (primitive->GetAttr("dyn_input_sizes") != nullptr) { | |||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes")); | |||
| } | |||
| if (!inputs.empty()) { | |||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||
| size_t kernel_info_cnt = inputs[0]->dtypes().size(); | |||
| for (size_t j = 0; j < kernel_info_cnt; j++) { | |||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr); | |||
| if (!SetKernelBuilderInputInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) { | |||
| MS_LOG(ERROR) << "Parse kernel metadata, set inputs kernel builder info failed."; | |||
| return false; | |||
| } | |||
| if (!outputs.empty()) { | |||
| if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) { | |||
| MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed."; | |||
| return false; | |||
| } | |||
| } | |||
| kernel_info_list->push_back(builder->Build()); | |||
| } | |||
| } else if (!outputs.empty()) { | |||
| MS_EXCEPTION_IF_NULL(outputs[0]); | |||
| size_t kernel_info_cnt = outputs[0]->dtypes().size(); | |||
| for (size_t j = 0; j < kernel_info_cnt; j++) { | |||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr); | |||
| if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) { | |||
| MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed."; | |||
| return false; | |||
| } | |||
| kernel_info_list->push_back(builder->Build()); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||
| // if format is default, it remarkes support all format | |||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | |||
| } | |||
| if (format == kOpFormat_DEFAULT) { | |||
| return true; | |||
| } | |||
| if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) { | |||
| return false; | |||
| } | |||
| // if shape size is 0, the shape will be a scalar | |||
| if (shape.empty()) { | |||
| return true; | |||
| } | |||
| if (shape.size() > kShape4dDims) { | |||
| return false; | |||
| } | |||
| if (format == kOpFormat_FRAC_NZ && shape.size() < 2) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| const size_t kCAxis = 1; | |||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | |||
| if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) { | |||
| if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) { | |||
| return false; | |||
| } | |||
| return false; | |||
| } | |||
| if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||
| return false; | |||
| } | |||
| if (kernel_name == "ReduceMean") { | |||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims); | |||
| if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | |||
| if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) { | |||
| return false; | |||
| } | |||
| if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) { | |||
| if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) { | |||
| return false; | |||
| } | |||
| return false; | |||
| } | |||
| if (kernel_name == "ReduceMean") { | |||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims); | |||
| if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | |||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | |||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | |||
| } | |||
| return true; | |||
| } | |||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list; | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); | |||
| if (op_info_ptr == nullptr) { | |||
| return; | |||
| } | |||
| // dynamic get op format and dtype and replace opinfo | |||
| auto op_info_new_ptr = std::make_shared<OpInfo>(); | |||
| ReplaceByDynamicFormatDtype(kernel_node, op_info_ptr, op_info_new_ptr); | |||
| if (!ParseMetadata(kernel_node, op_info_new_ptr, &parse_info_list)) { | |||
| MS_LOG(INFO) << "Tbe parsed metadata of op[" << op_name << "] failed."; | |||
| return; | |||
| } | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| for (const auto &parse_info : parse_info_list) { | |||
| if (IsValidKernelInfo(kernel_node, *(parse_info))) { | |||
| if (CheckSupported(kernel_node, parse_info)) { | |||
| kernel_info_list->push_back(parse_info); | |||
| } else { | |||
| MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | |||
| } | |||
| } | |||
| if (kernel_info_list->empty()) { | |||
| MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; | |||
| } | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,20 +13,18 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_TBE_KERNEL_SELECT_H | |||
| #define MINDSPORE_TBE_KERNEL_SELECT_H | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "kernel/oplib/opinfo.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | |||
| struct SupportFormat { | |||
| std::vector<std::vector<std::string>> input_format; | |||
| std::vector<std::vector<std::string>> output_format; | |||
| }; | |||
| using SupportFormatItem = std::vector<std::string>; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_TBE_KERNEL_SELECT_H | |||
| #endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ | |||
| @@ -0,0 +1,319 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" | |||
| #include "utils/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr char kDynInputKey[] = "dyn_input_sizes"; | |||
| constexpr size_t kInputIndex_0 = 0; | |||
| constexpr size_t kChannelN = 0; | |||
| constexpr size_t kChannelC = 1; | |||
| constexpr size_t kAlignmented16 = 16; | |||
| // 1. all shape no scalar and same | |||
| // 2. part scalar : no_scalar (shape size > xxx && alig xxx) | |||
| // 3. all no_scalar and not same (broad cast xxx dim) | |||
| bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| input_num_ = 0; | |||
| output_num_ = 0; | |||
| input_shapes_.clear(); | |||
| output_shapes_.clear(); | |||
| if (AnfAlgo::HasNodeAttr(kDynInputKey, cnode_ptr_)) { | |||
| MS_LOG(INFO) << "This broadcast node has dynamic input."; | |||
| auto dynamic_size_vec = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode_ptr_, kDynInputKey); | |||
| if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) { | |||
| MS_LOG(EXCEPTION) << "dynamic attr set error, please check."; | |||
| } | |||
| auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); | |||
| PadScalarShape(&dynamic_input_shape0_); | |||
| input_shapes_.emplace_back(dynamic_input_shape0_); | |||
| input_num_ = 1; | |||
| } else { | |||
| input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||
| for (size_t i = 0; i < input_num_; ++i) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | |||
| PadScalarShape(&input_shape); | |||
| input_shapes_.emplace_back(input_shape); | |||
| } | |||
| } | |||
| output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | |||
| for (size_t i = 0; i < output_num_; ++i) { | |||
| auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i); | |||
| PadScalarShape(&output); | |||
| output_shapes_.emplace_back(output); | |||
| } | |||
| AssignSupportFormat(kOpFormat_DEFAULT, support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| if (IsSameShape()) { | |||
| if (!HasScalarInput()) { | |||
| AssignSupportFormat(kOpFormat_NC1HWC0, support_format); | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| SupportFormatItem input_support_format; | |||
| SupportFormatItem output_support_format; | |||
| if (HasScalarInput()) { | |||
| for (const auto &shape : input_shapes_) { | |||
| if (IsScalarShape(shape)) { | |||
| input_support_format.emplace_back(kOpFormat_DEFAULT); | |||
| } else { | |||
| if (!Is4DShape(shape)) { | |||
| return false; | |||
| } | |||
| if (shape[kChannelC] % kAlignmented16 != 0) { | |||
| return false; | |||
| } | |||
| input_support_format.emplace_back(kOpFormat_NC1HWC0); | |||
| } | |||
| } | |||
| } else { | |||
| for (const auto &shape : input_shapes_) { | |||
| if (!Is4DShape(shape)) { | |||
| return false; | |||
| } | |||
| } | |||
| auto shape_tmp = input_shapes_[0]; | |||
| auto broadcast_c_axis = std::any_of( | |||
| input_shapes_.begin(), input_shapes_.end(), | |||
| [&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); | |||
| if (broadcast_c_axis) { | |||
| MS_LOG(INFO) << "This node broadcast c channel."; | |||
| return false; | |||
| } | |||
| input_support_format.assign(input_num_, kOpFormat_NC1HWC0); | |||
| } | |||
| GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format); | |||
| support_format->input_format.emplace_back(input_support_format); | |||
| support_format->output_format.emplace_back(output_support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| if (IsSameShape()) { | |||
| if (!HasScalarInput()) { | |||
| AssignSupportFormat(kOpFormat_FRAC_Z, support_format); | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| SupportFormatItem input_support_format; | |||
| SupportFormatItem output_support_format; | |||
| if (HasScalarInput()) { | |||
| for (const auto &shape : input_shapes_) { | |||
| if (IsScalarShape(shape)) { | |||
| input_support_format.emplace_back(kOpFormat_DEFAULT); | |||
| } else { | |||
| if (!Is4DShape(shape)) { | |||
| return false; | |||
| } | |||
| if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) { | |||
| return false; | |||
| } | |||
| input_support_format.emplace_back(kOpFormat_FRAC_Z); | |||
| } | |||
| } | |||
| } else { | |||
| return false; | |||
| } | |||
| GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format); | |||
| support_format->input_format.emplace_back(input_support_format); | |||
| support_format->output_format.emplace_back(output_support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| if (IsSameShape()) { | |||
| if (!HasScalarInput()) { | |||
| AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format); | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| SupportFormatItem input_support_format; | |||
| SupportFormatItem output_support_format; | |||
| if (HasScalarInput()) { | |||
| for (const auto &shape : input_shapes_) { | |||
| if (IsScalarShape(shape)) { | |||
| input_support_format.emplace_back(kOpFormat_DEFAULT); | |||
| } else { | |||
| if (!Is4DShape(shape)) { | |||
| return false; | |||
| } | |||
| if (shape[kChannelN] % kAlignmented16 != 0) { | |||
| return false; | |||
| } | |||
| input_support_format.emplace_back(kOpFormat_C1HWNCoC0); | |||
| } | |||
| } | |||
| } else { | |||
| for (const auto &shape : input_shapes_) { | |||
| if (!Is4DShape(shape)) { | |||
| return false; | |||
| } | |||
| } | |||
| auto shape_tmp = input_shapes_[0]; | |||
| auto broadcast_nc_axis = | |||
| std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) { | |||
| return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN)); | |||
| }); | |||
| if (broadcast_nc_axis) { | |||
| MS_LOG(INFO) << "This node broadcast n || c channel."; | |||
| return false; | |||
| } | |||
| input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0); | |||
| } | |||
| GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format); | |||
| support_format->input_format.emplace_back(input_support_format); | |||
| support_format->output_format.emplace_back(output_support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| if (IsSameShape()) { | |||
| if (!HasScalarInput()) { | |||
| AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| SupportFormatItem input_support_format; | |||
| SupportFormatItem output_support_format; | |||
| if (HasScalarInput()) { | |||
| for (const auto &shape : input_shapes_) { | |||
| if (IsScalarShape(shape)) { | |||
| input_support_format.emplace_back(kOpFormat_DEFAULT); | |||
| } else { | |||
| if (shape.size() < kShape2dDims) { | |||
| return false; | |||
| } | |||
| if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) { | |||
| return false; | |||
| } | |||
| input_support_format.emplace_back(kOpFormat_FRAC_NZ); | |||
| } | |||
| } | |||
| } else { | |||
| auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(), | |||
| [](const std::vector<size_t> &elem) { return elem.size() < kShape2dDims; }); | |||
| if (less_2dims) { | |||
| MS_LOG(INFO) << "This node dim less 2."; | |||
| return false; | |||
| } | |||
| auto shape_tmp = input_shapes_[0]; | |||
| auto broadcast_last_dim = | |||
| std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) { | |||
| return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) || | |||
| (shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2)); | |||
| }); | |||
| if (broadcast_last_dim) { | |||
| MS_LOG(INFO) << "This node broadcast last channel."; | |||
| return false; | |||
| } | |||
| input_support_format.assign(input_num_, kOpFormat_FRAC_NZ); | |||
| } | |||
| GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format); | |||
| support_format->input_format.emplace_back(input_support_format); | |||
| support_format->output_format.emplace_back(output_support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| return false; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const { | |||
| return shape.size() == kShape4dDims; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::IsSameShape() const { | |||
| auto shape = input_shapes_.begin(); | |||
| for (const auto &item : input_shapes_) { | |||
| if (shape->size() != item.size()) { | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < shape->size(); ++i) { | |||
| if (shape->at(i) != item.at(i)) { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void TbeKernelBroadCastSelecter::PadScalarShape(std::vector<size_t> *shape) const { | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (shape->empty()) { | |||
| shape->emplace_back(1); | |||
| } | |||
| } | |||
| bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector<size_t> &shape) const { | |||
| return (shape.size() == 1 && shape[0] == 1); | |||
| } | |||
| bool TbeKernelBroadCastSelecter::HasScalarInput() const { | |||
| bool ret = false; | |||
| for (const auto &shape : input_shapes_) { | |||
| if (IsScalarShape(shape)) { | |||
| ret = true; | |||
| break; | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format, | |||
| SupportFormatItem *output_support_item) const { | |||
| MS_EXCEPTION_IF_NULL(output_support_item); | |||
| for (const auto &shape : output_shapes_) { | |||
| if (IsScalarShape(shape)) { | |||
| output_support_item->emplace_back(kOpFormat_DEFAULT); | |||
| } else { | |||
| output_support_item->emplace_back(support_format); | |||
| } | |||
| } | |||
| } | |||
| void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str, | |||
| mindspore::kernel::SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| SupportFormatItem input_support_format; | |||
| SupportFormatItem output_support_format; | |||
| input_support_format.assign(input_num_, support_format_str); | |||
| output_support_format.assign(output_num_, support_format_str); | |||
| support_format->input_format.emplace_back(input_support_format); | |||
| support_format->output_format.emplace_back(output_support_format); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "ir/anf.h" | |||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class TbeKernelBroadCastSelecter { | |||
| public: | |||
| explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} | |||
| ~TbeKernelBroadCastSelecter() = default; | |||
| bool GetShapeInfo(SupportFormat *support_format); | |||
| bool IsBroadCastSupport5HD(SupportFormat *support_format) const; | |||
| bool IsBroadCastSupportFracZ(SupportFormat *support_format) const; | |||
| bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const; | |||
| bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const; | |||
| bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const; | |||
| private: | |||
| bool IsSameShape() const; | |||
| void PadScalarShape(std::vector<size_t> *shape) const; | |||
| bool Is4DShape(const std::vector<size_t> &shape) const; | |||
| bool IsScalarShape(const std::vector<size_t> &shape) const; | |||
| bool HasScalarInput() const; | |||
| void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; | |||
| void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; | |||
| // broadcast | |||
| CNodePtr cnode_ptr_; | |||
| size_t input_num_{}; | |||
| size_t output_num_{}; | |||
| std::vector<std::vector<size_t>> input_shapes_; | |||
| std::vector<std::vector<size_t>> output_shapes_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H | |||
| @@ -0,0 +1,180 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include "utils/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr char kKeepDims[] = "keep_dims"; | |||
| constexpr char kAxis[] = "axis"; | |||
| constexpr char kTypeInt32[] = "Int32"; | |||
| constexpr size_t kInputIndex_0 = 0; | |||
| constexpr size_t kOutputIndex_0 = 0; | |||
| constexpr size_t kChannelN = 0; | |||
| constexpr size_t kChannelC = 1; | |||
| constexpr size_t kReduceNZMinDim = 3; | |||
| bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| input_shape_.clear(); | |||
| output_shape_.clear(); | |||
| axis_.clear(); | |||
| auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | |||
| if (input_num != 1 || output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num | |||
| << ", output num: " << output_num; | |||
| } | |||
| // get input/output shape | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); | |||
| PadScalarShape(&input_shape_); | |||
| output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0); | |||
| PadScalarShape(&output_shape_); | |||
| // get keep dim attr | |||
| GetReduceAttrKeepDim(); | |||
| // get axis attr | |||
| GetReduceAttrAxis(); | |||
| AssignSupportFormat(kOpFormat_DEFAULT, support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| if (!Is4DShape(input_shape_)) { | |||
| return false; | |||
| } | |||
| if (!keep_dims_ || axis_.empty()) { | |||
| return false; | |||
| } | |||
| auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); | |||
| if (reduce_c_axis) { | |||
| return false; | |||
| } | |||
| AssignSupportFormat(kOpFormat_NC1HWC0, support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| // like to 5HD | |||
| return false; | |||
| } | |||
| bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { | |||
| return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format); | |||
| } | |||
| bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const { | |||
| return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format); | |||
| } | |||
| bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| if (input_shape_.size() < kReduceNZMinDim) { | |||
| return false; | |||
| } | |||
| if (axis_.empty()) { | |||
| return false; | |||
| } | |||
| auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) { | |||
| return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2)); | |||
| }); | |||
| if (reduce_last_axis) { | |||
| return false; | |||
| } | |||
| AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format, | |||
| mindspore::kernel::SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| if (!Is4DShape(input_shape_)) { | |||
| return false; | |||
| } | |||
| if (!keep_dims_ || axis_.empty()) { | |||
| return false; | |||
| } | |||
| auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(), | |||
| [](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); }); | |||
| if (reduce_n_c_axis) { | |||
| return false; | |||
| } | |||
| AssignSupportFormat(format, support_format); | |||
| return true; | |||
| } | |||
| void TbeKernelReduceSelecter::GetReduceAttrAxis() { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto axis = primitive->GetAttr(kAxis); | |||
| if (axis == nullptr) { | |||
| MS_LOG(INFO) << "This node does't have axie attr."; | |||
| return; | |||
| } | |||
| auto type = axis->type(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::vector<int> axis_list; | |||
| if (type->ToString() == kTypeInt32) { | |||
| axis_list.emplace_back(GetValue<int>(axis)); | |||
| } else { | |||
| axis_list = GetValue<std::vector<int>>(axis); | |||
| } | |||
| for (const auto &elem : axis_list) { | |||
| if (elem < 0) { | |||
| axis_.emplace_back(input_shape_.size() + elem); | |||
| } else { | |||
| axis_.emplace_back(IntToSize(elem)); | |||
| } | |||
| } | |||
| } | |||
| void TbeKernelReduceSelecter::GetReduceAttrKeepDim() { | |||
| if (!AnfAlgo::HasNodeAttr(kKeepDims, cnode_ptr_)) { | |||
| MS_LOG(INFO) << "This node does't have keep_attr."; | |||
| keep_dims_ = false; | |||
| return; | |||
| } | |||
| keep_dims_ = AnfAlgo::GetNodeAttr<bool>(cnode_ptr_, kKeepDims); | |||
| } | |||
| void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str, | |||
| mindspore::kernel::SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| SupportFormatItem input_support_format; | |||
| SupportFormatItem output_support_format; | |||
| input_support_format.emplace_back(support_format_str); | |||
| output_support_format.emplace_back(support_format_str); | |||
| support_format->input_format.emplace_back(input_support_format); | |||
| support_format->output_format.emplace_back(output_support_format); | |||
| } | |||
| bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; } | |||
| void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const { | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (shape->empty()) { | |||
| shape->emplace_back(1); | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ | |||
| #include <utility> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class TbeKernelReduceSelecter { | |||
| public: | |||
| explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} | |||
| ~TbeKernelReduceSelecter() = default; | |||
| bool GetShapeInfo(SupportFormat *support_format); | |||
| bool IsReduceSupport5HD(SupportFormat *support_format) const; | |||
| bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const; | |||
| bool IsReduceSupportFracZ(SupportFormat *support_format) const; | |||
| bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const; | |||
| bool IsReduceSupportFracNZ(SupportFormat *support_format) const; | |||
| private: | |||
| bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const; | |||
| void GetReduceAttrAxis(); | |||
| void GetReduceAttrKeepDim(); | |||
| void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; | |||
| bool Is4DShape(const std::vector<size_t> &shape) const; | |||
| void PadScalarShape(std::vector<size_t> *shape) const; | |||
| CNodePtr cnode_ptr_; | |||
| std::vector<size_t> input_shape_{}; | |||
| std::vector<size_t> output_shape_{}; | |||
| std::vector<size_t> axis_{}; | |||
| bool keep_dims_ = false; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H | |||
| @@ -0,0 +1,633 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | |||
| #include <memory> | |||
| #include <map> | |||
| #include <set> | |||
| #include <utility> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "kernel/tbe/tbe_kernel_build.h" | |||
| #include "nlohmann/json.hpp" | |||
| #include "utils/context/ms_context.h" | |||
| #include "kernel/tbe/tbe_python_funcs.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "kernel/tbe/tbe_convert_utils.h" | |||
| #include "parallel/ops_info/ops_utils.h" | |||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" | |||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" | |||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr auto kName = "name"; | |||
| constexpr auto kDtype = "dtype"; | |||
| constexpr auto kFormat = "format"; | |||
| constexpr auto kPrefixInput = "input"; | |||
| constexpr auto kPrefixOutput = "output"; | |||
| constexpr char kDynInputKey[] = "dyn_input_sizes"; | |||
| constexpr char kParamTypeDynamic[] = "dynamic"; | |||
| constexpr char kParamTypeRequre[] = "required"; | |||
| constexpr char kParamTypeOptional[] = "optional"; | |||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||
| auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list); | |||
| tbe_selecter.TbeMetadataInfoEx(); | |||
| } | |||
| TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) | |||
| : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {} | |||
| void TbeKernelSelect::TbeMetadataInfoEx() { | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list_); | |||
| node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); | |||
| auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); | |||
| if (!op_info_ptr) { | |||
| MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ | |||
| << ", node name: " << cnode_ptr_->fullname_with_scope(); | |||
| OpPattern pattern = op_info_ptr->op_pattern(); | |||
| if (pattern == kCommonPattern) { | |||
| GetCommonPatternKernelInfo(*op_info_ptr); | |||
| } else if (pattern == kDynamicFormatPattern) { | |||
| GetDynamicFormatPatternKernelInfo(*op_info_ptr); | |||
| } else if (pattern == kFormatAgnosticPattern) { | |||
| GetAgnosticPatternKernelInfo(*op_info_ptr); | |||
| } else if (pattern == kBroadcastPattern) { | |||
| GetBroadcastPatternKernelInfo(*op_info_ptr); | |||
| } else if (pattern == kReducePattern) { | |||
| GetReducePatternKernelInfo(*op_info_ptr); | |||
| } else { | |||
| MS_LOG(INFO) << "Warning: op pattern is invailed."; | |||
| } | |||
| // check support | |||
| FilterInVaildKernelInfo(); | |||
| MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; | |||
| } | |||
| void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | |||
| MS_LOG(INFO) << "start."; | |||
| // get dynamic inputs | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| std::vector<int> dyn_input_sizes; | |||
| if (primitive->HasAttr(kDynInputKey)) { | |||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr(kDynInputKey)); | |||
| } | |||
| // get real input/output num | |||
| size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||
| const auto inputs_info = op_info.inputs_ptr(); | |||
| size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | |||
| const auto outputs_info = op_info.outputs_ptr(); | |||
| if (inputs_info.empty() && outputs_info.empty()) { | |||
| MS_LOG(EXCEPTION) << "op info input & output is null, please check."; | |||
| } | |||
| // create kernel build info from opinfo | |||
| size_t kernel_build_info_num = | |||
| inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size(); | |||
| for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) { | |||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||
| SetTbeBuildCommonInfo(op_info, &builder); | |||
| std::vector<std::string> inputs_format; | |||
| std::vector<TypeId> inputs_device_type; | |||
| std::vector<std::vector<Axis>> inputs_reshape_type; | |||
| // input | |||
| if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, | |||
| &inputs_format, &inputs_device_type, &inputs_reshape_type)) { | |||
| break; | |||
| } | |||
| builder.SetInputsDeviceType(inputs_device_type); | |||
| builder.SetInputsFormat(inputs_format); | |||
| builder.SetInputReshapeType(inputs_reshape_type); | |||
| // output | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_device_type; | |||
| std::vector<std::vector<Axis>> outputs_reshape_type; | |||
| if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, | |||
| &outputs_format, &outputs_device_type, &outputs_reshape_type)) { | |||
| break; | |||
| } | |||
| builder.SetOutputsDeviceType(outputs_device_type); | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputReshapeType(outputs_reshape_type); | |||
| kernel_info_list_->emplace_back(builder.Build()); | |||
| } | |||
| MS_LOG(INFO) << "end."; | |||
| } | |||
| void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) { | |||
| MS_LOG(INFO) << "start."; | |||
| // | |||
| OpInfo op_info_new; | |||
| CreateNewOpInfo(op_info, &op_info_new); | |||
| GetCommonPatternKernelInfo(op_info_new); | |||
| MS_LOG(INFO) << "end."; | |||
| } | |||
| void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) { | |||
| MS_LOG(INFO) << "start."; | |||
| if (op_info.inputs_ptr().size() != 1) { | |||
| MS_LOG(EXCEPTION) << "AgnosticPattern only support one input."; | |||
| } | |||
| auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0); | |||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||
| MS_LOG(INFO) << "Got the unknown format " << format; | |||
| format = kOpFormat_DEFAULT; | |||
| } | |||
| SupportFormat support_format; | |||
| SupportFormatItem input_item; | |||
| SupportFormatItem output_item; | |||
| input_item.assign(op_info.inputs_ptr().size(), format); | |||
| output_item.assign(op_info.outputs_ptr().size(), format); | |||
| support_format.input_format.emplace_back(input_item); | |||
| support_format.output_format.emplace_back(output_item); | |||
| PrintSupportedFormat(support_format); | |||
| OpInfo op_info_new; | |||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||
| GetCommonPatternKernelInfo(op_info_new); | |||
| MS_LOG(INFO) << "end."; | |||
| } | |||
| void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { | |||
| MS_LOG(INFO) << "start."; | |||
| auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_); | |||
| SupportFormat support_format; | |||
| broadcast_selecter.GetShapeInfo(&support_format); | |||
| if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) { | |||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD."; | |||
| } | |||
| if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) { | |||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ."; | |||
| } | |||
| if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) { | |||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0."; | |||
| } | |||
| if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { | |||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; | |||
| } | |||
| PrintSupportedFormat(support_format); | |||
| OpInfo op_info_new; | |||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||
| GetCommonPatternKernelInfo(op_info_new); | |||
| MS_LOG(INFO) << "end."; | |||
| } | |||
| void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { | |||
| MS_LOG(INFO) << "start."; | |||
| auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_); | |||
| SupportFormat support_format; | |||
| reduce_selecter.GetShapeInfo(&support_format); | |||
| if (!reduce_selecter.IsReduceSupport5HD(&support_format)) { | |||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD."; | |||
| } | |||
| if (reduce_selecter.IsReduceSupportFracZ(&support_format)) { | |||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ."; | |||
| } | |||
| if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) { | |||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0."; | |||
| } | |||
| if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) { | |||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ."; | |||
| } | |||
| PrintSupportedFormat(support_format); | |||
| OpInfo op_info_new; | |||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||
| GetCommonPatternKernelInfo(op_info_new); | |||
| MS_LOG(INFO) << "end."; | |||
| } | |||
| void TbeKernelSelect::FilterInVaildKernelInfo() { | |||
| if (kernel_info_list_->empty()) { | |||
| MS_LOG(INFO) << "Warning: get kernel build info failed."; | |||
| return; | |||
| } | |||
| auto kernel_build_info_iter = kernel_info_list_->begin(); | |||
| while (kernel_build_info_iter != kernel_info_list_->end()) { | |||
| if (!FilterInVaildShape(kernel_build_info_iter)) { | |||
| MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); | |||
| kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); | |||
| continue; | |||
| } | |||
| if (!TbeCheckSupported(kernel_build_info_iter)) { | |||
| MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); | |||
| kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); | |||
| continue; | |||
| } | |||
| kernel_build_info_iter++; | |||
| } | |||
| } | |||
| bool TbeKernelSelect::FilterInVaildShape( | |||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||
| auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); | |||
| for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | |||
| auto format = kernel_build_info_inputs_format.at(i); | |||
| if (!IsShapeMatchFormat(shape, format)) { | |||
| MS_LOG(INFO) << "The " << i << "th input check failed."; | |||
| return false; | |||
| } | |||
| } | |||
| auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); | |||
| for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { | |||
| auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); | |||
| auto format = kernel_build_info_outputs_format.at(j); | |||
| if (!IsShapeMatchFormat(shape, format)) { | |||
| MS_LOG(INFO) << "The " << j << "th input check failed."; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||
| if (format == kOpFormat_DEFAULT) { | |||
| return true; | |||
| } | |||
| static std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||
| // if format is default, it remarkes support all format | |||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | |||
| } | |||
| // server not support format with C04 suffix | |||
| if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) != | |||
| kServerNotSupportFormat.end()) { | |||
| MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; | |||
| return false; | |||
| } | |||
| // not support format: | |||
| // 1 NDHWC with shape size != 5 | |||
| // 2 FRAC_NZ with shape size < 2 | |||
| // 3 !NDHWC with shape size > 4 | |||
| if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || | |||
| (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || | |||
| (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { | |||
| MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool TbeKernelSelect::TbeCheckSupported( | |||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||
| static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL, | |||
| parallel::BATCHMATMUL, | |||
| parallel::TOPK, | |||
| parallel::IN_TOPK, | |||
| parallel::PACK, | |||
| parallel::GATHER_ND, | |||
| parallel::UNSORTEF_SEGMENT_MIND, | |||
| parallel::UNSORTEF_SEGMENT_PRODD, | |||
| parallel::CAST}; | |||
| auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); | |||
| if (iter == kCheckSupportedOpType.end()) { | |||
| return true; | |||
| } | |||
| MS_LOG(INFO) << "Check support start."; | |||
| // replace kernel_info with current kernel info | |||
| auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); | |||
| AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); | |||
| nlohmann::json kernel_json; | |||
| TbeKernelJsonCreator creator(CHECK_SUPPORTED); | |||
| bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; | |||
| } | |||
| ret = TbePythonFuncs::CheckSupported(kernel_json); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); | |||
| return ret; | |||
| } | |||
| void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info, | |||
| mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) { | |||
| MS_EXCEPTION_IF_NULL(builder); | |||
| builder->SetProcessor(AICORE); | |||
| std::string fusion_type = op_info.fusion_type(); | |||
| if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { | |||
| builder->SetFusionType(tbe::GetFusionType(fusion_type)); | |||
| } | |||
| builder->SetOpPattern(op_info.op_pattern()); | |||
| builder->SetKernelType(TBE_KERNEL); | |||
| } | |||
| bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | |||
| const std::vector<int> &dyn_input_sizes, std::vector<std::string> *formats, | |||
| std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types) { | |||
| MS_EXCEPTION_IF_NULL(formats); | |||
| MS_EXCEPTION_IF_NULL(device_types); | |||
| MS_EXCEPTION_IF_NULL(reshape_types); | |||
| size_t dynamic_input_index = 0; | |||
| size_t real_io_tensor_index = 0; | |||
| size_t io_info_index = 0; | |||
| size_t io_info_num = ios_info.size(); | |||
| for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { | |||
| std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index]; | |||
| auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); | |||
| std::string kernel_build_info_format; | |||
| if (!io_info_item->formats().empty()) { | |||
| kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); | |||
| } | |||
| std::string io_param_type = io_info_item->param_type(); | |||
| std::vector<Axis> reshape_type; | |||
| StringToAxisVector(io_info_item->reshape_type(), &reshape_type); | |||
| if (io_param_type == kParamTypeDynamic) { | |||
| // dynamic io | |||
| if (is_input) { | |||
| if (dynamic_input_index >= dyn_input_sizes.size()) { | |||
| MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index | |||
| << ", dyn_input_sizes size: " << dyn_input_sizes.size(); | |||
| } | |||
| int dynamic_input_size = dyn_input_sizes[dynamic_input_index]; | |||
| for (int i = 0; i < dynamic_input_size; ++i) { | |||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||
| formats->emplace_back(kernel_build_info_format); | |||
| reshape_types->emplace_back(reshape_type); | |||
| } | |||
| dynamic_input_index++; | |||
| real_io_tensor_index += dynamic_input_size; | |||
| } else { | |||
| if (ios_info.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; | |||
| } | |||
| for (size_t i = 0; i < real_io_tensor_num; ++i) { | |||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||
| formats->emplace_back(kernel_build_info_format); | |||
| reshape_types->emplace_back(reshape_type); | |||
| } | |||
| real_io_tensor_index += real_io_tensor_num; | |||
| } | |||
| } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) { | |||
| // requre or optional io | |||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||
| formats->emplace_back(kernel_build_info_format); | |||
| reshape_types->emplace_back(reshape_type); | |||
| real_io_tensor_index++; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; | |||
| } | |||
| } | |||
| if (io_info_index != io_info_num) { | |||
| MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num | |||
| << "), this node may has optional input/output."; | |||
| } | |||
| if (real_io_tensor_index != real_io_tensor_num) { | |||
| std::string io_type = is_input ? "inputs " : "outputs"; | |||
| MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num | |||
| << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index | |||
| << ") != real_io_tensor_num(" << real_io_tensor_num << ")"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) { | |||
| MS_EXCEPTION_IF_NULL(reshape_type_vec); | |||
| for (const auto &c : reshape_type_str) { | |||
| switch (c) { | |||
| case 'N': | |||
| reshape_type_vec->push_back(kernel::N); | |||
| break; | |||
| case 'C': | |||
| reshape_type_vec->push_back(kernel::C); | |||
| break; | |||
| case 'H': | |||
| reshape_type_vec->push_back(kernel::H); | |||
| break; | |||
| case 'W': | |||
| reshape_type_vec->push_back(kernel::W); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; | |||
| } | |||
| } | |||
| } | |||
| void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, | |||
| const std::vector<std::vector<std::string>> &support_format_item, size_t index, | |||
| mindspore::kernel::OpIOInfo *op_io_info_new) { | |||
| MS_EXCEPTION_IF_NULL(op_io_info_new); | |||
| op_io_info_new->set_index(op_io_info.index()); | |||
| op_io_info_new->set_name(op_io_info.name()); | |||
| op_io_info_new->set_param_type(op_io_info.param_type()); | |||
| op_io_info_new->set_need_compile(op_io_info.need_compile()); | |||
| op_io_info_new->set_reshape_type(op_io_info.reshape_type()); | |||
| op_io_info_new->set_shape(op_io_info.shape()); | |||
| // dtype | |||
| std::vector<std::string> dtype_new; | |||
| auto dtype = op_io_info.dtypes(); | |||
| for (size_t i = 0; i < support_format_item.size(); ++i) { | |||
| dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end()); | |||
| } | |||
| op_io_info_new->set_dtypes(dtype_new); | |||
| // format | |||
| std::vector<std::string> format_new; | |||
| for (const auto &formats : support_format_item) { | |||
| auto format = formats.at(index); | |||
| for (size_t j = 0; j < dtype.size(); ++j) { | |||
| format_new.emplace_back(format); | |||
| } | |||
| } | |||
| op_io_info_new->set_formats(format_new); | |||
| } | |||
| std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { | |||
| const std::map<std::string, std::string> kDynamicFormatMap = { | |||
| {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; | |||
| if (op_select_json_item.empty()) { | |||
| MS_LOG(EXCEPTION) << "Op select ret item is null."; | |||
| } | |||
| const char space = ' '; | |||
| const char sep = ','; | |||
| std::string op_select_tmp = op_select_json_item + ","; | |||
| std::vector<std::string> ret; | |||
| auto begin = op_select_tmp.find_first_not_of(space, 0); | |||
| auto sep_pos = op_select_tmp.find(sep); | |||
| while (sep_pos != std::string::npos) { | |||
| auto obj = op_select_tmp.substr(begin, sep_pos - begin); | |||
| if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) { | |||
| obj = kDynamicFormatMap.at(obj); | |||
| } | |||
| ret.emplace_back(obj); | |||
| begin = op_select_tmp.find_first_not_of(space, sep_pos + 1); | |||
| sep_pos = op_select_tmp.find(sep, begin); | |||
| } | |||
| return ret; | |||
| } | |||
| std::string TbeKernelSelect::OpSelectFormat() { | |||
| nlohmann::json kernel_json; | |||
| std::string res_json_str; | |||
| TbeKernelJsonCreator creator(OP_SELECT_FORMAT); | |||
| bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; | |||
| } | |||
| res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); | |||
| if (res_json_str.empty()) { | |||
| MS_LOG(EXCEPTION) << "op select format error."; | |||
| } | |||
| MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; | |||
| return res_json_str; | |||
| } | |||
| void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format, | |||
| mindspore::kernel::OpInfo *op_info_new) { | |||
| MS_EXCEPTION_IF_NULL(op_info_new); | |||
| if (op_info.inputs_ptr().size() != support_format.input_format[0].size() || | |||
| op_info.outputs_ptr().size() != support_format.output_format[0].size()) { | |||
| MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size() | |||
| << ", input support size: " << support_format.input_format[0].size() | |||
| << ", op info output size: " << op_info.outputs_ptr().size() | |||
| << ", output support size: " << support_format.output_format[0].size(); | |||
| } | |||
| *op_info_new = op_info; | |||
| op_info_new->ClearInputs(); | |||
| op_info_new->ClearOutputs(); | |||
| for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { | |||
| auto input = op_info.inputs_ptr().at(i); | |||
| auto input_new = std::make_shared<OpIOInfo>(); | |||
| CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get()); | |||
| op_info_new->add_inputs_ptr(input_new); | |||
| } | |||
| for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) { | |||
| auto output = op_info.outputs_ptr().at(j); | |||
| auto output_new = std::make_shared<OpIOInfo>(); | |||
| CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get()); | |||
| op_info_new->add_outputs_ptr(output_new); | |||
| } | |||
| } | |||
| struct SelectOpIOInfo { | |||
| std::string name; | |||
| std::vector<std::string> dtypes; | |||
| std::vector<std::string> formats; | |||
| }; | |||
| void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, | |||
| mindspore::kernel::OpInfo *op_info_new) { | |||
| MS_EXCEPTION_IF_NULL(op_info_new); | |||
| auto op_seclect_json = OpSelectFormat(); | |||
| if (!op_seclect_json.empty()) { | |||
| nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json); | |||
| if (!json_obj.is_object()) { | |||
| MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json; | |||
| } | |||
| std::vector<SelectOpIOInfo> inputs; | |||
| std::vector<SelectOpIOInfo> outputs; | |||
| for (const auto &item : json_obj.items()) { | |||
| const std::string &item_name = item.key(); | |||
| bool is_input = (item_name.find(kPrefixInput) != std::string::npos); | |||
| bool is_output = (item_name.find(kPrefixOutput) != std::string::npos); | |||
| if (!is_input && !is_output) { | |||
| MS_LOG(EXCEPTION) << "op select ret json is error."; | |||
| } | |||
| if (is_input) { | |||
| SelectOpIOInfo select_input; | |||
| select_input.name = item.value().at(kName); | |||
| std::string input_dtype_item = item.value().at(kDtype); | |||
| select_input.dtypes = SplitStrToVec(input_dtype_item); | |||
| std::string input_format_item = item.value().at(kFormat); | |||
| select_input.formats = SplitStrToVec(input_format_item); | |||
| inputs.emplace_back(select_input); | |||
| } else if (is_output) { | |||
| SelectOpIOInfo select_output; | |||
| select_output.name = item.value().at(kName); | |||
| std::string input_dtype_item = item.value().at(kDtype); | |||
| select_output.dtypes = SplitStrToVec(input_dtype_item); | |||
| std::string input_format_item = item.value().at(kFormat); | |||
| select_output.formats = SplitStrToVec(input_format_item); | |||
| outputs.emplace_back(select_output); | |||
| } | |||
| } | |||
| if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) { | |||
| MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register."; | |||
| } | |||
| *op_info_new = op_info; | |||
| op_info_new->ClearInputs(); | |||
| op_info_new->ClearOutputs(); | |||
| for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { | |||
| auto input_new = std::make_shared<OpIOInfo>(); | |||
| CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get()); | |||
| op_info_new->add_inputs_ptr(input_new); | |||
| } | |||
| for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) { | |||
| auto output_new = std::make_shared<OpIOInfo>(); | |||
| CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get()); | |||
| op_info_new->add_outputs_ptr(output_new); | |||
| } | |||
| } | |||
| } | |||
| void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, | |||
| const std::vector<std::string> &support_dtype, | |||
| const std::vector<std::string> &support_format, | |||
| mindspore::kernel::OpIOInfo *op_io_info_new) { | |||
| MS_EXCEPTION_IF_NULL(op_io_info_new); | |||
| op_io_info_new->set_index(op_io_info.index()); | |||
| op_io_info_new->set_name(op_io_info.name()); | |||
| op_io_info_new->set_param_type(op_io_info.param_type()); | |||
| op_io_info_new->set_need_compile(op_io_info.need_compile()); | |||
| op_io_info_new->set_reshape_type(op_io_info.reshape_type()); | |||
| op_io_info_new->set_shape(op_io_info.shape()); | |||
| // dtype | |||
| std::vector<std::string> dtype_new; | |||
| for (size_t i = 0; i < support_format.size(); ++i) { | |||
| dtype_new.insert(dtype_new.end(), support_dtype.begin(), support_dtype.end()); | |||
| } | |||
| op_io_info_new->set_dtypes(dtype_new); | |||
| // format | |||
| std::vector<std::string> format_new; | |||
| for (const auto &format : support_format) { | |||
| for (size_t j = 0; j < support_dtype.size(); ++j) { | |||
| format_new.emplace_back(format); | |||
| } | |||
| } | |||
| op_io_info_new->set_formats(format_new); | |||
| } | |||
| void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) { | |||
| if (support_format.input_format.size() != support_format.output_format.size()) { | |||
| MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output(" | |||
| << support_format.output_format.size() << ") size not match."; | |||
| } | |||
| for (size_t i = 0; i < support_format.input_format.size(); ++i) { | |||
| auto input_items = support_format.input_format.at(i); | |||
| auto output_items = support_format.output_format.at(i); | |||
| std::string print_str = "["; | |||
| for (const auto &input : input_items) { | |||
| print_str.append(input); | |||
| print_str.append(", "); | |||
| } | |||
| print_str.append("] -->"); | |||
| for (const auto &output : output_items) { | |||
| print_str.append(output); | |||
| print_str.append(", "); | |||
| } | |||
| MS_LOG(INFO) << "Support format: " << print_str; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_TBE_KERNEL_SELECT_H | |||
| #define MINDSPORE_TBE_KERNEL_SELECT_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "kernel/oplib/opinfo.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | |||
| class TbeKernelSelect { | |||
| using OpInfoPtr = std::shared_ptr<OpInfo>; | |||
| using KernelBuildInfoIter = std::vector<std::shared_ptr<KernelBuildInfo>>::iterator; | |||
| public: | |||
| TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | |||
| ~TbeKernelSelect() = default; | |||
| void TbeMetadataInfoEx(); | |||
| private: | |||
| void GetCommonPatternKernelInfo(const OpInfo &op_info); | |||
| void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info); | |||
| void GetAgnosticPatternKernelInfo(const OpInfo &op_info); | |||
| void GetBroadcastPatternKernelInfo(const OpInfo &op_info); | |||
| void GetReducePatternKernelInfo(const OpInfo &op_info); | |||
| void FilterInVaildKernelInfo(); | |||
| bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); | |||
| static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format); | |||
| bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); | |||
| static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); | |||
| bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, const std::vector<int> &dyn_input_sizes, | |||
| std::vector<std::string> *formats, std::vector<TypeId> *device_types, | |||
| std::vector<std::vector<Axis>> *reshape_types); | |||
| static void StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec); | |||
| static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); | |||
| static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, | |||
| const std::vector<std::vector<std::string>> &support_format_item, size_t index, | |||
| OpIOInfo *op_io_info_new); | |||
| // op select(dynamic) | |||
| void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new); | |||
| static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector<std::string> &support_dtype, | |||
| const std::vector<std::string> &support_format, OpIOInfo *op_io_info_new); | |||
| static std::vector<std::string> SplitStrToVec(const std::string &op_select_json_item); | |||
| std::string OpSelectFormat(); | |||
| static void PrintSupportedFormat(const SupportFormat &support_format); | |||
| private: | |||
| CNodePtr cnode_ptr_; | |||
| std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list_; | |||
| std::string node_name_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_TBE_KERNEL_SELECT_H | |||
| @@ -216,6 +216,13 @@ constexpr char NEG[] = "Neg"; | |||
| constexpr char BATCH_MATMUL[] = "BatchMatMul"; | |||
| constexpr char EXPAND_DIMS[] = "ExpandDims"; | |||
| constexpr char SQUARE[] = "Square"; | |||
| constexpr char BATCHMATMUL[] = "BatchMatMul"; | |||
| constexpr char TOPK[] = "TopK"; | |||
| constexpr char IN_TOPK[] = "InTopK"; | |||
| constexpr char PACK[] = "Pack"; | |||
| constexpr char GATHER_ND[] = "GatherNd"; | |||
| constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; | |||
| constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; | |||
| // Parallel don't care | |||
| constexpr char TUPLE_GETITEM[] = "tuple_getitem"; | |||
| @@ -21,7 +21,6 @@ | |||
| #include <vector> | |||
| #include "device/ascend/kernel_select_ascend.h" | |||
| #include "kernel/kernel_query.h" | |||
| #include "kernel/tbe/tbe_kernel_select.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| @@ -34,7 +34,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph | |||
| return nullptr; | |||
| } | |||
| auto node_name = AnfAlgo::GetCNodeName(node); | |||
| if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) { | |||
| if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) { | |||
| return nullptr; | |||
| } | |||
| auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); | |||
| @@ -26,12 +26,9 @@ abs_op_info = TBERegOp("Abs") \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .input(0, "x", None, "required", None) \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||
| .get_op_info() | |||
| @@ -23,7 +23,6 @@ abs_grad_op_info = TBERegOp("AbsGrad") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("abs_grad") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .input(0, "y", None, "required", None) \ | |||
| .input(1, "dy", None, "required", None) \ | |||
| .output(0, "z", False, "required", "all") \ | |||
| @@ -26,6 +26,7 @@ add_op_info = TBERegOp("Add") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| @@ -26,17 +26,10 @@ add_n_op_info = TBERegOp("AddN") \ | |||
| .attr("n", "required", "int", "all") \ | |||
| .input(0, "x", False, "dynamic", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||
| .get_op_info() | |||
| @@ -29,6 +29,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "bias", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ | |||
| @@ -27,6 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "bias", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| @@ -26,6 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \ | |||
| .input(0, "x", False, "required", "all", reshape_type="NC") \ | |||
| .output(0, "sum", False, "required", "all") \ | |||
| .output(1, "square_sum", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @@ -32,6 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \ | |||
| .input(5, "batch_mean", False, "required", "all") \ | |||
| .input(6, "batch_variance", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all", reshape_type="NC") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| @@ -30,6 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \ | |||
| .input(3, "batch_variance", False, "required", "all") \ | |||
| .output(0, "diff_scale", False, "required", "all") \ | |||
| .output(1, "diff_offset", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| @@ -32,6 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \ | |||
| .output(0, "y", False, "required", "all", reshape_type="NC") \ | |||
| .output(1, "batch_mean", False, "required", "all") \ | |||
| .output(2, "batch_variance", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| @@ -26,32 +26,27 @@ cast_op_info = TBERegOp("Cast") \ | |||
| .attr("dst_type", "required", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .dtype_format(DataType.BOOL_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.BOOL_None, DataType.U8_None) \ | |||
| .dtype_format(DataType.BOOL_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.BOOL_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.BOOL_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I8_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.U8_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.U8_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.I32_None) \ | |||
| .get_op_info() | |||
| @@ -26,6 +26,7 @@ concat_op_info = TBERegOp("Concat") \ | |||
| .attr("axis", "required", "int", "all") \ | |||
| .input(0, "input_values", False, "dynamic", "all") \ | |||
| .output(0, "output_data", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| @@ -23,6 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv2d") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .attr("stride", "required", "listInt", "all") \ | |||
| .attr("pad_list", "required", "listInt", "all") \ | |||
| .attr("dilation", "required", "listInt", "all") \ | |||
| @@ -32,8 +33,7 @@ conv2d_op_info = TBERegOp("Conv2D") \ | |||
| .input(2, "bias", False, "optional", "all") \ | |||
| .input(3, "offset_w", False, "optional", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default, | |||
| DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \ | |||
| .get_op_info() | |||
| @@ -27,6 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \ | |||
| .input(1, "mask", False, "required", "all") \ | |||
| .input(2, "keep_prob", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @@ -28,9 +28,7 @@ elu_op_info = TBERegOp("Elu") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @@ -26,9 +26,7 @@ erf_op_info = TBERegOp("Erf") \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @@ -26,9 +26,7 @@ erfc_op_info = TBERegOp("Erfc") \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @@ -27,9 +27,7 @@ expm1_op_info = TBERegOp("Expm1") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @@ -27,6 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .input(2, "x3", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ | |||
| @@ -32,6 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "mean", False, "required", "all") \ | |||
| .output(2, "variance", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| @@ -30,6 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop") | |||
| .input(3, "mean", False, "required", "all") \ | |||
| .output(0, "pd_gamma", False, "required", "all") \ | |||
| .output(1, "pd_beta", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| @@ -29,6 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \ | |||
| .input(3, "mean", False, "required", "all") \ | |||
| .input(4, "gamma", False, "required", "all") \ | |||
| .output(0, "pd_x", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | |||
| @@ -26,21 +26,8 @@ mul_op_info = TBERegOp("Mul") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "y", False, "required", "all") \ | |||
| .output(0, "output", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ | |||
| .dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \ | |||
| .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ | |||
| .get_op_info() | |||
| @@ -26,10 +26,9 @@ realdiv_op_info = TBERegOp("RealDiv") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "y", False, "required", "all") \ | |||
| .output(0, "z", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @@ -25,6 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | |||
| @@ -27,11 +27,11 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \ | |||
| .attr("keep_dims", "optional", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .op_pattern("reduce") \ | |||
| .dtype_format(DataType.I8_None, DataType.I8_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.U8_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @@ -24,7 +24,7 @@ relu_grad_v2_op_info = TBERegOp("ReluGradV2") \ | |||
| .kernel_name("relu_grad_v2") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "gradients", False, "required", "all") \ | |||
| .input(1, "mask", False, "rerequired", "all") \ | |||
| .input(1, "mask", False, "required", "all") \ | |||
| .output(0, "backprops", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \ | |||
| @@ -27,6 +27,7 @@ select_op_info = TBERegOp("Select") \ | |||
| .input(1, "x1", False, "required", "all") \ | |||
| .input(2, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| @@ -27,11 +27,8 @@ sign_op_info = TBERegOp("Sign") \ | |||
| .input(0, "x", None, "required", None) \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .get_op_info() | |||
| @@ -30,6 +30,7 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \ | |||
| .input(1, "x1", False, "required", "all") \ | |||
| .input(2, "x2", False, "required", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, | |||
| @@ -27,9 +27,7 @@ softplus_op_info = TBERegOp("Softplus") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @@ -28,9 +28,7 @@ softplus_grad_op_info = TBERegOp("SoftplusGrad") \ | |||
| .input(1, "features", False, "required", "all") \ | |||
| .output(0, "backprops", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .get_op_info() | |||
| @@ -27,6 +27,7 @@ split_d_op_info = TBERegOp("Split") \ | |||
| .attr("output_num", "required", "int", "all") \ | |||
| .input(0, "value", False, "required", "all") \ | |||
| .output(0, "output", False, "dynamic", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| @@ -26,6 +26,7 @@ tensor_add_op_info = TBERegOp("TensorAdd") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| @@ -27,6 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "segment_ids", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ | |||
| @@ -97,6 +97,7 @@ class RegOp: | |||
| """ | |||
| if not isinstance(value, str): | |||
| raise TypeError("%s value must be str" % str(value)) | |||
| return True | |||
| def _is_int(self, value): | |||
| """ | |||
| @@ -110,6 +111,7 @@ class RegOp: | |||
| """ | |||
| if not isinstance(value, int): | |||
| raise TypeError("%s value must be int" % str(value)) | |||
| return True | |||
| def _is_bool(self, value): | |||
| """ | |||
| @@ -123,6 +125,7 @@ class RegOp: | |||
| """ | |||
| if not isinstance(value, bool): | |||
| raise TypeError("%s value must be bool" % str(value)) | |||
| return True | |||
| def _check_param(self, param_list, key_list, fn_list, kwargs): | |||
| """ | |||
| @@ -494,6 +497,7 @@ class DataType: | |||
| The current list below maybe not completed. If necessary, please add it. | |||
| """ | |||
| None_None = ("", "") | |||
| BOOL_None = ("bool", "") | |||
| BOOL_Default = ("bool", "DefaultFormat") | |||
| BOOL_5HD = ("bool", "NC1HWC0") | |||