You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_kernel_build.h 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_
  17. #define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_
  18. #include <string>
  19. #include <unordered_map>
  20. #include <memory>
  21. #include <map>
  22. #include <utility>
  23. #include <vector>
  24. #include <nlohmann/json.hpp>
  25. #include "ir/dtype.h"
  26. #include "kernel/kernel.h"
  27. #include "pybind11/stl.h"
  28. #include "kernel/oplib/oplib.h"
  29. #include "kernel/tbe/tbe_adapter.h"
  30. namespace mindspore {
  31. namespace kernel {
  32. // kernel operate type used for generate json
  33. class TbeKernelBuild {
  34. enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2 };
  35. public:
  36. static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
  37. std::vector<size_t> *output_size_list);
  38. // Ub Fuison
  39. static bool GenFusionScopeJson(const std::vector<AnfNodePtr> &input_nodes,
  40. const std::vector<AnfNodePtr> &compute_nodes, nlohmann::json *fusion_str,
  41. std::string *fusion_kernel);
  42. static bool GetIOSize(const nlohmann::json &fusion_op_list, const std::vector<AnfNodePtr> &output_nodes,
  43. std::vector<size_t> *input_size_list, std::vector<size_t> *output_size_list);
  44. private:
  45. TbeKernelBuild() = default;
  46. ~TbeKernelBuild() = default;
  47. static bool GenFusionDataInputJson(const std::shared_ptr<mindspore::AnfNode> &data_input,
  48. const std::map<const AnfNodePtr, FusionDataType> &spec_data_input,
  49. nlohmann::json *data_str, size_t *index);
  50. static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node,
  51. std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
  52. nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index);
  53. static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
  54. std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
  55. std::vector<nlohmann::json> *input_desc_list, size_t *index);
  56. static std::vector<size_t> GetDescOutputIndex(const std::vector<int> &output_used_nums);
  57. static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode,
  58. std::vector<nlohmann::json> *output_desc_list);
  59. static void GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx,
  60. size_t desc_output_idx, nlohmann::json *output_desc,
  61. FusionDataType fusion_data_type = kFusionNormal);
  62. static void GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t index,
  63. size_t output_index, nlohmann::json *output_desc);
  64. static size_t GetIOSizeImpl(const nlohmann::json &desc);
  65. static bool GetSpecInputLayers(const std::string &op_name, const std::vector<mindspore::AnfNodePtr> &reorder_layer,
  66. std::map<const AnfNodePtr, FusionDataType> *spec_data_input);
  67. static bool GetInputLayers(const std::vector<mindspore::AnfNodePtr> &input_nodes,
  68. const std::vector<mindspore::AnfNodePtr> &compute_nodes,
  69. std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers,
  70. std::map<const AnfNodePtr, FusionDataType> *spec_data_input);
  71. static bool IsDynamicInput(const CNodePtr &cnode);
  72. static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input);
  73. static std::string GetRealOpType(const std::string &origin_type);
  74. };
  75. class TbeKernelJsonCreator {
  76. public:
  77. explicit TbeKernelJsonCreator(kCreaterType creater_type = SINGLE_BUILD) : creater_type_(creater_type) {}
  78. ~TbeKernelJsonCreator() = default;
  79. bool GenTbeSingleKernelJson(const std::shared_ptr<AnfNode> &anf_node, nlohmann::json *kernel_json);
  80. std::string json_name() { return json_name_; }
  81. private:
  82. bool GenTbeInputsJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
  83. nlohmann::json *inputs_json);
  84. bool GenTbeOutputsJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
  85. nlohmann::json *outputs_json);
  86. bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
  87. nlohmann::json *attrs_json);
  88. static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj);
  89. bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
  90. const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i,
  91. std::vector<nlohmann::json> *input_list);
  92. bool GenOutputDescJson(const std::shared_ptr<AnfNode> &anf_node,
  93. const std::vector<std::shared_ptr<OpIOInfo>> &outputs_ptr, nlohmann::json *outputs_json);
  94. bool GenInputList(const std::shared_ptr<AnfNode> &anf_node, size_t input_tensor_num,
  95. const std::shared_ptr<OpIOInfo> &input_ptr, size_t *real_input_index, string *op_input_name,
  96. std::vector<nlohmann::json> *input_list);
  97. void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
  98. const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
  99. std::vector<nlohmann::json> *output_list);
  100. std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const;
  101. std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const;
  102. std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
  103. std::vector<size_t> GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const;
  104. std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const;
  105. std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
  106. kCreaterType creater_type_;
  107. std::string json_name_;
  108. std::string json_info_;
  109. };
  110. } // namespace kernel
  111. } // namespace mindspore
  112. #endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_