/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ #define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ #include #include #include #include #include #include #include #include "ir/dtype.h" #include "kernel/kernel.h" #include "pybind11/stl.h" #include "kernel/oplib/oplib.h" #include "kernel/tbe/tbe_adapter.h" namespace mindspore { namespace kernel { // kernel operate type used for generate json class TbeKernelBuild { public: static bool GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, std::vector *output_size_list); // Ub Fuison static bool GenFusionScopeJson(const std::vector &input_nodes, const std::vector &compute_nodes, nlohmann::json *fusion_str, std::string *fusion_kernel); static bool GetIOSize(const nlohmann::json &fusion_op_list, const std::vector &output_nodes, std::vector *input_size_list, std::vector *output_size_list); private: TbeKernelBuild() = default; ~TbeKernelBuild() = default; static bool GenFusionDataInputJson(const std::shared_ptr &data_input, nlohmann::json *data_str, size_t *index); static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, std::vector>::iterator *layer_iter, nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index); static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, std::vector>::iterator *layer_iter, std::vector *input_desc_list, size_t *index); static std::vector GetDescOutputIndex(const std::vector &output_used_nums); static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, std::vector *output_desc_list); static void GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, size_t desc_output_idx, nlohmann::json *output_desc); static void GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, size_t output_index, nlohmann::json *output_desc); static size_t GetIOSizeImpl(const nlohmann::json &desc); static bool GetInputLayers(const std::vector &input_nodes, const std::vector &compute_nodes, std::vector> *input_layers); static bool IsDynamicInput(const CNodePtr &cnode); static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input); }; class TbeKernelJsonCreator { public: explicit TbeKernelJsonCreator(kCreaterType creater_type = SINGLE_BUILD) : creater_type_(creater_type) {} ~TbeKernelJsonCreator() = default; bool GenTbeSingleKernelJson(const std::shared_ptr &anf_node, nlohmann::json *kernel_json); std::string json_name() { return json_name_; } private: bool GenTbeInputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, nlohmann::json *inputs_json); bool GenTbeOutputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, nlohmann::json *outputs_json); bool GenTbeAttrJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, nlohmann::json *attrs_json); void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); bool GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, bool value, const std::shared_ptr &input_ptr, const string &op_input_name, size_t input_i, std::vector *input_list); bool GenOutputDescJson(const std::shared_ptr &anf_node, const std::vector> &outputs_ptr, nlohmann::json *outputs_json); bool GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, const std::shared_ptr &input_ptr, size_t *real_input_index, string *op_input_name, std::vector *input_list); void GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, const std::shared_ptr &output_ptr, size_t *output_idx, std::vector *output_list); kCreaterType creater_type_; std::string json_name_; std::string json_info_; }; } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_