| @@ -20,7 +20,7 @@ | |||||
| #include "kernel/aicpu/aicpu_kernel_metadata.h" | #include "kernel/aicpu/aicpu_kernel_metadata.h" | ||||
| #include "kernel/rts/rt_kernel_info.h" | #include "kernel/rts/rt_kernel_info.h" | ||||
| #include "kernel/hccl/hccl_kernel_metadata.h" | #include "kernel/hccl/hccl_kernel_metadata.h" | ||||
| #include "kernel/tbe/tbe_kernel_select.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -63,7 +63,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| TbeMetadataInfo(kernel_node, kernel_info_list); | TbeMetadataInfo(kernel_node, kernel_info_list); | ||||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | |||||
| if (kernel_info_list->empty()) { | if (kernel_info_list->empty()) { | ||||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | AicpuMetadataInfo(kernel_node, kernel_info_list); | ||||
| if (!kernel_info_list->empty()) { | if (!kernel_info_list->empty()) { | ||||
| @@ -114,7 +113,6 @@ bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr | |||||
| auto cnode = kernel_node->cast<CNodePtr>(); | auto cnode = kernel_node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| TbeMetadataInfo(cnode, &kernel_info_list); | TbeMetadataInfo(cnode, &kernel_info_list); | ||||
| FilterInvalidKernelInfo(cnode, &kernel_info_list); | |||||
| return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | ||||
| [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -126,6 +126,8 @@ class OpInfo { | |||||
| bool is_ref() const { return !ref_infos_.empty(); } | bool is_ref() const { return !ref_infos_.empty(); } | ||||
| bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } | bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } | ||||
| void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } | void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } | ||||
| void ClearInputs() { (void)inputs_ptr_.clear(); } | |||||
| void ClearOutputs() { (void)outputs_ptr_.clear(); } | |||||
| private: | private: | ||||
| std::string op_name_; | std::string op_name_; | ||||
| @@ -35,7 +35,7 @@ constexpr auto kKernelName = "kernel_name"; | |||||
| constexpr auto kPartialFlag = "partial_flag"; | constexpr auto kPartialFlag = "partial_flag"; | ||||
| constexpr auto kReshapeType = "reshape_type"; | constexpr auto kReshapeType = "reshape_type"; | ||||
| constexpr auto kOpPattern = "op_pattern"; | constexpr auto kOpPattern = "op_pattern"; | ||||
| constexpr auto kDynamicFormat = "dynamic_format"; | |||||
| constexpr auto kDynamicFormat = "dynamicFormat"; | |||||
| constexpr auto kFormatAgnostic = "formatAgnostic"; | constexpr auto kFormatAgnostic = "formatAgnostic"; | ||||
| constexpr auto kBroadcast = "broadcast"; | constexpr auto kBroadcast = "broadcast"; | ||||
| constexpr auto kReduce = "reduce"; | constexpr auto kReduce = "reduce"; | ||||
| @@ -100,7 +100,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) | |||||
| void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | ||||
| const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, | const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, | ||||
| {kFormatAgnostic, kBroadcastPattern}, | |||||
| {kBroadcast, kBroadcastPattern}, | |||||
| {kReduce, kReducePattern}, | {kReduce, kReducePattern}, | ||||
| {kDynamicFormat, kDynamicFormatPattern}}; | {kDynamicFormat, kDynamicFormatPattern}}; | ||||
| op_info->set_async_flag(obj.at(kAsyncFlag)); | op_info->set_async_flag(obj.at(kAsyncFlag)); | ||||
| @@ -108,14 +108,19 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p | |||||
| op_info->set_compute_cost(obj.at(kComputeCost)); | op_info->set_compute_cost(obj.at(kComputeCost)); | ||||
| op_info->set_kernel_name(obj.at(kKernelName)); | op_info->set_kernel_name(obj.at(kKernelName)); | ||||
| op_info->set_partial_flag(obj.at(kPartialFlag)); | op_info->set_partial_flag(obj.at(kPartialFlag)); | ||||
| if (obj.find(kOpPattern) != obj.end()) { | if (obj.find(kOpPattern) != obj.end()) { | ||||
| if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) { | |||||
| op_info->set_op_pattern(obj.at(kOpPattern)); | |||||
| std::string op_pattern = obj.at(kOpPattern); | |||||
| auto find_iter = kOpPatternMap.find(op_pattern); | |||||
| if (find_iter == kOpPatternMap.end()) { | |||||
| if (!op_pattern.empty()) { | |||||
| MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern; | |||||
| } | |||||
| op_info->set_op_pattern(kCommonPattern); | |||||
| } else { | |||||
| op_info->set_op_pattern(find_iter->second); | |||||
| } | } | ||||
| } | } | ||||
| if (obj.find(kDynamicFormat) != obj.end()) { | |||||
| op_info->set_dynamic_format(obj.at(kDynamicFormat)); | |||||
| } | |||||
| } | } | ||||
| bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, | bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, | ||||
| @@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = { | |||||
| {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, | {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, | ||||
| {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, | {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, | ||||
| {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, | {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, | ||||
| {TypeId::kNumberTypeBool, "bool"}, | |||||
| {TypeId::kNumberTypeBool, "int8"}, | |||||
| }; | }; | ||||
| const std::map<std::string, std::string> type_str_maps = { | const std::map<std::string, std::string> type_str_maps = { | ||||
| @@ -85,7 +85,7 @@ std::string DtypeToString(const std::string &dtypes) { | |||||
| std::string TypeIdToString(TypeId type_id) { | std::string TypeIdToString(TypeId type_id) { | ||||
| auto iter = type_id_str_maps.find(type_id); | auto iter = type_id_str_maps.find(type_id); | ||||
| if (iter == type_id_str_maps.end()) { | if (iter == type_id_str_maps.end()) { | ||||
| MS_LOG(EXCEPTION) << "Illegal input dtype." << TypeIdLabel(type_id); | |||||
| MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id); | |||||
| } | } | ||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| @@ -111,41 +111,20 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node, | |||||
| if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { | if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { | ||||
| TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); | TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); | ||||
| } else { | } else { | ||||
| // dtype : float16 | |||||
| auto tensor_dtype = | |||||
| std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index))); | |||||
| MS_EXCEPTION_IF_NULL(tensor_dtype); | |||||
| std::string dtype = tensor_dtype->element()->ToString(); | |||||
| dtype = tbe::DtypeToString(dtype); | |||||
| // format | |||||
| std::string format = AnfAlgo::GetInputFormat(anf_node, real_input_index); | |||||
| if (format == kOpFormat_DEFAULT) { | |||||
| format = kOpFormat_NCHW; | |||||
| } else if (format == kOpFormat_FRAC_Z) { | |||||
| format = kOpFormat_FRACTAL_Z; | |||||
| } | |||||
| nlohmann::json input_desc_json; | |||||
| input_desc_json["dtype"] = dtype; | |||||
| input_desc_json["name"] = op_input_name + std::to_string(input_i); | |||||
| auto dtype = GetDeviceInputType(anf_node, real_input_index); | |||||
| auto format = GetDeviceInputFormat(anf_node, real_input_index); | |||||
| auto shape = GetDeviceInputShape(anf_node, real_input_index); | |||||
| auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | ||||
| if (ori_shape.empty()) { | if (ori_shape.empty()) { | ||||
| ori_shape.emplace_back(1); | ori_shape.emplace_back(1); | ||||
| } | } | ||||
| nlohmann::json input_desc_json; | |||||
| input_desc_json["dtype"] = dtype; | |||||
| input_desc_json["name"] = op_input_name + std::to_string(input_i); | |||||
| input_desc_json["ori_shape"] = ori_shape; | input_desc_json["ori_shape"] = ori_shape; | ||||
| input_desc_json["ori_format"] = kOpFormat_NCHW; | input_desc_json["ori_format"] = kOpFormat_NCHW; | ||||
| auto shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); | |||||
| if (shape.empty()) { | |||||
| shape.emplace_back(1); | |||||
| } | |||||
| if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { | |||||
| input_desc_json["shape"] = ori_shape; | |||||
| input_desc_json["format"] = kOpFormat_NCHW; | |||||
| } else { | |||||
| input_desc_json["shape"] = shape; | |||||
| input_desc_json["format"] = format; | |||||
| } | |||||
| input_desc_json["shape"] = shape; | |||||
| input_desc_json["format"] = format; | |||||
| input_desc_json["valid"] = value; | input_desc_json["valid"] = value; | ||||
| input_desc_json["param_type"] = input_ptr->param_type(); | input_desc_json["param_type"] = input_ptr->param_type(); | ||||
| input_list->emplace_back(input_desc_json); | input_list->emplace_back(input_desc_json); | ||||
| @@ -325,40 +304,22 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co | |||||
| MS_EXCEPTION_IF_NULL(output_idx); | MS_EXCEPTION_IF_NULL(output_idx); | ||||
| MS_EXCEPTION_IF_NULL(output_list); | MS_EXCEPTION_IF_NULL(output_list); | ||||
| for (size_t i = 0; i < output_obj_num; i++) { | for (size_t i = 0; i < output_obj_num; i++) { | ||||
| nlohmann::json output_obj; | |||||
| auto type_ptr = std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, *output_idx))); | |||||
| std::string dtype = type_ptr->element()->ToString(); | |||||
| dtype = tbe::DtypeToString(dtype); | |||||
| std::string format = AnfAlgo::GetOutputFormat(anf_node, *output_idx); | |||||
| if (format == kOpFormat_DEFAULT) { | |||||
| format = kOpFormat_NCHW; | |||||
| } else if (format == kOpFormat_FRAC_Z) { | |||||
| format = kOpFormat_FRACTAL_Z; | |||||
| } | |||||
| std::vector<size_t> ori_shape; | |||||
| if (AnfAlgo::GetOutputInferShape(anf_node, *output_idx).empty()) { | |||||
| auto dtype = GetDeviceOutputType(anf_node, *output_idx); | |||||
| auto format = GetDeviceOutputFormat(anf_node, *output_idx); | |||||
| auto shape = GetDeviceOutputShape(anf_node, *output_idx); | |||||
| std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); | |||||
| if (ori_shape.empty()) { | |||||
| ori_shape.emplace_back(1); | ori_shape.emplace_back(1); | ||||
| } else { | |||||
| ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); | |||||
| } | } | ||||
| nlohmann::json output_obj; | |||||
| output_obj["dtype"] = dtype; | output_obj["dtype"] = dtype; | ||||
| auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, *output_idx); | |||||
| if (shape.empty()) { | |||||
| shape.emplace_back(1); | |||||
| } | |||||
| if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { | |||||
| output_obj["shape"] = ori_shape; | |||||
| output_obj["format"] = kOpFormat_NCHW; | |||||
| } else { | |||||
| output_obj["shape"] = shape; | |||||
| output_obj["format"] = format; | |||||
| } | |||||
| output_obj["shape"] = shape; | |||||
| output_obj["format"] = format; | |||||
| output_obj["ori_shape"] = ori_shape; | output_obj["ori_shape"] = ori_shape; | ||||
| output_obj["ori_format"] = kOpFormat_NCHW; | output_obj["ori_format"] = kOpFormat_NCHW; | ||||
| output_obj["name"] = output_ptr->name(); | output_obj["name"] = output_ptr->name(); | ||||
| output_obj["valid"] = true; | output_obj["valid"] = true; | ||||
| output_obj["param_type"] = output_ptr->param_type(); | output_obj["param_type"] = output_ptr->param_type(); | ||||
| output_list->emplace_back(output_obj); | output_list->emplace_back(output_obj); | ||||
| (*output_idx)++; | (*output_idx)++; | ||||
| } | } | ||||
| @@ -456,6 +417,84 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo | |||||
| } | } | ||||
| } | } | ||||
| std::vector<size_t> TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| std::vector<size_t> shape; | |||||
| if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { | |||||
| shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index); | |||||
| } else { | |||||
| shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index); | |||||
| } | |||||
| if (shape.empty()) { | |||||
| shape.emplace_back(1); | |||||
| } | |||||
| return shape; | |||||
| } | |||||
| std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| TypeId type_id; | |||||
| if (creater_type_ == OP_SELECT_FORMAT) { | |||||
| type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index); | |||||
| } else { | |||||
| type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index); | |||||
| } | |||||
| return tbe::TypeIdToString(type_id); | |||||
| } | |||||
| std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| std::string format = kOpFormat_NCHW; | |||||
| if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { | |||||
| format = AnfAlgo::GetInputFormat(anf_node, real_index); | |||||
| if (format == kOpFormat_FRAC_Z) { | |||||
| format = kOpFormat_FRACTAL_Z; | |||||
| } else if (format == kOpFormat_DEFAULT) { | |||||
| format = kOpFormat_NCHW; | |||||
| } | |||||
| } | |||||
| return format; | |||||
| } | |||||
| std::vector<size_t> TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| std::vector<size_t> shape; | |||||
| if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { | |||||
| shape = AnfAlgo::GetOutputInferShape(anf_node, real_index); | |||||
| } else { | |||||
| shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index); | |||||
| } | |||||
| if (shape.empty()) { | |||||
| shape.emplace_back(1); | |||||
| } | |||||
| return shape; | |||||
| } | |||||
| std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| TypeId type_id; | |||||
| if (creater_type_ == OP_SELECT_FORMAT) { | |||||
| type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index); | |||||
| } else { | |||||
| type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index); | |||||
| } | |||||
| return tbe::TypeIdToString(type_id); | |||||
| } | |||||
| std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| std::string format = kOpFormat_NCHW; | |||||
| if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { | |||||
| format = AnfAlgo::GetOutputFormat(anf_node, real_index); | |||||
| if (format == kOpFormat_FRAC_Z) { | |||||
| format = kOpFormat_FRACTAL_Z; | |||||
| } else if (format == kOpFormat_DEFAULT) { | |||||
| format = kOpFormat_NCHW; | |||||
| } | |||||
| } | |||||
| return format; | |||||
| } | |||||
| bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list, | bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list, | ||||
| std::vector<size_t> *output_size_list) { | std::vector<size_t> *output_size_list) { | ||||
| if (input_size_list == nullptr || output_size_list == nullptr) { | if (input_size_list == nullptr || output_size_list == nullptr) { | ||||
| @@ -93,7 +93,7 @@ class TbeKernelJsonCreator { | |||||
| nlohmann::json *outputs_json); | nlohmann::json *outputs_json); | ||||
| bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info, | bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info, | ||||
| nlohmann::json *attrs_json); | nlohmann::json *attrs_json); | ||||
| void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); | |||||
| static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); | |||||
| bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value, | bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value, | ||||
| const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i, | const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i, | ||||
| std::vector<nlohmann::json> *input_list); | std::vector<nlohmann::json> *input_list); | ||||
| @@ -105,6 +105,13 @@ class TbeKernelJsonCreator { | |||||
| void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num, | void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num, | ||||
| const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx, | const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx, | ||||
| std::vector<nlohmann::json> *output_list); | std::vector<nlohmann::json> *output_list); | ||||
| std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const; | |||||
| std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const; | |||||
| std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const; | |||||
| std::vector<size_t> GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const; | |||||
| std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const; | |||||
| std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const; | |||||
| kCreaterType creater_type_; | kCreaterType creater_type_; | ||||
| std::string json_name_; | std::string json_name_; | ||||
| std::string json_info_; | std::string json_info_; | ||||
| @@ -230,7 +230,7 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t | |||||
| task_iter->second.output_size_list, kernel_pack); | task_iter->second.output_size_list, kernel_pack); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| if (set_kernel_mod) { | if (set_kernel_mod) { | ||||
| AnfAlgo ::SetKernelMod(kernel_mod, task_iter->second.node); | |||||
| AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); | |||||
| } | } | ||||
| auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); | auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); | ||||
| (void)task_map_.erase(task_iter); | (void)task_map_.erase(task_iter); | ||||
| @@ -1,664 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/tbe/tbe_kernel_select.h" | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "kernel/tbe/tbe_kernel_build.h" | |||||
| #include "nlohmann/json.hpp" | |||||
| #include "common/utils.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include "kernel/tbe/tbe_python_funcs.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "kernel/tbe/tbe_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| constexpr auto kName = "name"; | |||||
| constexpr auto kDtype = "dtype"; | |||||
| constexpr auto kFormat = "format"; | |||||
| constexpr auto kPrefixInput = "input"; | |||||
| constexpr auto kPrefixOutput = "output"; | |||||
| const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"}, | |||||
| {"NHWC", "DefaultFormat"}, | |||||
| {"ND", "DefaultFormat"}, | |||||
| {"FRACTAL_Z", "FracZ"}, | |||||
| {"NDHWC", "DefaultFormat"}}; | |||||
| static const std::vector<std::string> CHECK_SUPPORTED_OPTYPE{ | |||||
| "MatMul", "BatchMatMul", "TopK", "InTopK", "Pack", "GatherNd", "UnsortedSegmentMinD", "UnsortedSegmentProdD", "Cast"}; | |||||
| bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(select_kernel_build_info); | |||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||||
| auto iter = std::find(CHECK_SUPPORTED_OPTYPE.begin(), CHECK_SUPPORTED_OPTYPE.end(), op_name); | |||||
| if (iter == CHECK_SUPPORTED_OPTYPE.end()) { | |||||
| MS_LOG(DEBUG) << "Op " << op_name << "this op does not need to check op supported."; | |||||
| return true; | |||||
| } | |||||
| // replace kernel_info with current kernel info | |||||
| auto ori_select_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(anf_node); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(select_kernel_build_info, anf_node.get()); | |||||
| nlohmann::json kernel_json; | |||||
| TbeKernelJsonCreator creator(CHECK_SUPPORTED); | |||||
| bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json); | |||||
| if (!ret) { | |||||
| MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed"; | |||||
| AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get()); | |||||
| return false; | |||||
| } | |||||
| ret = TbePythonFuncs::CheckSupported(kernel_json); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get()); | |||||
| return ret; | |||||
| } | |||||
| bool CheckJsonItemValidity(const nlohmann::json &json_obj, const std::string &key_name, | |||||
| const std::vector<std::string> &keys) { | |||||
| if (!json_obj[key_name].is_object()) { | |||||
| MS_LOG(DEBUG) << key_name << "is not an object!"; | |||||
| return false; | |||||
| } | |||||
| for (auto key : keys) { | |||||
| if (json_obj[key_name].find(key) == json_obj[key_name].end()) { | |||||
| MS_LOG(DEBUG) << "Key" << key << "of " << key_name << " is not found!"; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::vector<std::string> SplitStr(const std::string &string, const std::string &sep) { | |||||
| std::vector<std::string> result; | |||||
| size_t start = 0; | |||||
| size_t index = string.find(sep, start); | |||||
| std::string substr; | |||||
| while (index != std::string::npos) { | |||||
| if (string.size() > start) { | |||||
| substr = string.substr(start, index - start); | |||||
| } | |||||
| (void)substr.erase(0, substr.find_first_not_of(' ')); | |||||
| (void)substr.erase(substr.find_last_not_of(' ') + 1); | |||||
| auto iter = DYNAMIC_FORMAT_MAP.find(substr); | |||||
| if (iter != DYNAMIC_FORMAT_MAP.end()) { | |||||
| substr = iter->second; | |||||
| } | |||||
| result.push_back(substr); | |||||
| start = index + sep.size(); | |||||
| index = string.find(sep, start); | |||||
| } | |||||
| if (string.size() > start) { | |||||
| substr = string.substr(start); | |||||
| } | |||||
| (void)substr.erase(0, substr.find_first_not_of(' ')); | |||||
| (void)substr.erase(substr.find_last_not_of(' ') + 1); | |||||
| auto iter = DYNAMIC_FORMAT_MAP.find(substr); | |||||
| if (iter != DYNAMIC_FORMAT_MAP.end()) { | |||||
| substr = iter->second; | |||||
| } | |||||
| result.push_back(substr); | |||||
| return result; | |||||
| } | |||||
| void ConvertFormatDtype(const std::string &format, const std::string &dtype, const std::shared_ptr<OpIOInfo> &io_info) { | |||||
| MS_EXCEPTION_IF_NULL(io_info); | |||||
| std::vector<std::string> format_vec = SplitStr(format, ","); | |||||
| std::vector<std::string> dtype_vec = SplitStr(dtype, ","); | |||||
| io_info->set_formats(format_vec); | |||||
| io_info->set_dtypes(dtype_vec); | |||||
| } | |||||
| bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector<std::shared_ptr<OpIOInfo>> *const inputs, | |||||
| std::vector<std::shared_ptr<OpIOInfo>> *const outputs) { | |||||
| nlohmann::json json_obj = nlohmann::json::parse(jsonStr); | |||||
| if (!json_obj.is_object()) { | |||||
| MS_LOG(DEBUG) << "JsonStr is not an object, the jsonStr is:" << jsonStr; | |||||
| return false; | |||||
| } | |||||
| std::vector<std::string> keys = {kName, kDtype, kFormat}; | |||||
| for (const auto &item : json_obj.items()) { | |||||
| std::string key_name; | |||||
| key_name = item.key(); | |||||
| if (key_name.empty()) { | |||||
| MS_LOG(DEBUG) << "Key name is empty!"; | |||||
| return false; | |||||
| } | |||||
| if (!CheckJsonItemValidity(json_obj, key_name, keys)) { | |||||
| return false; | |||||
| } | |||||
| if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) { | |||||
| std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| input->set_name(json_obj[key_name].at(kName)); | |||||
| ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input); | |||||
| inputs->emplace_back(input); | |||||
| } else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) { | |||||
| std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| output->set_name(json_obj[key_name].at(kName)); | |||||
| ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), output); | |||||
| outputs->emplace_back(output); | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Key name:" << key_name << " is undefined!"; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::string OpSelectFormat(const std::shared_ptr<AnfNode> &anf_node) { | |||||
| nlohmann::json kernel_json; | |||||
| std::string res_json_str; | |||||
| TbeKernelJsonCreator creator(OP_SELECT_FORMAT); | |||||
| bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json); | |||||
| if (!ret) { | |||||
| MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed"; | |||||
| return res_json_str; | |||||
| } | |||||
| res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); | |||||
| MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; | |||||
| return res_json_str; | |||||
| } | |||||
| void SetTidyInputsInfo(const std::shared_ptr<AnfNode> &anf_node, | |||||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, | |||||
| const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { | |||||
| std::vector<TypeId> inputs_type; | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<int> dyn_input_sizes; | |||||
| size_t dyn_input_idx = 0; | |||||
| size_t kernel_info_index = 0; | |||||
| size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node); | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| if (primitive->GetAttr("dyn_input_sizes") != nullptr) { | |||||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes")); | |||||
| } | |||||
| for (size_t i = 0; i < inputs.size(); i++) { | |||||
| MS_EXCEPTION_IF_NULL(inputs[i]); | |||||
| std::string param_type = inputs[i]->param_type(); | |||||
| if (i >= real_input_num) { | |||||
| MS_LOG(INFO) << "Input index: " << i << " is out of real_input_num:" << real_input_num; | |||||
| continue; | |||||
| } | |||||
| auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, i); | |||||
| auto format = kOpFormat_DEFAULT; | |||||
| if (param_type == "dynamic") { | |||||
| if (!dyn_input_sizes.empty()) { | |||||
| for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { | |||||
| kernel_info_index++; | |||||
| inputs_type.emplace_back(type_id); | |||||
| inputs_format.emplace_back(format); | |||||
| } | |||||
| dyn_input_idx++; | |||||
| } | |||||
| } else if (param_type == "required") { | |||||
| kernel_info_index++; | |||||
| inputs_type.emplace_back(type_id); | |||||
| inputs_format.emplace_back(format); | |||||
| } else { | |||||
| if (kernel_info_index < real_input_num) { | |||||
| MS_LOG(INFO) << "Input type is optional, input index is :" << kernel_info_index; | |||||
| kernel_info_index++; | |||||
| inputs_type.emplace_back(type_id); | |||||
| inputs_format.emplace_back(format); | |||||
| } | |||||
| } | |||||
| } | |||||
| builder->SetInputsDeviceType(inputs_type); | |||||
| builder->SetInputsFormat(inputs_format); | |||||
| } | |||||
| void SetTidyOutputsInfo(const std::shared_ptr<AnfNode> &anf_node, | |||||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, | |||||
| const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { | |||||
| std::vector<TypeId> outputs_type; | |||||
| std::vector<std::string> outputs_format; | |||||
| auto real_output_num = AnfAlgo::GetOutputTensorNum(anf_node); | |||||
| size_t output_idx = 0; | |||||
| for (const auto &output : outputs) { | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| if (output_idx >= real_output_num) { | |||||
| continue; | |||||
| } | |||||
| size_t output_num = 0; | |||||
| if (output->param_type() == "dynamic") { | |||||
| if (outputs.size() > 1) { | |||||
| MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!"; | |||||
| } | |||||
| output_num = real_output_num; | |||||
| } else if (output->param_type() == "required") { | |||||
| output_num = 1; | |||||
| } else { | |||||
| if (output_idx < real_output_num) { | |||||
| MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; | |||||
| output_num = 1; | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < output_num; i++) { | |||||
| auto type_id = AnfAlgo::GetOutputInferDataType(anf_node, output_idx); | |||||
| outputs_type.emplace_back(type_id); | |||||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | |||||
| output_idx++; | |||||
| } | |||||
| } | |||||
| builder->SetOutputsDeviceType(outputs_type); | |||||
| builder->SetOutputsFormat(outputs_format); | |||||
| } | |||||
| void GenTidyKernelBuildInfo(const std::shared_ptr<AnfNode> &anf_node, | |||||
| const std::vector<std::shared_ptr<OpIOInfo>> &inputs, | |||||
| const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { | |||||
| auto builder_tmp = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| builder_tmp->SetKernelType(TBE_KERNEL); | |||||
| SetTidyInputsInfo(anf_node, builder_tmp, inputs); | |||||
| SetTidyOutputsInfo(anf_node, builder_tmp, outputs); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder_tmp->Build(), anf_node.get()); | |||||
| } | |||||
| void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, | |||||
| const std::shared_ptr<OpInfo> &op_info_new_ptr) { | |||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_static = op_info_ptr->inputs_ptr(); | |||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_static = op_info_ptr->outputs_ptr(); | |||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_dyn; | |||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_dyn; | |||||
| if ((op_info_ptr->imply_type() == kTBE) && (!mindspore::opt::IsNopNode(kernel_node->cast<AnfNodePtr>()))) { | |||||
| // 1. create tidy kernelBuildInfo in order to generate json for calling op_select_format | |||||
| auto anf_node = kernel_node->cast<std::shared_ptr<AnfNode>>(); | |||||
| auto kernel_build_info_ptr = AnfAlgo::GetSelectKernelBuildInfo(anf_node); | |||||
| GenTidyKernelBuildInfo(kernel_node, inputs_static, outputs_static); | |||||
| // 2.get dynamic format from op_impl | |||||
| std::string res_json_str; | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (context_ptr->execution_mode() != kPynativeMode) { | |||||
| res_json_str = OpSelectFormat(kernel_node); | |||||
| } | |||||
| if (!res_json_str.empty()) { | |||||
| (void)ParseDynamicFormatJson(res_json_str, &inputs_dyn, &outputs_dyn); | |||||
| } | |||||
| if (inputs_static.size() != inputs_dyn.size()) { | |||||
| inputs_dyn.clear(); | |||||
| } | |||||
| if (outputs_static.size() != outputs_dyn.size()) { | |||||
| outputs_dyn.clear(); | |||||
| } | |||||
| // 3. resume kernel node's SelectKernelBuildInfo | |||||
| // As it has been replaced by GenTidyKernelBuildInfo in order to call python func | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_ptr, anf_node.get()); | |||||
| } | |||||
| // 4.replace by dynamic format and dtype | |||||
| if (inputs_dyn.empty() && outputs_dyn.empty()) { | |||||
| MS_LOG(INFO) << "Dynamic select format response is empty, use static register info."; | |||||
| op_info_new_ptr->set_inputs_ptr(inputs_static); | |||||
| op_info_new_ptr->set_outputs_ptr(outputs_static); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format."; | |||||
| for (size_t i = 0; i < inputs_static.size(); i++) { | |||||
| inputs_dyn[i]->set_param_type(inputs_static[i]->param_type()); | |||||
| inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type()); | |||||
| } | |||||
| for (size_t j = 0; j < outputs_static.size(); j++) { | |||||
| outputs_dyn[j]->set_param_type(outputs_static[j]->param_type()); | |||||
| outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type()); | |||||
| } | |||||
| op_info_new_ptr->set_inputs_ptr(inputs_dyn); | |||||
| op_info_new_ptr->set_outputs_ptr(outputs_dyn); | |||||
| } | |||||
| // 5.copy other opinfo to new op_info_new | |||||
| op_info_new_ptr->set_op_name(op_info_ptr->op_name()); | |||||
| op_info_new_ptr->set_imply_type(op_info_ptr->imply_type()); | |||||
| op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type()); | |||||
| } | |||||
| bool StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) { | |||||
| for (const auto &c : reshape_type_str) { | |||||
| switch (c) { | |||||
| case 'N': | |||||
| reshape_type_vec->push_back(kernel::N); | |||||
| break; | |||||
| case 'C': | |||||
| reshape_type_vec->push_back(kernel::C); | |||||
| break; | |||||
| case 'H': | |||||
| reshape_type_vec->push_back(kernel::H); | |||||
| break; | |||||
| case 'W': | |||||
| reshape_type_vec->push_back(kernel::W); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num, | |||||
| size_t builder_idex, const std::vector<int> &dyn_input_sizes, | |||||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) { | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| std::vector<TypeId> inputs_device_type; | |||||
| std::vector<std::string> inputs_format; | |||||
| size_t dyn_input_idx = 0; | |||||
| size_t kernel_info_index = 0; | |||||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||||
| size_t kernel_info_cnt = inputs[0]->dtypes().size(); | |||||
| std::vector<std::vector<Axis>> reshape_types; | |||||
| for (const auto &input : inputs) { | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| std::string param_type = input->param_type(); | |||||
| std::vector<std::string> dtypes = input->dtypes(); | |||||
| std::vector<std::string> formats = input->formats(); | |||||
| if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { | |||||
| MS_LOG(ERROR) << "Set input kernel builder info, dtyps size != formats size."; | |||||
| return false; | |||||
| } | |||||
| std::vector<Axis> reshape_type; | |||||
| if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { | |||||
| return false; | |||||
| } | |||||
| if (param_type == "dynamic") { | |||||
| if (dyn_input_sizes.empty()) { | |||||
| MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; | |||||
| return false; | |||||
| } | |||||
| for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { | |||||
| kernel_info_index++; | |||||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||||
| inputs_device_type.push_back(type_id); | |||||
| inputs_format.push_back(formats[builder_idex]); | |||||
| reshape_types.push_back(reshape_type); | |||||
| } | |||||
| dyn_input_idx++; | |||||
| } else if (param_type == "required") { | |||||
| kernel_info_index++; | |||||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||||
| inputs_device_type.push_back(type_id); | |||||
| inputs_format.push_back(formats[builder_idex]); | |||||
| reshape_types.push_back(reshape_type); | |||||
| } else { | |||||
| if (kernel_info_index < real_input_num) { | |||||
| MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index; | |||||
| kernel_info_index++; | |||||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||||
| inputs_device_type.push_back(type_id); | |||||
| inputs_format.push_back(formats[builder_idex]); | |||||
| reshape_types.push_back(reshape_type); | |||||
| } | |||||
| } | |||||
| } | |||||
| builder->SetInputReshapeType(reshape_types); | |||||
| builder->SetInputsDeviceType(inputs_device_type); | |||||
| builder->SetInputsFormat(inputs_format); | |||||
| return true; | |||||
| } | |||||
| bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex, | |||||
| const size_t &real_output_num, | |||||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) { | |||||
| // not now but in the next we need to support dynamic output case | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| size_t output_idx = 0; | |||||
| std::vector<TypeId> outputs_device_type; | |||||
| std::vector<std::string> outputs_format; | |||||
| MS_EXCEPTION_IF_NULL(outputs[0]); | |||||
| size_t kernel_info_cnt = outputs[0]->dtypes().size(); | |||||
| std::vector<std::vector<Axis>> reshape_types; | |||||
| for (const auto &output : outputs) { | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| if (output_idx >= real_output_num) { | |||||
| MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!"; | |||||
| continue; | |||||
| } | |||||
| std::vector<Axis> reshape_type; | |||||
| if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { | |||||
| return false; | |||||
| } | |||||
| size_t output_num = 0; | |||||
| if (output->param_type() == "dynamic") { | |||||
| if (outputs.size() > 1) { | |||||
| MS_LOG(EXCEPTION) << "Dynamic output is unsupported multi output!"; | |||||
| } | |||||
| output_num = real_output_num; | |||||
| } else if (output->param_type() == "required") { | |||||
| output_num = 1; | |||||
| } else { | |||||
| if (output_idx < real_output_num) { | |||||
| MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is " << output_idx; | |||||
| output_num = 1; | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < output_num; i++) { | |||||
| std::vector<std::string> dtypes = output->dtypes(); | |||||
| std::vector<std::string> formats = output->formats(); | |||||
| if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { | |||||
| MS_LOG(ERROR) << "Set output kernel builder info, dtyps size != formats size."; | |||||
| return false; | |||||
| } | |||||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||||
| outputs_device_type.push_back(type_id); | |||||
| outputs_format.push_back(formats[builder_idex]); | |||||
| reshape_types.push_back(reshape_type); | |||||
| output_idx++; | |||||
| } | |||||
| } | |||||
| builder->SetOutputReshapeType(reshape_types); | |||||
| builder->SetOutputsFormat(outputs_format); | |||||
| builder->SetOutputsDeviceType(outputs_device_type); | |||||
| return true; | |||||
| } | |||||
| void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, | |||||
| Processor processor, const std::shared_ptr<const OpInfo> &op_info_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| MS_EXCEPTION_IF_NULL(op_info_ptr); | |||||
| builder->SetProcessor(processor); | |||||
| std::string fusion_type = op_info_ptr->fusion_type(); | |||||
| if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { | |||||
| builder->SetFusionType(tbe::GetFusionType(fusion_type)); | |||||
| } | |||||
| builder->SetOpPattern(op_info_ptr->op_pattern()); | |||||
| builder->SetKernelType(TBE_KERNEL); | |||||
| } | |||||
| bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, | |||||
| std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||||
| size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr(); | |||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr(); | |||||
| std::vector<int> dyn_input_sizes; | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| if (primitive->GetAttr("dyn_input_sizes") != nullptr) { | |||||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes")); | |||||
| } | |||||
| if (!inputs.empty()) { | |||||
| MS_EXCEPTION_IF_NULL(inputs[0]); | |||||
| size_t kernel_info_cnt = inputs[0]->dtypes().size(); | |||||
| for (size_t j = 0; j < kernel_info_cnt; j++) { | |||||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr); | |||||
| if (!SetKernelBuilderInputInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) { | |||||
| MS_LOG(ERROR) << "Parse kernel metadata, set inputs kernel builder info failed."; | |||||
| return false; | |||||
| } | |||||
| if (!outputs.empty()) { | |||||
| if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) { | |||||
| MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| kernel_info_list->push_back(builder->Build()); | |||||
| } | |||||
| } else if (!outputs.empty()) { | |||||
| MS_EXCEPTION_IF_NULL(outputs[0]); | |||||
| size_t kernel_info_cnt = outputs[0]->dtypes().size(); | |||||
| for (size_t j = 0; j < kernel_info_cnt; j++) { | |||||
| auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr); | |||||
| if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) { | |||||
| MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed."; | |||||
| return false; | |||||
| } | |||||
| kernel_info_list->push_back(builder->Build()); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||||
| // if format is default, it remarkes support all format | |||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | |||||
| } | |||||
| if (format == kOpFormat_DEFAULT) { | |||||
| return true; | |||||
| } | |||||
| if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) { | |||||
| return false; | |||||
| } | |||||
| // if shape size is 0, the shape will be a scalar | |||||
| if (shape.empty()) { | |||||
| return true; | |||||
| } | |||||
| if (shape.size() > kShape4dDims) { | |||||
| return false; | |||||
| } | |||||
| if (format == kOpFormat_FRAC_NZ && shape.size() < 2) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| const size_t kCAxis = 1; | |||||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | |||||
| if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) { | |||||
| if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) { | |||||
| return false; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| if (kernel_name == "ReduceMean") { | |||||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims); | |||||
| if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | |||||
| if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) { | |||||
| if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) { | |||||
| return false; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| if (kernel_name == "ReduceMean") { | |||||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims); | |||||
| if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | |||||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | |||||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list; | |||||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); | |||||
| if (op_info_ptr == nullptr) { | |||||
| return; | |||||
| } | |||||
| // dynamic get op format and dtype and replace opinfo | |||||
| auto op_info_new_ptr = std::make_shared<OpInfo>(); | |||||
| ReplaceByDynamicFormatDtype(kernel_node, op_info_ptr, op_info_new_ptr); | |||||
| if (!ParseMetadata(kernel_node, op_info_new_ptr, &parse_info_list)) { | |||||
| MS_LOG(INFO) << "Tbe parsed metadata of op[" << op_name << "] failed."; | |||||
| return; | |||||
| } | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| for (const auto &parse_info : parse_info_list) { | |||||
| if (IsValidKernelInfo(kernel_node, *(parse_info))) { | |||||
| if (CheckSupported(kernel_node, parse_info)) { | |||||
| kernel_info_list->push_back(parse_info); | |||||
| } else { | |||||
| MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; | |||||
| } | |||||
| } | |||||
| if (kernel_info_list->empty()) { | |||||
| MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -13,20 +13,18 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_TBE_KERNEL_SELECT_H | |||||
| #define MINDSPORE_TBE_KERNEL_SELECT_H | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "kernel/oplib/opinfo.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | |||||
| struct SupportFormat { | |||||
| std::vector<std::vector<std::string>> input_format; | |||||
| std::vector<std::vector<std::string>> output_format; | |||||
| }; | |||||
| using SupportFormatItem = std::vector<std::string>; | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_TBE_KERNEL_SELECT_H | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ | |||||
| @@ -0,0 +1,319 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" | |||||
| #include "utils/utils.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| constexpr char kDynInputKey[] = "dyn_input_sizes"; | |||||
| constexpr size_t kInputIndex_0 = 0; | |||||
| constexpr size_t kChannelN = 0; | |||||
| constexpr size_t kChannelC = 1; | |||||
| constexpr size_t kAlignmented16 = 16; | |||||
| // 1. all shape no scalar and same | |||||
| // 2. part scalar : no_scalar (shape size > xxx && alig xxx) | |||||
| // 3. all no_scalar and not same (broad cast xxx dim) | |||||
| bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| input_num_ = 0; | |||||
| output_num_ = 0; | |||||
| input_shapes_.clear(); | |||||
| output_shapes_.clear(); | |||||
| if (AnfAlgo::HasNodeAttr(kDynInputKey, cnode_ptr_)) { | |||||
| MS_LOG(INFO) << "This broadcast node has dynamic input."; | |||||
| auto dynamic_size_vec = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode_ptr_, kDynInputKey); | |||||
| if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) { | |||||
| MS_LOG(EXCEPTION) << "dynamic attr set error, please check."; | |||||
| } | |||||
| auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); | |||||
| PadScalarShape(&dynamic_input_shape0_); | |||||
| input_shapes_.emplace_back(dynamic_input_shape0_); | |||||
| input_num_ = 1; | |||||
| } else { | |||||
| input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||||
| for (size_t i = 0; i < input_num_; ++i) { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | |||||
| PadScalarShape(&input_shape); | |||||
| input_shapes_.emplace_back(input_shape); | |||||
| } | |||||
| } | |||||
| output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | |||||
| for (size_t i = 0; i < output_num_; ++i) { | |||||
| auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i); | |||||
| PadScalarShape(&output); | |||||
| output_shapes_.emplace_back(output); | |||||
| } | |||||
| AssignSupportFormat(kOpFormat_DEFAULT, support_format); | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| if (IsSameShape()) { | |||||
| if (!HasScalarInput()) { | |||||
| AssignSupportFormat(kOpFormat_NC1HWC0, support_format); | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| SupportFormatItem input_support_format; | |||||
| SupportFormatItem output_support_format; | |||||
| if (HasScalarInput()) { | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (IsScalarShape(shape)) { | |||||
| input_support_format.emplace_back(kOpFormat_DEFAULT); | |||||
| } else { | |||||
| if (!Is4DShape(shape)) { | |||||
| return false; | |||||
| } | |||||
| if (shape[kChannelC] % kAlignmented16 != 0) { | |||||
| return false; | |||||
| } | |||||
| input_support_format.emplace_back(kOpFormat_NC1HWC0); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (!Is4DShape(shape)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto shape_tmp = input_shapes_[0]; | |||||
| auto broadcast_c_axis = std::any_of( | |||||
| input_shapes_.begin(), input_shapes_.end(), | |||||
| [&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); | |||||
| if (broadcast_c_axis) { | |||||
| MS_LOG(INFO) << "This node broadcast c channel."; | |||||
| return false; | |||||
| } | |||||
| input_support_format.assign(input_num_, kOpFormat_NC1HWC0); | |||||
| } | |||||
| GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format); | |||||
| support_format->input_format.emplace_back(input_support_format); | |||||
| support_format->output_format.emplace_back(output_support_format); | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| if (IsSameShape()) { | |||||
| if (!HasScalarInput()) { | |||||
| AssignSupportFormat(kOpFormat_FRAC_Z, support_format); | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| SupportFormatItem input_support_format; | |||||
| SupportFormatItem output_support_format; | |||||
| if (HasScalarInput()) { | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (IsScalarShape(shape)) { | |||||
| input_support_format.emplace_back(kOpFormat_DEFAULT); | |||||
| } else { | |||||
| if (!Is4DShape(shape)) { | |||||
| return false; | |||||
| } | |||||
| if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) { | |||||
| return false; | |||||
| } | |||||
| input_support_format.emplace_back(kOpFormat_FRAC_Z); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format); | |||||
| support_format->input_format.emplace_back(input_support_format); | |||||
| support_format->output_format.emplace_back(output_support_format); | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| if (IsSameShape()) { | |||||
| if (!HasScalarInput()) { | |||||
| AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format); | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| SupportFormatItem input_support_format; | |||||
| SupportFormatItem output_support_format; | |||||
| if (HasScalarInput()) { | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (IsScalarShape(shape)) { | |||||
| input_support_format.emplace_back(kOpFormat_DEFAULT); | |||||
| } else { | |||||
| if (!Is4DShape(shape)) { | |||||
| return false; | |||||
| } | |||||
| if (shape[kChannelN] % kAlignmented16 != 0) { | |||||
| return false; | |||||
| } | |||||
| input_support_format.emplace_back(kOpFormat_C1HWNCoC0); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (!Is4DShape(shape)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto shape_tmp = input_shapes_[0]; | |||||
| auto broadcast_nc_axis = | |||||
| std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) { | |||||
| return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN)); | |||||
| }); | |||||
| if (broadcast_nc_axis) { | |||||
| MS_LOG(INFO) << "This node broadcast n || c channel."; | |||||
| return false; | |||||
| } | |||||
| input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0); | |||||
| } | |||||
| GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format); | |||||
| support_format->input_format.emplace_back(input_support_format); | |||||
| support_format->output_format.emplace_back(output_support_format); | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| if (IsSameShape()) { | |||||
| if (!HasScalarInput()) { | |||||
| AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| SupportFormatItem input_support_format; | |||||
| SupportFormatItem output_support_format; | |||||
| if (HasScalarInput()) { | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (IsScalarShape(shape)) { | |||||
| input_support_format.emplace_back(kOpFormat_DEFAULT); | |||||
| } else { | |||||
| if (shape.size() < kShape2dDims) { | |||||
| return false; | |||||
| } | |||||
| if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) { | |||||
| return false; | |||||
| } | |||||
| input_support_format.emplace_back(kOpFormat_FRAC_NZ); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(), | |||||
| [](const std::vector<size_t> &elem) { return elem.size() < kShape2dDims; }); | |||||
| if (less_2dims) { | |||||
| MS_LOG(INFO) << "This node dim less 2."; | |||||
| return false; | |||||
| } | |||||
| auto shape_tmp = input_shapes_[0]; | |||||
| auto broadcast_last_dim = | |||||
| std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) { | |||||
| return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) || | |||||
| (shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2)); | |||||
| }); | |||||
| if (broadcast_last_dim) { | |||||
| MS_LOG(INFO) << "This node broadcast last channel."; | |||||
| return false; | |||||
| } | |||||
| input_support_format.assign(input_num_, kOpFormat_FRAC_NZ); | |||||
| } | |||||
| GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format); | |||||
| support_format->input_format.emplace_back(input_support_format); | |||||
| support_format->output_format.emplace_back(output_support_format); | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| return false; | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const { | |||||
| return shape.size() == kShape4dDims; | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::IsSameShape() const { | |||||
| auto shape = input_shapes_.begin(); | |||||
| for (const auto &item : input_shapes_) { | |||||
| if (shape->size() != item.size()) { | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < shape->size(); ++i) { | |||||
| if (shape->at(i) != item.at(i)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void TbeKernelBroadCastSelecter::PadScalarShape(std::vector<size_t> *shape) const { | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| if (shape->empty()) { | |||||
| shape->emplace_back(1); | |||||
| } | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector<size_t> &shape) const { | |||||
| return (shape.size() == 1 && shape[0] == 1); | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::HasScalarInput() const { | |||||
| bool ret = false; | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (IsScalarShape(shape)) { | |||||
| ret = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format, | |||||
| SupportFormatItem *output_support_item) const { | |||||
| MS_EXCEPTION_IF_NULL(output_support_item); | |||||
| for (const auto &shape : output_shapes_) { | |||||
| if (IsScalarShape(shape)) { | |||||
| output_support_item->emplace_back(kOpFormat_DEFAULT); | |||||
| } else { | |||||
| output_support_item->emplace_back(support_format); | |||||
| } | |||||
| } | |||||
| } | |||||
| void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str, | |||||
| mindspore::kernel::SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| SupportFormatItem input_support_format; | |||||
| SupportFormatItem output_support_format; | |||||
| input_support_format.assign(input_num_, support_format_str); | |||||
| output_support_format.assign(output_num_, support_format_str); | |||||
| support_format->input_format.emplace_back(input_support_format); | |||||
| support_format->output_format.emplace_back(output_support_format); | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include "ir/anf.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class TbeKernelBroadCastSelecter { | |||||
| public: | |||||
| explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} | |||||
| ~TbeKernelBroadCastSelecter() = default; | |||||
| bool GetShapeInfo(SupportFormat *support_format); | |||||
| bool IsBroadCastSupport5HD(SupportFormat *support_format) const; | |||||
| bool IsBroadCastSupportFracZ(SupportFormat *support_format) const; | |||||
| bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const; | |||||
| bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const; | |||||
| bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const; | |||||
| private: | |||||
| bool IsSameShape() const; | |||||
| void PadScalarShape(std::vector<size_t> *shape) const; | |||||
| bool Is4DShape(const std::vector<size_t> &shape) const; | |||||
| bool IsScalarShape(const std::vector<size_t> &shape) const; | |||||
| bool HasScalarInput() const; | |||||
| void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; | |||||
| void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; | |||||
| // broadcast | |||||
| CNodePtr cnode_ptr_; | |||||
| size_t input_num_{}; | |||||
| size_t output_num_{}; | |||||
| std::vector<std::vector<size_t>> input_shapes_; | |||||
| std::vector<std::vector<size_t>> output_shapes_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H | |||||
| @@ -0,0 +1,180 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "utils/utils.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| constexpr char kKeepDims[] = "keep_dims"; | |||||
| constexpr char kAxis[] = "axis"; | |||||
| constexpr char kTypeInt32[] = "Int32"; | |||||
| constexpr size_t kInputIndex_0 = 0; | |||||
| constexpr size_t kOutputIndex_0 = 0; | |||||
| constexpr size_t kChannelN = 0; | |||||
| constexpr size_t kChannelC = 1; | |||||
| constexpr size_t kReduceNZMinDim = 3; | |||||
| bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| input_shape_.clear(); | |||||
| output_shape_.clear(); | |||||
| axis_.clear(); | |||||
| auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||||
| auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | |||||
| if (input_num != 1 || output_num != 1) { | |||||
| MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num | |||||
| << ", output num: " << output_num; | |||||
| } | |||||
| // get input/output shape | |||||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); | |||||
| PadScalarShape(&input_shape_); | |||||
| output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0); | |||||
| PadScalarShape(&output_shape_); | |||||
| // get keep dim attr | |||||
| GetReduceAttrKeepDim(); | |||||
| // get axis attr | |||||
| GetReduceAttrAxis(); | |||||
| AssignSupportFormat(kOpFormat_DEFAULT, support_format); | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| if (!Is4DShape(input_shape_)) { | |||||
| return false; | |||||
| } | |||||
| if (!keep_dims_ || axis_.empty()) { | |||||
| return false; | |||||
| } | |||||
| auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); | |||||
| if (reduce_c_axis) { | |||||
| return false; | |||||
| } | |||||
| AssignSupportFormat(kOpFormat_NC1HWC0, support_format); | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| // like to 5HD | |||||
| return false; | |||||
| } | |||||
| bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { | |||||
| return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format); | |||||
| } | |||||
| bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const { | |||||
| return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format); | |||||
| } | |||||
| bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| if (input_shape_.size() < kReduceNZMinDim) { | |||||
| return false; | |||||
| } | |||||
| if (axis_.empty()) { | |||||
| return false; | |||||
| } | |||||
| auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) { | |||||
| return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2)); | |||||
| }); | |||||
| if (reduce_last_axis) { | |||||
| return false; | |||||
| } | |||||
| AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format, | |||||
| mindspore::kernel::SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| if (!Is4DShape(input_shape_)) { | |||||
| return false; | |||||
| } | |||||
| if (!keep_dims_ || axis_.empty()) { | |||||
| return false; | |||||
| } | |||||
| auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(), | |||||
| [](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); }); | |||||
| if (reduce_n_c_axis) { | |||||
| return false; | |||||
| } | |||||
| AssignSupportFormat(format, support_format); | |||||
| return true; | |||||
| } | |||||
| void TbeKernelReduceSelecter::GetReduceAttrAxis() { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto axis = primitive->GetAttr(kAxis); | |||||
| if (axis == nullptr) { | |||||
| MS_LOG(INFO) << "This node does't have axie attr."; | |||||
| return; | |||||
| } | |||||
| auto type = axis->type(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| std::vector<int> axis_list; | |||||
| if (type->ToString() == kTypeInt32) { | |||||
| axis_list.emplace_back(GetValue<int>(axis)); | |||||
| } else { | |||||
| axis_list = GetValue<std::vector<int>>(axis); | |||||
| } | |||||
| for (const auto &elem : axis_list) { | |||||
| if (elem < 0) { | |||||
| axis_.emplace_back(input_shape_.size() + elem); | |||||
| } else { | |||||
| axis_.emplace_back(IntToSize(elem)); | |||||
| } | |||||
| } | |||||
| } | |||||
| void TbeKernelReduceSelecter::GetReduceAttrKeepDim() { | |||||
| if (!AnfAlgo::HasNodeAttr(kKeepDims, cnode_ptr_)) { | |||||
| MS_LOG(INFO) << "This node does't have keep_attr."; | |||||
| keep_dims_ = false; | |||||
| return; | |||||
| } | |||||
| keep_dims_ = AnfAlgo::GetNodeAttr<bool>(cnode_ptr_, kKeepDims); | |||||
| } | |||||
| void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str, | |||||
| mindspore::kernel::SupportFormat *support_format) const { | |||||
| MS_EXCEPTION_IF_NULL(support_format); | |||||
| SupportFormatItem input_support_format; | |||||
| SupportFormatItem output_support_format; | |||||
| input_support_format.emplace_back(support_format_str); | |||||
| output_support_format.emplace_back(support_format_str); | |||||
| support_format->input_format.emplace_back(input_support_format); | |||||
| support_format->output_format.emplace_back(output_support_format); | |||||
| } | |||||
| bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; } | |||||
| void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const { | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| if (shape->empty()) { | |||||
| shape->emplace_back(1); | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ | |||||
| #include <utility> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ir/anf.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class TbeKernelReduceSelecter { | |||||
| public: | |||||
| explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} | |||||
| ~TbeKernelReduceSelecter() = default; | |||||
| bool GetShapeInfo(SupportFormat *support_format); | |||||
| bool IsReduceSupport5HD(SupportFormat *support_format) const; | |||||
| bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const; | |||||
| bool IsReduceSupportFracZ(SupportFormat *support_format) const; | |||||
| bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const; | |||||
| bool IsReduceSupportFracNZ(SupportFormat *support_format) const; | |||||
| private: | |||||
| bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const; | |||||
| void GetReduceAttrAxis(); | |||||
| void GetReduceAttrKeepDim(); | |||||
| void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; | |||||
| bool Is4DShape(const std::vector<size_t> &shape) const; | |||||
| void PadScalarShape(std::vector<size_t> *shape) const; | |||||
| CNodePtr cnode_ptr_; | |||||
| std::vector<size_t> input_shape_{}; | |||||
| std::vector<size_t> output_shape_{}; | |||||
| std::vector<size_t> axis_{}; | |||||
| bool keep_dims_ = false; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H | |||||
| @@ -0,0 +1,633 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include <utility> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "kernel/tbe/tbe_kernel_build.h" | |||||
| #include "nlohmann/json.hpp" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include "kernel/tbe/tbe_python_funcs.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "kernel/tbe/tbe_convert_utils.h" | |||||
| #include "parallel/ops_info/ops_utils.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| constexpr auto kName = "name"; | |||||
| constexpr auto kDtype = "dtype"; | |||||
| constexpr auto kFormat = "format"; | |||||
| constexpr auto kPrefixInput = "input"; | |||||
| constexpr auto kPrefixOutput = "output"; | |||||
| constexpr char kDynInputKey[] = "dyn_input_sizes"; | |||||
| constexpr char kParamTypeDynamic[] = "dynamic"; | |||||
| constexpr char kParamTypeRequre[] = "required"; | |||||
| constexpr char kParamTypeOptional[] = "optional"; | |||||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||||
| auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list); | |||||
| tbe_selecter.TbeMetadataInfoEx(); | |||||
| } | |||||
| TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) | |||||
| : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {} | |||||
| void TbeKernelSelect::TbeMetadataInfoEx() { | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list_); | |||||
| node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); | |||||
| auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); | |||||
| if (!op_info_ptr) { | |||||
| MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ | |||||
| << ", node name: " << cnode_ptr_->fullname_with_scope(); | |||||
| OpPattern pattern = op_info_ptr->op_pattern(); | |||||
| if (pattern == kCommonPattern) { | |||||
| GetCommonPatternKernelInfo(*op_info_ptr); | |||||
| } else if (pattern == kDynamicFormatPattern) { | |||||
| GetDynamicFormatPatternKernelInfo(*op_info_ptr); | |||||
| } else if (pattern == kFormatAgnosticPattern) { | |||||
| GetAgnosticPatternKernelInfo(*op_info_ptr); | |||||
| } else if (pattern == kBroadcastPattern) { | |||||
| GetBroadcastPatternKernelInfo(*op_info_ptr); | |||||
| } else if (pattern == kReducePattern) { | |||||
| GetReducePatternKernelInfo(*op_info_ptr); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Warning: op pattern is invailed."; | |||||
| } | |||||
| // check support | |||||
| FilterInVaildKernelInfo(); | |||||
| MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; | |||||
| } | |||||
| void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| // get dynamic inputs | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| std::vector<int> dyn_input_sizes; | |||||
| if (primitive->HasAttr(kDynInputKey)) { | |||||
| dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr(kDynInputKey)); | |||||
| } | |||||
| // get real input/output num | |||||
| size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||||
| const auto inputs_info = op_info.inputs_ptr(); | |||||
| size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | |||||
| const auto outputs_info = op_info.outputs_ptr(); | |||||
| if (inputs_info.empty() && outputs_info.empty()) { | |||||
| MS_LOG(EXCEPTION) << "op info input & output is null, please check."; | |||||
| } | |||||
| // create kernel build info from opinfo | |||||
| size_t kernel_build_info_num = | |||||
| inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size(); | |||||
| for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) { | |||||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||||
| SetTbeBuildCommonInfo(op_info, &builder); | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<TypeId> inputs_device_type; | |||||
| std::vector<std::vector<Axis>> inputs_reshape_type; | |||||
| // input | |||||
| if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, | |||||
| &inputs_format, &inputs_device_type, &inputs_reshape_type)) { | |||||
| break; | |||||
| } | |||||
| builder.SetInputsDeviceType(inputs_device_type); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetInputReshapeType(inputs_reshape_type); | |||||
| // output | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> outputs_device_type; | |||||
| std::vector<std::vector<Axis>> outputs_reshape_type; | |||||
| if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, | |||||
| &outputs_format, &outputs_device_type, &outputs_reshape_type)) { | |||||
| break; | |||||
| } | |||||
| builder.SetOutputsDeviceType(outputs_device_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| builder.SetOutputReshapeType(outputs_reshape_type); | |||||
| kernel_info_list_->emplace_back(builder.Build()); | |||||
| } | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| // | |||||
| OpInfo op_info_new; | |||||
| CreateNewOpInfo(op_info, &op_info_new); | |||||
| GetCommonPatternKernelInfo(op_info_new); | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| if (op_info.inputs_ptr().size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "AgnosticPattern only support one input."; | |||||
| } | |||||
| auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0); | |||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||||
| MS_LOG(INFO) << "Got the unknown format " << format; | |||||
| format = kOpFormat_DEFAULT; | |||||
| } | |||||
| SupportFormat support_format; | |||||
| SupportFormatItem input_item; | |||||
| SupportFormatItem output_item; | |||||
| input_item.assign(op_info.inputs_ptr().size(), format); | |||||
| output_item.assign(op_info.outputs_ptr().size(), format); | |||||
| support_format.input_format.emplace_back(input_item); | |||||
| support_format.output_format.emplace_back(output_item); | |||||
| PrintSupportedFormat(support_format); | |||||
| OpInfo op_info_new; | |||||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||||
| GetCommonPatternKernelInfo(op_info_new); | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_); | |||||
| SupportFormat support_format; | |||||
| broadcast_selecter.GetShapeInfo(&support_format); | |||||
| if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD."; | |||||
| } | |||||
| if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ."; | |||||
| } | |||||
| if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0."; | |||||
| } | |||||
| if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; | |||||
| } | |||||
| PrintSupportedFormat(support_format); | |||||
| OpInfo op_info_new; | |||||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||||
| GetCommonPatternKernelInfo(op_info_new); | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { | |||||
| MS_LOG(INFO) << "start."; | |||||
| auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_); | |||||
| SupportFormat support_format; | |||||
| reduce_selecter.GetShapeInfo(&support_format); | |||||
| if (!reduce_selecter.IsReduceSupport5HD(&support_format)) { | |||||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD."; | |||||
| } | |||||
| if (reduce_selecter.IsReduceSupportFracZ(&support_format)) { | |||||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ."; | |||||
| } | |||||
| if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) { | |||||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0."; | |||||
| } | |||||
| if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) { | |||||
| MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ."; | |||||
| } | |||||
| PrintSupportedFormat(support_format); | |||||
| OpInfo op_info_new; | |||||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||||
| GetCommonPatternKernelInfo(op_info_new); | |||||
| MS_LOG(INFO) << "end."; | |||||
| } | |||||
| void TbeKernelSelect::FilterInVaildKernelInfo() { | |||||
| if (kernel_info_list_->empty()) { | |||||
| MS_LOG(INFO) << "Warning: get kernel build info failed."; | |||||
| return; | |||||
| } | |||||
| auto kernel_build_info_iter = kernel_info_list_->begin(); | |||||
| while (kernel_build_info_iter != kernel_info_list_->end()) { | |||||
| if (!FilterInVaildShape(kernel_build_info_iter)) { | |||||
| MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); | |||||
| kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); | |||||
| continue; | |||||
| } | |||||
| if (!TbeCheckSupported(kernel_build_info_iter)) { | |||||
| MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); | |||||
| kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); | |||||
| continue; | |||||
| } | |||||
| kernel_build_info_iter++; | |||||
| } | |||||
| } | |||||
| bool TbeKernelSelect::FilterInVaildShape( | |||||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||||
| auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); | |||||
| for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | |||||
| auto format = kernel_build_info_inputs_format.at(i); | |||||
| if (!IsShapeMatchFormat(shape, format)) { | |||||
| MS_LOG(INFO) << "The " << i << "th input check failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); | |||||
| for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { | |||||
| auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); | |||||
| auto format = kernel_build_info_outputs_format.at(j); | |||||
| if (!IsShapeMatchFormat(shape, format)) { | |||||
| MS_LOG(INFO) << "The " << j << "th input check failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) { | |||||
| if (format == kOpFormat_DEFAULT) { | |||||
| return true; | |||||
| } | |||||
| static std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||||
| // if format is default, it remarkes support all format | |||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | |||||
| MS_LOG(EXCEPTION) << "Got the unknown format " << format; | |||||
| } | |||||
| // server not support format with C04 suffix | |||||
| if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) != | |||||
| kServerNotSupportFormat.end()) { | |||||
| MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; | |||||
| return false; | |||||
| } | |||||
| // not support format: | |||||
| // 1 NDHWC with shape size != 5 | |||||
| // 2 FRAC_NZ with shape size < 2 | |||||
| // 3 !NDHWC with shape size > 4 | |||||
| if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || | |||||
| (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || | |||||
| (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { | |||||
| MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool TbeKernelSelect::TbeCheckSupported( | |||||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | |||||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | |||||
| static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL, | |||||
| parallel::BATCHMATMUL, | |||||
| parallel::TOPK, | |||||
| parallel::IN_TOPK, | |||||
| parallel::PACK, | |||||
| parallel::GATHER_ND, | |||||
| parallel::UNSORTEF_SEGMENT_MIND, | |||||
| parallel::UNSORTEF_SEGMENT_PRODD, | |||||
| parallel::CAST}; | |||||
| auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); | |||||
| if (iter == kCheckSupportedOpType.end()) { | |||||
| return true; | |||||
| } | |||||
| MS_LOG(INFO) << "Check support start."; | |||||
| // replace kernel_info with current kernel info | |||||
| auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); | |||||
| nlohmann::json kernel_json; | |||||
| TbeKernelJsonCreator creator(CHECK_SUPPORTED); | |||||
| bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; | |||||
| } | |||||
| ret = TbePythonFuncs::CheckSupported(kernel_json); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); | |||||
| return ret; | |||||
| } | |||||
| void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info, | |||||
| mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) { | |||||
| MS_EXCEPTION_IF_NULL(builder); | |||||
| builder->SetProcessor(AICORE); | |||||
| std::string fusion_type = op_info.fusion_type(); | |||||
| if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { | |||||
| builder->SetFusionType(tbe::GetFusionType(fusion_type)); | |||||
| } | |||||
| builder->SetOpPattern(op_info.op_pattern()); | |||||
| builder->SetKernelType(TBE_KERNEL); | |||||
| } | |||||
| bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | |||||
| const std::vector<int> &dyn_input_sizes, std::vector<std::string> *formats, | |||||
| std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types) { | |||||
| MS_EXCEPTION_IF_NULL(formats); | |||||
| MS_EXCEPTION_IF_NULL(device_types); | |||||
| MS_EXCEPTION_IF_NULL(reshape_types); | |||||
| size_t dynamic_input_index = 0; | |||||
| size_t real_io_tensor_index = 0; | |||||
| size_t io_info_index = 0; | |||||
| size_t io_info_num = ios_info.size(); | |||||
| for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { | |||||
| std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index]; | |||||
| auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); | |||||
| std::string kernel_build_info_format; | |||||
| if (!io_info_item->formats().empty()) { | |||||
| kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); | |||||
| } | |||||
| std::string io_param_type = io_info_item->param_type(); | |||||
| std::vector<Axis> reshape_type; | |||||
| StringToAxisVector(io_info_item->reshape_type(), &reshape_type); | |||||
| if (io_param_type == kParamTypeDynamic) { | |||||
| // dynamic io | |||||
| if (is_input) { | |||||
| if (dynamic_input_index >= dyn_input_sizes.size()) { | |||||
| MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index | |||||
| << ", dyn_input_sizes size: " << dyn_input_sizes.size(); | |||||
| } | |||||
| int dynamic_input_size = dyn_input_sizes[dynamic_input_index]; | |||||
| for (int i = 0; i < dynamic_input_size; ++i) { | |||||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||||
| formats->emplace_back(kernel_build_info_format); | |||||
| reshape_types->emplace_back(reshape_type); | |||||
| } | |||||
| dynamic_input_index++; | |||||
| real_io_tensor_index += dynamic_input_size; | |||||
| } else { | |||||
| if (ios_info.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; | |||||
| } | |||||
| for (size_t i = 0; i < real_io_tensor_num; ++i) { | |||||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||||
| formats->emplace_back(kernel_build_info_format); | |||||
| reshape_types->emplace_back(reshape_type); | |||||
| } | |||||
| real_io_tensor_index += real_io_tensor_num; | |||||
| } | |||||
| } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) { | |||||
| // requre or optional io | |||||
| device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); | |||||
| formats->emplace_back(kernel_build_info_format); | |||||
| reshape_types->emplace_back(reshape_type); | |||||
| real_io_tensor_index++; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; | |||||
| } | |||||
| } | |||||
| if (io_info_index != io_info_num) { | |||||
| MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num | |||||
| << "), this node may has optional input/output."; | |||||
| } | |||||
| if (real_io_tensor_index != real_io_tensor_num) { | |||||
| std::string io_type = is_input ? "inputs " : "outputs"; | |||||
| MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num | |||||
| << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index | |||||
| << ") != real_io_tensor_num(" << real_io_tensor_num << ")"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) { | |||||
| MS_EXCEPTION_IF_NULL(reshape_type_vec); | |||||
| for (const auto &c : reshape_type_str) { | |||||
| switch (c) { | |||||
| case 'N': | |||||
| reshape_type_vec->push_back(kernel::N); | |||||
| break; | |||||
| case 'C': | |||||
| reshape_type_vec->push_back(kernel::C); | |||||
| break; | |||||
| case 'H': | |||||
| reshape_type_vec->push_back(kernel::H); | |||||
| break; | |||||
| case 'W': | |||||
| reshape_type_vec->push_back(kernel::W); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; | |||||
| } | |||||
| } | |||||
| } | |||||
| void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, | |||||
| const std::vector<std::vector<std::string>> &support_format_item, size_t index, | |||||
| mindspore::kernel::OpIOInfo *op_io_info_new) { | |||||
| MS_EXCEPTION_IF_NULL(op_io_info_new); | |||||
| op_io_info_new->set_index(op_io_info.index()); | |||||
| op_io_info_new->set_name(op_io_info.name()); | |||||
| op_io_info_new->set_param_type(op_io_info.param_type()); | |||||
| op_io_info_new->set_need_compile(op_io_info.need_compile()); | |||||
| op_io_info_new->set_reshape_type(op_io_info.reshape_type()); | |||||
| op_io_info_new->set_shape(op_io_info.shape()); | |||||
| // dtype | |||||
| std::vector<std::string> dtype_new; | |||||
| auto dtype = op_io_info.dtypes(); | |||||
| for (size_t i = 0; i < support_format_item.size(); ++i) { | |||||
| dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end()); | |||||
| } | |||||
| op_io_info_new->set_dtypes(dtype_new); | |||||
| // format | |||||
| std::vector<std::string> format_new; | |||||
| for (const auto &formats : support_format_item) { | |||||
| auto format = formats.at(index); | |||||
| for (size_t j = 0; j < dtype.size(); ++j) { | |||||
| format_new.emplace_back(format); | |||||
| } | |||||
| } | |||||
| op_io_info_new->set_formats(format_new); | |||||
| } | |||||
| std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { | |||||
| const std::map<std::string, std::string> kDynamicFormatMap = { | |||||
| {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; | |||||
| if (op_select_json_item.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Op select ret item is null."; | |||||
| } | |||||
| const char space = ' '; | |||||
| const char sep = ','; | |||||
| std::string op_select_tmp = op_select_json_item + ","; | |||||
| std::vector<std::string> ret; | |||||
| auto begin = op_select_tmp.find_first_not_of(space, 0); | |||||
| auto sep_pos = op_select_tmp.find(sep); | |||||
| while (sep_pos != std::string::npos) { | |||||
| auto obj = op_select_tmp.substr(begin, sep_pos - begin); | |||||
| if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) { | |||||
| obj = kDynamicFormatMap.at(obj); | |||||
| } | |||||
| ret.emplace_back(obj); | |||||
| begin = op_select_tmp.find_first_not_of(space, sep_pos + 1); | |||||
| sep_pos = op_select_tmp.find(sep, begin); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| std::string TbeKernelSelect::OpSelectFormat() { | |||||
| nlohmann::json kernel_json; | |||||
| std::string res_json_str; | |||||
| TbeKernelJsonCreator creator(OP_SELECT_FORMAT); | |||||
| bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; | |||||
| } | |||||
| res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); | |||||
| if (res_json_str.empty()) { | |||||
| MS_LOG(EXCEPTION) << "op select format error."; | |||||
| } | |||||
| MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; | |||||
| return res_json_str; | |||||
| } | |||||
| void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format, | |||||
| mindspore::kernel::OpInfo *op_info_new) { | |||||
| MS_EXCEPTION_IF_NULL(op_info_new); | |||||
| if (op_info.inputs_ptr().size() != support_format.input_format[0].size() || | |||||
| op_info.outputs_ptr().size() != support_format.output_format[0].size()) { | |||||
| MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size() | |||||
| << ", input support size: " << support_format.input_format[0].size() | |||||
| << ", op info output size: " << op_info.outputs_ptr().size() | |||||
| << ", output support size: " << support_format.output_format[0].size(); | |||||
| } | |||||
| *op_info_new = op_info; | |||||
| op_info_new->ClearInputs(); | |||||
| op_info_new->ClearOutputs(); | |||||
| for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { | |||||
| auto input = op_info.inputs_ptr().at(i); | |||||
| auto input_new = std::make_shared<OpIOInfo>(); | |||||
| CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get()); | |||||
| op_info_new->add_inputs_ptr(input_new); | |||||
| } | |||||
| for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) { | |||||
| auto output = op_info.outputs_ptr().at(j); | |||||
| auto output_new = std::make_shared<OpIOInfo>(); | |||||
| CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get()); | |||||
| op_info_new->add_outputs_ptr(output_new); | |||||
| } | |||||
| } | |||||
| struct SelectOpIOInfo { | |||||
| std::string name; | |||||
| std::vector<std::string> dtypes; | |||||
| std::vector<std::string> formats; | |||||
| }; | |||||
| void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, | |||||
| mindspore::kernel::OpInfo *op_info_new) { | |||||
| MS_EXCEPTION_IF_NULL(op_info_new); | |||||
| auto op_seclect_json = OpSelectFormat(); | |||||
| if (!op_seclect_json.empty()) { | |||||
| nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json); | |||||
| if (!json_obj.is_object()) { | |||||
| MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json; | |||||
| } | |||||
| std::vector<SelectOpIOInfo> inputs; | |||||
| std::vector<SelectOpIOInfo> outputs; | |||||
| for (const auto &item : json_obj.items()) { | |||||
| const std::string &item_name = item.key(); | |||||
| bool is_input = (item_name.find(kPrefixInput) != std::string::npos); | |||||
| bool is_output = (item_name.find(kPrefixOutput) != std::string::npos); | |||||
| if (!is_input && !is_output) { | |||||
| MS_LOG(EXCEPTION) << "op select ret json is error."; | |||||
| } | |||||
| if (is_input) { | |||||
| SelectOpIOInfo select_input; | |||||
| select_input.name = item.value().at(kName); | |||||
| std::string input_dtype_item = item.value().at(kDtype); | |||||
| select_input.dtypes = SplitStrToVec(input_dtype_item); | |||||
| std::string input_format_item = item.value().at(kFormat); | |||||
| select_input.formats = SplitStrToVec(input_format_item); | |||||
| inputs.emplace_back(select_input); | |||||
| } else if (is_output) { | |||||
| SelectOpIOInfo select_output; | |||||
| select_output.name = item.value().at(kName); | |||||
| std::string input_dtype_item = item.value().at(kDtype); | |||||
| select_output.dtypes = SplitStrToVec(input_dtype_item); | |||||
| std::string input_format_item = item.value().at(kFormat); | |||||
| select_output.formats = SplitStrToVec(input_format_item); | |||||
| outputs.emplace_back(select_output); | |||||
| } | |||||
| } | |||||
| if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) { | |||||
| MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register."; | |||||
| } | |||||
| *op_info_new = op_info; | |||||
| op_info_new->ClearInputs(); | |||||
| op_info_new->ClearOutputs(); | |||||
| for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { | |||||
| auto input_new = std::make_shared<OpIOInfo>(); | |||||
| CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get()); | |||||
| op_info_new->add_inputs_ptr(input_new); | |||||
| } | |||||
| for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) { | |||||
| auto output_new = std::make_shared<OpIOInfo>(); | |||||
| CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get()); | |||||
| op_info_new->add_outputs_ptr(output_new); | |||||
| } | |||||
| } | |||||
| } | |||||
| void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, | |||||
| const std::vector<std::string> &support_dtype, | |||||
| const std::vector<std::string> &support_format, | |||||
| mindspore::kernel::OpIOInfo *op_io_info_new) { | |||||
| MS_EXCEPTION_IF_NULL(op_io_info_new); | |||||
| op_io_info_new->set_index(op_io_info.index()); | |||||
| op_io_info_new->set_name(op_io_info.name()); | |||||
| op_io_info_new->set_param_type(op_io_info.param_type()); | |||||
| op_io_info_new->set_need_compile(op_io_info.need_compile()); | |||||
| op_io_info_new->set_reshape_type(op_io_info.reshape_type()); | |||||
| op_io_info_new->set_shape(op_io_info.shape()); | |||||
| // dtype | |||||
| std::vector<std::string> dtype_new; | |||||
| for (size_t i = 0; i < support_format.size(); ++i) { | |||||
| dtype_new.insert(dtype_new.end(), support_dtype.begin(), support_dtype.end()); | |||||
| } | |||||
| op_io_info_new->set_dtypes(dtype_new); | |||||
| // format | |||||
| std::vector<std::string> format_new; | |||||
| for (const auto &format : support_format) { | |||||
| for (size_t j = 0; j < support_dtype.size(); ++j) { | |||||
| format_new.emplace_back(format); | |||||
| } | |||||
| } | |||||
| op_io_info_new->set_formats(format_new); | |||||
| } | |||||
| void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) { | |||||
| if (support_format.input_format.size() != support_format.output_format.size()) { | |||||
| MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output(" | |||||
| << support_format.output_format.size() << ") size not match."; | |||||
| } | |||||
| for (size_t i = 0; i < support_format.input_format.size(); ++i) { | |||||
| auto input_items = support_format.input_format.at(i); | |||||
| auto output_items = support_format.output_format.at(i); | |||||
| std::string print_str = "["; | |||||
| for (const auto &input : input_items) { | |||||
| print_str.append(input); | |||||
| print_str.append(", "); | |||||
| } | |||||
| print_str.append("] -->"); | |||||
| for (const auto &output : output_items) { | |||||
| print_str.append(output); | |||||
| print_str.append(", "); | |||||
| } | |||||
| MS_LOG(INFO) << "Support format: " << print_str; | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,77 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_TBE_KERNEL_SELECT_H | |||||
| #define MINDSPORE_TBE_KERNEL_SELECT_H | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "kernel/oplib/opinfo.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| #include "kernel/tbe/tbe_kernel_select/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | |||||
| class TbeKernelSelect { | |||||
| using OpInfoPtr = std::shared_ptr<OpInfo>; | |||||
| using KernelBuildInfoIter = std::vector<std::shared_ptr<KernelBuildInfo>>::iterator; | |||||
| public: | |||||
| TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | |||||
| ~TbeKernelSelect() = default; | |||||
| void TbeMetadataInfoEx(); | |||||
| private: | |||||
| void GetCommonPatternKernelInfo(const OpInfo &op_info); | |||||
| void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info); | |||||
| void GetAgnosticPatternKernelInfo(const OpInfo &op_info); | |||||
| void GetBroadcastPatternKernelInfo(const OpInfo &op_info); | |||||
| void GetReducePatternKernelInfo(const OpInfo &op_info); | |||||
| void FilterInVaildKernelInfo(); | |||||
| bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); | |||||
| static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format); | |||||
| bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); | |||||
| static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); | |||||
| bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, const std::vector<int> &dyn_input_sizes, | |||||
| std::vector<std::string> *formats, std::vector<TypeId> *device_types, | |||||
| std::vector<std::vector<Axis>> *reshape_types); | |||||
| static void StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec); | |||||
| static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); | |||||
| static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, | |||||
| const std::vector<std::vector<std::string>> &support_format_item, size_t index, | |||||
| OpIOInfo *op_io_info_new); | |||||
| // op select(dynamic) | |||||
| void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new); | |||||
| static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector<std::string> &support_dtype, | |||||
| const std::vector<std::string> &support_format, OpIOInfo *op_io_info_new); | |||||
| static std::vector<std::string> SplitStrToVec(const std::string &op_select_json_item); | |||||
| std::string OpSelectFormat(); | |||||
| static void PrintSupportedFormat(const SupportFormat &support_format); | |||||
| private: | |||||
| CNodePtr cnode_ptr_; | |||||
| std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list_; | |||||
| std::string node_name_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_TBE_KERNEL_SELECT_H | |||||
| @@ -216,6 +216,13 @@ constexpr char NEG[] = "Neg"; | |||||
| constexpr char BATCH_MATMUL[] = "BatchMatMul"; | constexpr char BATCH_MATMUL[] = "BatchMatMul"; | ||||
| constexpr char EXPAND_DIMS[] = "ExpandDims"; | constexpr char EXPAND_DIMS[] = "ExpandDims"; | ||||
| constexpr char SQUARE[] = "Square"; | constexpr char SQUARE[] = "Square"; | ||||
| constexpr char BATCHMATMUL[] = "BatchMatMul"; | |||||
| constexpr char TOPK[] = "TopK"; | |||||
| constexpr char IN_TOPK[] = "InTopK"; | |||||
| constexpr char PACK[] = "Pack"; | |||||
| constexpr char GATHER_ND[] = "GatherNd"; | |||||
| constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; | |||||
| constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; | |||||
| // Parallel don't care | // Parallel don't care | ||||
| constexpr char TUPLE_GETITEM[] = "tuple_getitem"; | constexpr char TUPLE_GETITEM[] = "tuple_getitem"; | ||||
| @@ -21,7 +21,6 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "device/ascend/kernel_select_ascend.h" | #include "device/ascend/kernel_select_ascend.h" | ||||
| #include "kernel/kernel_query.h" | #include "kernel/kernel_query.h" | ||||
| #include "kernel/tbe/tbe_kernel_select.h" | |||||
| #include "kernel/oplib/oplib.h" | #include "kernel/oplib/oplib.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| @@ -34,7 +34,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto node_name = AnfAlgo::GetCNodeName(node); | auto node_name = AnfAlgo::GetCNodeName(node); | ||||
| if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) { | |||||
| if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); | auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); | ||||
| @@ -26,12 +26,9 @@ abs_op_info = TBERegOp("Abs") \ | |||||
| .op_pattern("formatAgnostic") \ | .op_pattern("formatAgnostic") \ | ||||
| .input(0, "x", None, "required", None) \ | .input(0, "x", None, "required", None) \ | ||||
| .output(0, "y", True, "required", "all") \ | .output(0, "y", True, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -23,7 +23,6 @@ abs_grad_op_info = TBERegOp("AbsGrad") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("abs_grad") \ | .kernel_name("abs_grad") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .op_pattern("formatAgnostic") \ | |||||
| .input(0, "y", None, "required", None) \ | .input(0, "y", None, "required", None) \ | ||||
| .input(1, "dy", None, "required", None) \ | .input(1, "dy", None, "required", None) \ | ||||
| .output(0, "z", False, "required", "all") \ | .output(0, "z", False, "required", "all") \ | ||||
| @@ -26,6 +26,7 @@ add_op_info = TBERegOp("Add") \ | |||||
| .input(0, "x1", False, "required", "all") \ | .input(0, "x1", False, "required", "all") \ | ||||
| .input(1, "x2", False, "required", "all") \ | .input(1, "x2", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | ||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| @@ -26,17 +26,10 @@ add_n_op_info = TBERegOp("AddN") \ | |||||
| .attr("n", "required", "int", "all") \ | .attr("n", "required", "int", "all") \ | ||||
| .input(0, "x", False, "dynamic", "all") \ | .input(0, "x", False, "dynamic", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \ | |||||
| .op_pattern("broadcast") \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -29,6 +29,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \ | |||||
| .input(1, "x2", False, "required", "all") \ | .input(1, "x2", False, "required", "all") \ | ||||
| .input(2, "bias", False, "optional", "all") \ | .input(2, "bias", False, "optional", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ | .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ | ||||
| @@ -27,6 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .input(1, "bias", False, "required", "all") \ | .input(1, "bias", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| @@ -26,6 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \ | |||||
| .input(0, "x", False, "required", "all", reshape_type="NC") \ | .input(0, "x", False, "required", "all", reshape_type="NC") \ | ||||
| .output(0, "sum", False, "required", "all") \ | .output(0, "sum", False, "required", "all") \ | ||||
| .output(1, "square_sum", False, "required", "all") \ | .output(1, "square_sum", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -32,6 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \ | |||||
| .input(5, "batch_mean", False, "required", "all") \ | .input(5, "batch_mean", False, "required", "all") \ | ||||
| .input(6, "batch_variance", False, "required", "all") \ | .input(6, "batch_variance", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all", reshape_type="NC") \ | .output(0, "y", False, "required", "all", reshape_type="NC") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ | DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| @@ -30,6 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \ | |||||
| .input(3, "batch_variance", False, "required", "all") \ | .input(3, "batch_variance", False, "required", "all") \ | ||||
| .output(0, "diff_scale", False, "required", "all") \ | .output(0, "diff_scale", False, "required", "all") \ | ||||
| .output(1, "diff_offset", False, "required", "all") \ | .output(1, "diff_offset", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD) \ | DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| @@ -32,6 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \ | |||||
| .output(0, "y", False, "required", "all", reshape_type="NC") \ | .output(0, "y", False, "required", "all", reshape_type="NC") \ | ||||
| .output(1, "batch_mean", False, "required", "all") \ | .output(1, "batch_mean", False, "required", "all") \ | ||||
| .output(2, "batch_variance", False, "required", "all") \ | .output(2, "batch_variance", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, | DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD) \ | DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| @@ -26,32 +26,27 @@ cast_op_info = TBERegOp("Cast") \ | |||||
| .attr("dst_type", "required", "int", "all") \ | .attr("dst_type", "required", "int", "all") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.BOOL_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | |||||
| .op_pattern("formatAgnostic") \ | |||||
| .dtype_format(DataType.BOOL_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.BOOL_None, DataType.U8_None) \ | |||||
| .dtype_format(DataType.BOOL_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.BOOL_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.I8_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.I8_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.I8_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.BOOL_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I8_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.U8_None) \ | |||||
| .dtype_format(DataType.F16_None, DataType.U8_None) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.F16_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.I32_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,6 +26,7 @@ concat_op_info = TBERegOp("Concat") \ | |||||
| .attr("axis", "required", "int", "all") \ | .attr("axis", "required", "int", "all") \ | ||||
| .input(0, "input_values", False, "dynamic", "all") \ | .input(0, "input_values", False, "dynamic", "all") \ | ||||
| .output(0, "output_data", False, "required", "all") \ | .output(0, "output_data", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | ||||
| .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ | .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ | ||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | ||||
| @@ -23,6 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("conv2d") \ | .kernel_name("conv2d") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .attr("stride", "required", "listInt", "all") \ | .attr("stride", "required", "listInt", "all") \ | ||||
| .attr("pad_list", "required", "listInt", "all") \ | .attr("pad_list", "required", "listInt", "all") \ | ||||
| .attr("dilation", "required", "listInt", "all") \ | .attr("dilation", "required", "listInt", "all") \ | ||||
| @@ -32,8 +33,7 @@ conv2d_op_info = TBERegOp("Conv2D") \ | |||||
| .input(2, "bias", False, "optional", "all") \ | .input(2, "bias", False, "optional", "all") \ | ||||
| .input(3, "offset_w", False, "optional", "all") \ | .input(3, "offset_w", False, "optional", "all") \ | ||||
| .output(0, "y", True, "required", "all") \ | .output(0, "y", True, "required", "all") \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default, | |||||
| DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -27,6 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \ | |||||
| .input(1, "mask", False, "required", "all") \ | .input(1, "mask", False, "required", "all") \ | ||||
| .input(2, "keep_prob", False, "required", "all") \ | .input(2, "keep_prob", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -28,9 +28,7 @@ elu_op_info = TBERegOp("Elu") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,9 +26,7 @@ erf_op_info = TBERegOp("Erf") \ | |||||
| .op_pattern("formatAgnostic") \ | .op_pattern("formatAgnostic") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,9 +26,7 @@ erfc_op_info = TBERegOp("Erfc") \ | |||||
| .op_pattern("formatAgnostic") \ | .op_pattern("formatAgnostic") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -27,9 +27,7 @@ expm1_op_info = TBERegOp("Expm1") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -27,6 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \ | |||||
| .input(1, "x2", False, "required", "all") \ | .input(1, "x2", False, "required", "all") \ | ||||
| .input(2, "x3", False, "required", "all") \ | .input(2, "x3", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | ||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | ||||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ | .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ | ||||
| @@ -32,6 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \ | |||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .output(1, "mean", False, "required", "all") \ | .output(1, "mean", False, "required", "all") \ | ||||
| .output(2, "variance", False, "required", "all") \ | .output(2, "variance", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_Default) \ | DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | ||||
| @@ -30,6 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop") | |||||
| .input(3, "mean", False, "required", "all") \ | .input(3, "mean", False, "required", "all") \ | ||||
| .output(0, "pd_gamma", False, "required", "all") \ | .output(0, "pd_gamma", False, "required", "all") \ | ||||
| .output(1, "pd_beta", False, "required", "all") \ | .output(1, "pd_beta", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F32_Default, DataType.F32_Default) \ | DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | ||||
| @@ -29,6 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \ | |||||
| .input(3, "mean", False, "required", "all") \ | .input(3, "mean", False, "required", "all") \ | ||||
| .input(4, "gamma", False, "required", "all") \ | .input(4, "gamma", False, "required", "all") \ | ||||
| .output(0, "pd_x", False, "required", "all") \ | .output(0, "pd_x", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_Default) \ | DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | ||||
| @@ -26,21 +26,8 @@ mul_op_info = TBERegOp("Mul") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .input(1, "y", False, "required", "all") \ | .input(1, "y", False, "required", "all") \ | ||||
| .output(0, "output", False, "required", "all") \ | .output(0, "output", False, "required", "all") \ | ||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ | |||||
| .dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \ | |||||
| .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,10 +26,9 @@ realdiv_op_info = TBERegOp("RealDiv") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .input(1, "y", False, "required", "all") \ | .input(1, "y", False, "required", "all") \ | ||||
| .output(0, "z", False, "required", "all") \ | .output(0, "z", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .op_pattern("broadcast") \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -25,6 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \ | |||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | ||||
| .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | ||||
| @@ -27,11 +27,11 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \ | |||||
| .attr("keep_dims", "optional", "bool", "all") \ | .attr("keep_dims", "optional", "bool", "all") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .op_pattern("reduce") \ | |||||
| .dtype_format(DataType.I8_None, DataType.I8_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.U8_None) \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -24,7 +24,7 @@ relu_grad_v2_op_info = TBERegOp("ReluGradV2") \ | |||||
| .kernel_name("relu_grad_v2") \ | .kernel_name("relu_grad_v2") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .input(0, "gradients", False, "required", "all") \ | .input(0, "gradients", False, "required", "all") \ | ||||
| .input(1, "mask", False, "rerequired", "all") \ | |||||
| .input(1, "mask", False, "required", "all") \ | |||||
| .output(0, "backprops", True, "required", "all") \ | .output(0, "backprops", True, "required", "all") \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \ | .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \ | .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \ | ||||
| @@ -27,6 +27,7 @@ select_op_info = TBERegOp("Select") \ | |||||
| .input(1, "x1", False, "required", "all") \ | .input(1, "x1", False, "required", "all") \ | ||||
| .input(2, "x2", False, "required", "all") \ | .input(2, "x2", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | ||||
| .dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | ||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | ||||
| @@ -27,11 +27,8 @@ sign_op_info = TBERegOp("Sign") \ | |||||
| .input(0, "x", None, "required", None) \ | .input(0, "x", None, "required", None) \ | ||||
| .output(0, "y", True, "required", "all") \ | .output(0, "y", True, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | ||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -30,6 +30,7 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \ | |||||
| .input(1, "x1", False, "required", "all") \ | .input(1, "x1", False, "required", "all") \ | ||||
| .input(2, "x2", False, "required", "all") \ | .input(2, "x2", False, "required", "all") \ | ||||
| .output(0, "y", True, "required", "all") \ | .output(0, "y", True, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_Default) \ | DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, | ||||
| @@ -27,9 +27,7 @@ softplus_op_info = TBERegOp("Softplus") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -28,9 +28,7 @@ softplus_grad_op_info = TBERegOp("SoftplusGrad") \ | |||||
| .input(1, "features", False, "required", "all") \ | .input(1, "features", False, "required", "all") \ | ||||
| .output(0, "backprops", False, "required", "all") \ | .output(0, "backprops", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -27,6 +27,7 @@ split_d_op_info = TBERegOp("Split") \ | |||||
| .attr("output_num", "required", "int", "all") \ | .attr("output_num", "required", "int", "all") \ | ||||
| .input(0, "value", False, "required", "all") \ | .input(0, "value", False, "required", "all") \ | ||||
| .output(0, "output", False, "dynamic", "all") \ | .output(0, "output", False, "dynamic", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | ||||
| .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ | .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ | ||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | ||||
| @@ -26,6 +26,7 @@ tensor_add_op_info = TBERegOp("TensorAdd") \ | |||||
| .input(0, "x1", False, "required", "all") \ | .input(0, "x1", False, "required", "all") \ | ||||
| .input(1, "x2", False, "required", "all") \ | .input(1, "x2", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| @@ -27,6 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .input(1, "segment_ids", False, "required", "all") \ | .input(1, "segment_ids", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ | .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ | ||||
| .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \ | .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \ | ||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ | .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ | ||||
| @@ -97,6 +97,7 @@ class RegOp: | |||||
| """ | """ | ||||
| if not isinstance(value, str): | if not isinstance(value, str): | ||||
| raise TypeError("%s value must be str" % str(value)) | raise TypeError("%s value must be str" % str(value)) | ||||
| return True | |||||
| def _is_int(self, value): | def _is_int(self, value): | ||||
| """ | """ | ||||
| @@ -110,6 +111,7 @@ class RegOp: | |||||
| """ | """ | ||||
| if not isinstance(value, int): | if not isinstance(value, int): | ||||
| raise TypeError("%s value must be int" % str(value)) | raise TypeError("%s value must be int" % str(value)) | ||||
| return True | |||||
| def _is_bool(self, value): | def _is_bool(self, value): | ||||
| """ | """ | ||||
| @@ -123,6 +125,7 @@ class RegOp: | |||||
| """ | """ | ||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("%s value must be bool" % str(value)) | raise TypeError("%s value must be bool" % str(value)) | ||||
| return True | |||||
| def _check_param(self, param_list, key_list, fn_list, kwargs): | def _check_param(self, param_list, key_list, fn_list, kwargs): | ||||
| """ | """ | ||||
| @@ -494,6 +497,7 @@ class DataType: | |||||
| The current list below maybe not completed. If necessary, please add it. | The current list below maybe not completed. If necessary, please add it. | ||||
| """ | """ | ||||
| None_None = ("", "") | |||||
| BOOL_None = ("bool", "") | BOOL_None = ("bool", "") | ||||
| BOOL_Default = ("bool", "DefaultFormat") | BOOL_Default = ("bool", "DefaultFormat") | ||||
| BOOL_5HD = ("bool", "NC1HWC0") | BOOL_5HD = ("bool", "NC1HWC0") | ||||