You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

oplib.cc 15 kB

5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/oplib/oplib.h"
  17. #include <memory>
  18. #include <map>
  19. #include <fstream>
  20. #include "utils/log_adapter.h"
  21. #include "utils/overload.h"
  22. #include "utils/ms_context.h"
  23. namespace mindspore {
  24. namespace kernel {
  25. constexpr auto kImplyType = "imply_type";
  26. constexpr auto kOpName = "op_name";
  27. constexpr auto kFusionType = "fusion_type";
  28. constexpr auto kAsyncFlag = "async_flag";
  29. constexpr auto kBinfileName = "binfile_name";
  30. constexpr auto kComputeCost = "compute_cost";
  31. constexpr auto kKernelName = "kernel_name";
  32. constexpr auto kPartialFlag = "partial_flag";
  33. constexpr auto kReshapeType = "reshape_type";
  34. constexpr auto kValueDepend = "value_depend";
  35. constexpr auto kOpPattern = "op_pattern";
  36. constexpr auto kIsDynamicFormat = "is_dynamic_format";
  37. constexpr auto kDynamicFormat = "dynamicFormat";
  38. constexpr auto kFormatAgnostic = "formatAgnostic";
  39. constexpr auto kNeedCheckSupported = "need_check_supported";
  40. constexpr auto kBroadcast = "broadcast";
  41. constexpr auto kReduce = "reduce";
  42. constexpr auto kDynamicShape = "dynamic_shape";
  43. constexpr auto kDynamicCompileStatic = "dynamic_compile_static";
  44. constexpr auto kDtypeFormat = "dtype_format";
  45. constexpr auto kAttr = "attr";
  46. constexpr auto kIputs = "inputs";
  47. constexpr auto kOutputs = "outputs";
  48. constexpr auto kAiCPU = "AiCPU";
  49. constexpr auto kAiCore = "AiCore";
  50. constexpr auto kCUDA = "CUDA";
  51. constexpr auto kTbe = "TBE";
  52. constexpr auto kAkg = "AKG";
  53. constexpr auto kCpu = "CPU";
  54. constexpr auto kGpu = "GPU";
  55. constexpr auto kName = "name";
  56. constexpr auto kParamType = "param_type";
  57. constexpr auto kDtype = "dtype";
  58. constexpr auto kType = "type";
  59. constexpr auto kValue = "value";
  60. constexpr auto kDefaultValue = "default_value";
  61. constexpr auto kIndex = "index";
  62. constexpr auto kFormat = "format";
  63. constexpr auto kNeedCompile = "need_compile";
  64. constexpr auto kShape = "shape";
  65. constexpr auto kProcessor = "processor";
  66. std::multimap<std::string, std::shared_ptr<OpInfo>> OpLib::op_info_;
  67. static std::string ImplTypeToStr(OpImplyType impl_type) {
  68. switch (impl_type) {
  69. case kTBE:
  70. return kTbe;
  71. case kAKG:
  72. return kAkg;
  73. case kAICPU:
  74. return kAiCPU;
  75. case kCPU:
  76. return kCpu;
  77. case kGPU:
  78. return kGpu;
  79. default:
  80. return "unknown";
  81. }
  82. }
  83. bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) {
  84. bool ret = false;
  85. try {
  86. auto op_json = nlohmann::json::parse(json_string);
  87. std::string imply_type_string = op_json.at(kImplyType);
  88. std::string op_name = op_json.at(kOpName);
  89. if (imply_type_string == kTbe) {
  90. OpImplyType imply_type = kTBE;
  91. ret = DecodeOpInfo(op_json, imply_type, impl_path);
  92. } else if (imply_type_string == kAkg) {
  93. OpImplyType imply_type = kAKG;
  94. ret = DecodeOpInfo(op_json, imply_type, impl_path);
  95. } else if (imply_type_string == kAiCPU) {
  96. OpImplyType imply_type = kAICPU;
  97. ret = DecodeOpInfo(op_json, imply_type, impl_path);
  98. } else if (imply_type_string == kCpu) {
  99. OpImplyType imply_type = kCPU;
  100. ret = DecodeOpInfo(op_json, imply_type, impl_path);
  101. } else if (imply_type_string == kGpu) {
  102. OpImplyType imply_type = kGPU;
  103. ret = DecodeOpInfo(op_json, imply_type, impl_path);
  104. } else {
  105. MS_LOG(ERROR) << "Not support imply_type";
  106. }
  107. if (!ret) {
  108. MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string;
  109. }
  110. } catch (const std::exception &e) {
  111. MS_LOG(ERROR) << "get op json elements failed: " << e.what();
  112. }
  113. return ret;
  114. }
  115. void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
  116. const std::map<std::string, kernel::OpPattern> kOpPatternMap = {
  117. {kFormatAgnostic, kFormatAgnosticPattern}, {kBroadcast, kBroadcastPattern}, {kReduce, kReducePattern}};
  118. MS_EXCEPTION_IF_NULL(op_info);
  119. op_info->set_async_flag(obj.at(kAsyncFlag));
  120. op_info->set_binfile_name(obj.at(kBinfileName));
  121. op_info->set_compute_cost(obj.at(kComputeCost));
  122. op_info->set_kernel_name(obj.at(kKernelName));
  123. op_info->set_partial_flag(obj.at(kPartialFlag));
  124. op_info->set_need_check_supported(obj.at(kNeedCheckSupported));
  125. if (obj.find(kDynamicShape) != obj.end()) {
  126. op_info->set_dynamic_shape(obj.at(kDynamicShape));
  127. }
  128. if (obj.find(kDynamicCompileStatic) != obj.end()) {
  129. op_info->set_dynamic_compile_static_(obj.at(kDynamicCompileStatic));
  130. }
  131. if (obj.find(kIsDynamicFormat) != obj.end()) {
  132. op_info->set_is_dynamic_format(obj.at(kIsDynamicFormat));
  133. }
  134. if (obj.find(kOpPattern) != obj.end()) {
  135. std::string op_pattern = obj.at(kOpPattern);
  136. auto find_iter = kOpPatternMap.find(op_pattern);
  137. if (find_iter == kOpPatternMap.end()) {
  138. if (!op_pattern.empty()) {
  139. MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern;
  140. }
  141. op_info->set_op_pattern(kCommonPattern);
  142. } else {
  143. op_info->set_op_pattern(find_iter->second);
  144. }
  145. }
  146. }
  147. void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
  148. MS_EXCEPTION_IF_NULL(op_info);
  149. op_info->set_processor(obj.at(kProcessor));
  150. }
  151. bool OpLib::RegOpFromLocalInfo() {
  152. static bool has_load = false;
  153. if (has_load) {
  154. return true;
  155. }
  156. MS_LOG(INFO) << "Start";
  157. has_load = true;
  158. std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH");
  159. if (dir.empty()) {
  160. MS_LOG(INFO) << "MindSpore op info path does not been set. use op info from python pass.";
  161. return true;
  162. }
  163. char real_path[PATH_MAX] = {0};
  164. if (dir.size() >= PATH_MAX) {
  165. MS_LOG(ERROR) << "Invalid environment variable 'MINDSPORE_OP_INFO_PATH', the path length should be smaller than "
  166. << PATH_MAX << ", but got " << dir;
  167. return false;
  168. }
  169. #if defined(_WIN32) || defined(_WIN64)
  170. if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) {
  171. MS_LOG(ERROR) << "Op info path is invalid: " << dir;
  172. return false;
  173. }
  174. #else
  175. if (realpath(common::SafeCStr(dir), real_path) == nullptr) {
  176. MS_LOG(ERROR) << "Invalid environment variable 'MINDSPORE_OP_INFO_PATH', the path is: " << dir
  177. << ". Please check (1) whether the path exists, (2) whether the path has the access permission, "
  178. << "(3) whether the path is too long. ";
  179. return false;
  180. }
  181. if (strlen(real_path) >= PATH_MAX) {
  182. MS_LOG(ERROR) << "Invalid environment variable 'MINDSPORE_OP_INFO_PATH', the absolute path length should be smaller"
  183. << " than " << PATH_MAX << ", but got " << real_path;
  184. return false;
  185. }
  186. #endif
  187. MS_LOG(INFO) << "Start to read op info from local file.";
  188. std::ifstream file(real_path);
  189. if (!file.is_open()) {
  190. MS_LOG(ERROR) << "Find op info file failed.";
  191. return false;
  192. }
  193. std::string line;
  194. while (getline(file, line)) {
  195. if (!line.empty()) {
  196. (void)OpLib::RegOp(line, "");
  197. }
  198. }
  199. file.close();
  200. MS_LOG(INFO) << "End";
  201. return true;
  202. }
  203. bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
  204. const std::string &impl_path) {
  205. std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
  206. MS_EXCEPTION_IF_NULL(op_info);
  207. op_info->set_op_name(obj.at(kOpName));
  208. op_info->set_impl_path(impl_path);
  209. op_info->set_imply_type(imply_type);
  210. op_info->set_fusion_type(obj.at(kFusionType));
  211. if (imply_type == kTBE) {
  212. DecodeTBESpecificInfo(obj, op_info);
  213. } else if (imply_type == kAKG) {
  214. DecodeAKGSpecificInfo(obj, op_info);
  215. }
  216. auto attrs = obj.at(kAttr);
  217. for (const auto &attr : attrs) {
  218. if (!DecodeAttr(attr, imply_type, op_info)) {
  219. MS_LOG(ERROR) << "DecodeAttr Failed";
  220. return false;
  221. }
  222. }
  223. nlohmann::json dtype_format;
  224. if (obj.find(kDtypeFormat) != obj.end()) {
  225. dtype_format = obj.at(kDtypeFormat);
  226. }
  227. auto inputs = obj.at(kIputs);
  228. for (const auto &input : inputs) {
  229. if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
  230. MS_LOG(ERROR) << "DecodeInputOutput Failed";
  231. return false;
  232. }
  233. }
  234. auto outputs = obj.at(kOutputs);
  235. for (const auto &output : outputs) {
  236. if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
  237. MS_LOG(ERROR) << "DecodeInputOutput Failed";
  238. return false;
  239. }
  240. }
  241. if (CheckRepetition(op_info)) {
  242. MS_LOG(WARNING) << "This op info has been already registered. op name: " << op_info->op_name()
  243. << ", impl type: " << ImplTypeToStr(op_info->imply_type())
  244. << ", impl path: " << op_info->impl_path();
  245. return true;
  246. }
  247. if (!GetRefInfo(op_info)) {
  248. MS_LOG(ERROR) << "GetRefInfo Failed";
  249. return false;
  250. }
  251. op_info_.emplace(op_info->op_name(), op_info);
  252. return true;
  253. }
  254. bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
  255. const std::shared_ptr<OpInfo> &op_info) {
  256. MS_EXCEPTION_IF_NULL(op_info);
  257. bool ret = true;
  258. try {
  259. std::shared_ptr<OpAttr> op_attr = std::make_shared<OpAttr>();
  260. MS_EXCEPTION_IF_NULL(op_attr);
  261. op_attr->set_name(obj.at(kName));
  262. if (imply_type != kAICPU) {
  263. op_attr->set_param_type(obj.at(kParamType));
  264. }
  265. op_attr->set_type(obj.at(kType));
  266. if (imply_type == kTBE) {
  267. op_attr->set_value(obj.at(kValue));
  268. }
  269. if (obj.find(kDefaultValue) != obj.end()) {
  270. op_attr->set_default_value(obj.at(kDefaultValue));
  271. }
  272. op_info->add_attrs_ptr(op_attr);
  273. } catch (const std::exception &e) {
  274. MS_LOG(ERROR) << "DecodeAttr failed:" << e.what();
  275. ret = false;
  276. }
  277. return ret;
  278. }
  279. bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
  280. size_t index) {
  281. MS_EXCEPTION_IF_NULL(op_io);
  282. bool ret = true;
  283. try {
  284. std::vector<std::string> dtype;
  285. std::vector<std::string> format;
  286. for (const auto &it : dtype_format) {
  287. dtype.emplace_back(it[index][0]);
  288. format.emplace_back(it[index][1]);
  289. }
  290. op_io->set_dtypes(dtype);
  291. op_io->set_formats(format);
  292. } catch (const std::exception &e) {
  293. MS_LOG(ERROR) << "DecodeDtypeFormat failed" << e.what();
  294. ret = false;
  295. }
  296. return ret;
  297. }
  298. bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
  299. const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) {
  300. MS_EXCEPTION_IF_NULL(op_info);
  301. bool ret = true;
  302. try {
  303. std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
  304. MS_EXCEPTION_IF_NULL(op_io);
  305. op_io->set_index(obj.at(kIndex));
  306. op_io->set_name(obj.at(kName));
  307. if (!dtype_format.empty()) {
  308. if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) {
  309. MS_LOG(ERROR) << "Decode dtype format failed";
  310. return false;
  311. }
  312. } else {
  313. op_io->set_dtypes(obj.at(kDtype));
  314. op_io->set_formats(obj.at(kFormat));
  315. }
  316. if (op_io->dtypes().size() != op_io->formats().size()) {
  317. MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes()
  318. << " is not equal to format size: " << op_io->formats();
  319. return false;
  320. }
  321. if (obj.find(kParamType) != obj.end()) {
  322. op_io->set_param_type(obj.at(kParamType));
  323. }
  324. if (imply_type == kTBE) {
  325. if (obj.find(kNeedCompile) != obj.end()) {
  326. op_io->set_need_compile(obj.at(kNeedCompile));
  327. }
  328. if (obj.find(kShape) != obj.end()) {
  329. op_io->set_shape(obj.at(kShape));
  330. }
  331. if (obj.find(kReshapeType) != obj.end()) {
  332. op_io->set_reshape_type(obj.at(kReshapeType));
  333. }
  334. if (obj.find(kValueDepend) != obj.end()) {
  335. op_io->set_value_depend(obj.at(kValueDepend));
  336. }
  337. }
  338. if (io_type == kInput) {
  339. op_info->add_inputs_ptr(op_io);
  340. } else if (io_type == kOutput) {
  341. op_info->add_outputs_ptr(op_io);
  342. }
  343. } catch (const std::exception &e) {
  344. MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what();
  345. ret = false;
  346. }
  347. return ret;
  348. }
  349. std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type, bool is_dynamic_shape) {
  350. if (!OpLib::RegOpFromLocalInfo()) {
  351. MS_LOG(INFO) << "Warning reg local op info failed.";
  352. }
  353. auto context = MsContext::GetInstance();
  354. MS_EXCEPTION_IF_NULL(context);
  355. bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
  356. if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
  357. MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
  358. << ", current op num: " << op_info_.size();
  359. return nullptr;
  360. }
  361. std::string target_processor = is_gpu ? kCUDA : kAiCore;
  362. for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) {
  363. auto &op_info = (*iter).second;
  364. MS_EXCEPTION_IF_NULL(op_info);
  365. if (op_info->imply_type() != imply_type) {
  366. continue;
  367. }
  368. if (imply_type == kAKG && op_info->processor() != target_processor) {
  369. continue;
  370. }
  371. if (is_dynamic_shape && !op_info->dynamic_shape()) {
  372. continue;
  373. }
  374. return op_info;
  375. }
  376. MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
  377. << ", current op num: " << op_info_.size() << " is_dynamic_shape:" << is_dynamic_shape;
  378. return nullptr;
  379. }
  380. bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
  381. MS_EXCEPTION_IF_NULL(op_info);
  382. const auto &output_infos = op_info->outputs_ptr();
  383. const auto &input_infos = op_info->inputs_ptr();
  384. for (size_t out_index = 0; out_index < output_infos.size(); out_index++) {
  385. MS_EXCEPTION_IF_NULL(output_infos[out_index]);
  386. const auto &out_name = output_infos[out_index]->name();
  387. for (size_t in_index = 0; in_index < input_infos.size(); in_index++) {
  388. MS_EXCEPTION_IF_NULL(input_infos[in_index]);
  389. const auto &in_name = input_infos[in_index]->name();
  390. if (out_name == in_name) {
  391. if (op_info->has_ref_index(out_index)) {
  392. MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info";
  393. return false;
  394. }
  395. op_info->add_ref_pair(out_index, in_index);
  396. }
  397. }
  398. }
  399. return true;
  400. }
  401. bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
  402. MS_EXCEPTION_IF_NULL(op_info);
  403. for (auto [iter, end] = op_info_.equal_range(op_info->op_name()); iter != end; ++iter) {
  404. auto &exist_op_info = (*iter).second;
  405. MS_EXCEPTION_IF_NULL(exist_op_info);
  406. if (exist_op_info->equals_to(op_info)) {
  407. return true;
  408. }
  409. }
  410. return false;
  411. }
  412. } // namespace kernel
  413. } // namespace mindspore