You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_kernel_build.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_
  17. #define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_
  18. #include <string>
  19. #include <unordered_map>
  20. #include <memory>
  21. #include <map>
  22. #include <utility>
  23. #include <vector>
  24. #include <nlohmann/json.hpp>
  25. #include "ir/dtype.h"
  26. #include "kernel/kernel.h"
  27. #include "pybind11/stl.h"
  28. #include "kernel/oplib/oplib.h"
  29. #include "kernel/tbe/tbe_adapter.h"
  30. namespace mindspore {
  31. namespace kernel {
  32. // kernel operate type used for generate json
  33. class TbeKernelBuild {
  34. public:
  35. static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
  36. std::vector<size_t> *output_size_list);
  37. // Ub Fuison
  38. static bool GenFusionScopeJson(const std::vector<AnfNodePtr> &input_nodes,
  39. const std::vector<AnfNodePtr> &compute_nodes, nlohmann::json *fusion_str,
  40. std::string *fusion_kernel);
  41. static bool GetIOSize(const nlohmann::json &fusion_op_list, const std::vector<AnfNodePtr> &output_nodes,
  42. std::vector<size_t> *input_size_list, std::vector<size_t> *output_size_list);
  43. private:
  44. TbeKernelBuild() = default;
  45. ~TbeKernelBuild() = default;
  46. static bool GenFusionDataInputJson(const std::shared_ptr<mindspore::AnfNode> &data_input, nlohmann::json *data_str,
  47. size_t *index);
  48. static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node,
  49. std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
  50. nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index);
  51. static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
  52. std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
  53. std::vector<nlohmann::json> *input_desc_list, size_t *index);
  54. static std::vector<size_t> GetDescOutputIndex(const std::vector<int> &output_used_nums);
  55. static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode,
  56. std::vector<nlohmann::json> *output_desc_list);
  57. static void GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx,
  58. size_t desc_output_idx, nlohmann::json *output_desc);
  59. static void GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t index,
  60. size_t output_index, nlohmann::json *output_desc);
  61. static size_t GetIOSizeImpl(const nlohmann::json &desc);
  62. static bool GetInputLayers(const std::vector<mindspore::AnfNodePtr> &input_nodes,
  63. const std::vector<mindspore::AnfNodePtr> &compute_nodes,
  64. std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers);
  65. static bool IsDynamicInput(const CNodePtr &cnode);
  66. static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input);
  67. };
  68. class TbeKernelJsonCreator {
  69. public:
  70. explicit TbeKernelJsonCreator(kCreaterType creater_type = SINGLE_BUILD) : creater_type_(creater_type) {}
  71. ~TbeKernelJsonCreator() = default;
  72. bool GenTbeSingleKernelJson(const std::shared_ptr<AnfNode> &anf_node, nlohmann::json *kernel_json);
  73. std::string json_name() { return json_name_; }
  74. private:
  75. bool GenTbeInputsJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
  76. nlohmann::json *inputs_json);
  77. bool GenTbeOutputsJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
  78. nlohmann::json *outputs_json);
  79. bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
  80. nlohmann::json *attrs_json);
  81. void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj);
  82. bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
  83. const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i,
  84. std::vector<nlohmann::json> *input_list);
  85. bool GenOutputDescJson(const std::shared_ptr<AnfNode> &anf_node,
  86. const std::vector<std::shared_ptr<OpIOInfo>> &outputs_ptr, nlohmann::json *outputs_json);
  87. bool GenInputList(const std::shared_ptr<AnfNode> &anf_node, size_t input_tensor_num,
  88. const std::shared_ptr<OpIOInfo> &input_ptr, size_t *real_input_index, string *op_input_name,
  89. std::vector<nlohmann::json> *input_list);
  90. void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
  91. const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
  92. std::vector<nlohmann::json> *output_list);
  93. kCreaterType creater_type_;
  94. std::string json_name_;
  95. std::string json_info_;
  96. };
  97. } // namespace kernel
  98. } // namespace mindspore
  99. #endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_