You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

oplib.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/oplib/oplib.h"
  17. #include <pybind11/pybind11.h>
  18. #include <unordered_map>
  19. #include <memory>
  20. #include "utils/log_adapter.h"
  21. #include "utils/overload.h"
  22. #include "utils/context/ms_context.h"
  23. namespace mindspore {
  24. namespace kernel {
  25. constexpr auto kImplyType = "imply_type";
  26. constexpr auto kOpName = "op_name";
  27. constexpr auto kFusionType = "fusion_type";
  28. constexpr auto kAsyncFlag = "async_flag";
  29. constexpr auto kBinfileName = "binfile_name";
  30. constexpr auto kComputeCost = "compute_cost";
  31. constexpr auto kKernelName = "kernel_name";
  32. constexpr auto kPartialFlag = "partial_flag";
  33. constexpr auto kReshapeType = "reshape_type";
  34. constexpr auto kOpPattern = "op_pattern";
  35. constexpr auto kDynamicFormat = "dynamic_format";
  36. constexpr auto kDtypeFormat = "dtype_format";
  37. constexpr auto kAttr = "attr";
  38. constexpr auto kIputs = "inputs";
  39. constexpr auto kOutputs = "outputs";
  40. constexpr auto kAiCPU = "AiCPU";
  41. constexpr auto kTbe = "TBE";
  42. constexpr auto kAkg = "akg";
  43. constexpr auto kAutodiff = "AutoDiff";
  44. constexpr auto kName = "name";
  45. constexpr auto kParamType = "param_type";
  46. constexpr auto kDtype = "dtype";
  47. constexpr auto kType = "type";
  48. constexpr auto kValue = "value";
  49. constexpr auto kDefaultValue = "default_value";
  50. constexpr auto kIndex = "index";
  51. constexpr auto kFormat = "format";
  52. constexpr auto kNeedCompile = "need_compile";
  53. constexpr auto kShape = "shape";
  54. std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_;
  55. std::string ImplTypeToStr(OpImplyType impl_type) {
  56. switch (impl_type) {
  57. case kTBE:
  58. return kTbe;
  59. case kAKG:
  60. return kAkg;
  61. case kAICPU:
  62. return kAiCPU;
  63. default:
  64. return "unknow";
  65. }
  66. }
  67. bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) {
  68. bool ret = false;
  69. try {
  70. auto op_json = nlohmann::json::parse(json_string);
  71. std::string imply_type_string = op_json.at(kImplyType);
  72. std::string op_name = op_json.at(kOpName);
  73. if (imply_type_string == kTbe) {
  74. OpImplyType imply_type = kTBE;
  75. ret = DecodeOpInfo(op_json, imply_type, impl_path);
  76. } else if (imply_type_string == kAutodiff) {
  77. OpImplyType imply_type = kAKG;
  78. ret = DecodeOpInfo(op_json, imply_type, impl_path);
  79. } else if (imply_type_string == kAiCPU) {
  80. OpImplyType imply_type = kAICPU;
  81. ret = DecodeOpInfo(op_json, imply_type, impl_path);
  82. } else {
  83. MS_LOG(ERROR) << "Not support imply_type";
  84. }
  85. if (!ret) {
  86. MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string;
  87. }
  88. } catch (const std::exception &e) {
  89. MS_LOG(ERROR) << "get op json elements failed: " << e.what();
  90. }
  91. return ret;
  92. }
  93. void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
  94. op_info->set_async_flag(obj.at(kAsyncFlag));
  95. op_info->set_binfile_name(obj.at(kBinfileName));
  96. op_info->set_compute_cost(obj.at(kComputeCost));
  97. op_info->set_kernel_name(obj.at(kKernelName));
  98. op_info->set_partial_flag(obj.at(kPartialFlag));
  99. if (obj.find(kOpPattern) != obj.end()) {
  100. op_info->set_op_pattern(obj.at(kOpPattern));
  101. }
  102. if (obj.find(kDynamicFormat) != obj.end()) {
  103. op_info->set_dynamic_format(obj.at(kDynamicFormat));
  104. }
  105. }
  106. bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
  107. const std::string &impl_path) {
  108. std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
  109. MS_EXCEPTION_IF_NULL(op_info);
  110. op_info->set_op_name(obj.at(kOpName));
  111. op_info->set_impl_path(impl_path);
  112. op_info->set_imply_type(imply_type);
  113. op_info->set_fusion_type(obj.at(kFusionType));
  114. if (imply_type == kTBE) {
  115. DecodeTBESpecificInfo(obj, op_info);
  116. }
  117. auto attrs = obj.at(kAttr);
  118. for (const auto &attr : attrs) {
  119. if (!DecodeAttr(attr, imply_type, op_info)) {
  120. MS_LOG(ERROR) << "DecodeAttr Failed";
  121. return false;
  122. }
  123. }
  124. nlohmann::json dtype_format;
  125. if (obj.find(kDtypeFormat) != obj.end()) {
  126. dtype_format = obj.at(kDtypeFormat);
  127. }
  128. auto inputs = obj.at(kIputs);
  129. for (const auto &input : inputs) {
  130. if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
  131. MS_LOG(ERROR) << "DecodeInputOutput Failed";
  132. return false;
  133. }
  134. }
  135. auto outputs = obj.at(kOutputs);
  136. for (const auto &output : outputs) {
  137. if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
  138. MS_LOG(ERROR) << "DecodeInputOutput Failed";
  139. return false;
  140. }
  141. }
  142. if (!GetRefInfo(op_info)) {
  143. MS_LOG(ERROR) << "GetRefInfo Failed";
  144. return false;
  145. }
  146. if (!CheckRepetition(op_info)) {
  147. MS_LOG(ERROR) << "CheckRepetition Failed";
  148. return false;
  149. }
  150. op_info_.push_back(op_info);
  151. return true;
  152. }
  153. bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
  154. const std::shared_ptr<OpInfo> &op_info) {
  155. MS_EXCEPTION_IF_NULL(op_info);
  156. bool ret = true;
  157. try {
  158. std::shared_ptr<OpAttr> op_attr = std::make_shared<OpAttr>();
  159. MS_EXCEPTION_IF_NULL(op_attr);
  160. op_attr->set_name(obj.at(kName));
  161. if (imply_type != kAICPU) {
  162. op_attr->set_param_type(obj.at(kParamType));
  163. }
  164. op_attr->set_type(obj.at(kType));
  165. if (imply_type == kTBE) {
  166. op_attr->set_value(obj.at(kValue));
  167. }
  168. if (obj.find(kDefaultValue) != obj.end()) {
  169. op_attr->set_default_value(obj.at(kDefaultValue));
  170. }
  171. op_info->add_attrs_ptr(op_attr);
  172. } catch (const std::exception &e) {
  173. MS_LOG(ERROR) << "DecodeAttr failed:" << e.what();
  174. ret = false;
  175. }
  176. return ret;
  177. }
  178. bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
  179. size_t index) {
  180. bool ret = true;
  181. try {
  182. std::vector<std::string> dtype;
  183. std::vector<std::string> format;
  184. for (const auto &it : dtype_format) {
  185. dtype.emplace_back(it[index][0]);
  186. format.emplace_back(it[index][1]);
  187. }
  188. op_io->set_dtypes(dtype);
  189. op_io->set_formats(format);
  190. } catch (const std::exception &e) {
  191. MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what();
  192. ret = false;
  193. }
  194. return ret;
  195. }
  196. bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
  197. const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) {
  198. bool ret = true;
  199. try {
  200. std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
  201. MS_EXCEPTION_IF_NULL(op_io);
  202. op_io->set_index(obj.at(kIndex));
  203. op_io->set_name(obj.at(kName));
  204. if (!dtype_format.empty()) {
  205. if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) {
  206. MS_LOG(ERROR) << "Decode dtype format failed";
  207. return false;
  208. }
  209. } else {
  210. op_io->set_dtypes(obj.at(kDtype));
  211. op_io->set_formats(obj.at(kFormat));
  212. }
  213. if (op_io->dtypes().size() != op_io->formats().size()) {
  214. MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes()
  215. << " is not equal to format size: " << op_io->formats();
  216. return false;
  217. }
  218. if (obj.find(kParamType) != obj.end()) {
  219. op_io->set_param_type(obj.at(kParamType));
  220. }
  221. if (imply_type == kTBE) {
  222. if (obj.find(kNeedCompile) != obj.end()) {
  223. op_io->set_need_compile(obj.at(kNeedCompile));
  224. }
  225. if (obj.find(kShape) != obj.end()) {
  226. op_io->set_shape(obj.at(kShape));
  227. }
  228. if (obj.find(kReshapeType) != obj.end()) {
  229. op_io->set_reshape_type(obj.at(kReshapeType));
  230. }
  231. }
  232. if (io_type == kInput) {
  233. op_info->add_inputs_ptr(op_io);
  234. } else if (io_type == kOutput) {
  235. op_info->add_outputs_ptr(op_io);
  236. }
  237. } catch (const std::exception &e) {
  238. MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what();
  239. ret = false;
  240. }
  241. return ret;
  242. }
  243. std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) {
  244. auto context = MsContext::GetInstance();
  245. MS_EXCEPTION_IF_NULL(context);
  246. bool is_gpu = (context->device_target() == kGPUDevice);
  247. if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) ||
  248. (!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) {
  249. MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
  250. << ", current op num: " << op_info_.size();
  251. return nullptr;
  252. }
  253. for (const auto &op_info : op_info_) {
  254. MS_EXCEPTION_IF_NULL(op_info);
  255. if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) {
  256. return op_info;
  257. }
  258. }
  259. MS_LOG(DEBUG) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
  260. << ", current op num: " << op_info_.size();
  261. return nullptr;
  262. }
  263. bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
  264. MS_EXCEPTION_IF_NULL(op_info);
  265. const auto &output_infos = op_info->outputs_ptr();
  266. const auto &input_infos = op_info->inputs_ptr();
  267. for (size_t out_index = 0; out_index < output_infos.size(); out_index++) {
  268. const auto &out_name = output_infos[out_index]->name();
  269. for (size_t in_index = 0; in_index < input_infos.size(); in_index++) {
  270. const auto &in_name = input_infos[in_index]->name();
  271. if (out_name == in_name) {
  272. if (op_info->has_ref_index(out_index)) {
  273. MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info";
  274. return false;
  275. }
  276. op_info->add_ref_pair(out_index, in_index);
  277. MS_LOG(INFO) << "add ref info, op name is " << op_info->op_name() << ", outindex is " << out_index
  278. << ", in_index is " << in_index;
  279. }
  280. }
  281. }
  282. return true;
  283. }
  284. bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
  285. MS_EXCEPTION_IF_NULL(op_info);
  286. for (const auto &exist_op_info : op_info_) {
  287. MS_EXCEPTION_IF_NULL(exist_op_info);
  288. if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() &&
  289. exist_op_info->impl_path() != op_info->impl_path()) {
  290. MS_LOG(ERROR) << "Op has already exist, please use other name, op name: " << op_info->op_name()
  291. << " op type: " << ImplTypeToStr(op_info->imply_type());
  292. return false;
  293. }
  294. }
  295. return true;
  296. }
  297. } // namespace kernel
  298. } // namespace mindspore