You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

acl_graph_parser_util.cc 37 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. * http://www.apache.org/licenses/LICENSE-2.0
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. * See the License for the specific language governing permissions and
  11. * limitations under the License.
  12. */
  13. #include "parser/common/acl_graph_parser_util.h"
  14. #include <dlfcn.h>
  15. #include <regex.h>
  16. #include <cstdlib>
  17. #include <ctime>
  18. #include <fstream>
  19. #include "common/string_util.h"
  20. #include "common/util.h"
  21. #include "common/util/error_manager/error_manager.h"
  22. #include "external/ge/ge_api_types.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "framework/omg/parser/parser_types.h"
  25. #include "ge/ge_api_types.h"
  26. #include "google/protobuf/io/coded_stream.h"
  27. #include "google/protobuf/io/zero_copy_stream_impl.h"
  28. #include "graph/debug/ge_attr_define.h"
  29. #include "graph/opsproto_manager.h"
  30. #include "graph/utils/type_utils.h"
  31. #include "omg/parser/parser_inner_ctx.h"
  32. #include "parser/common/register_tbe.h"
  33. #include "tbe_plugin_loader.h"
  34. using google::protobuf::io::CodedInputStream;
  35. using google::protobuf::io::FileInputStream;
  36. using google::protobuf::io::ZeroCopyInputStream;
  37. using namespace ge::parser;
  38. namespace {
  39. const std::string kGraphDefaultName = "domi_default";
  40. /// The maximum length of the file.
  41. /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
  42. const int kMaxFileSizeLimit = INT_MAX;
  43. const int kMaxBuffSize = 256;
  44. const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
  45. const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M
  46. static string GetSoPath() {
  47. Dl_info dl_info;
  48. if (dladdr(reinterpret_cast<void *>(&GetSoPath), &dl_info) == 0) {
  49. GELOGW("Failed to read so_path!");
  50. return string();
  51. } else {
  52. std::string so_path = dl_info.dli_fname;
  53. char path[PATH_MAX] = {0};
  54. if (so_path.length() >= PATH_MAX) {
  55. GELOGW("File path is too long!");
  56. return string();
  57. }
  58. if (realpath(so_path.c_str(), path) == nullptr) {
  59. GELOGW("Failed to get realpath of %s", so_path.c_str());
  60. return string();
  61. }
  62. so_path = path;
  63. so_path = so_path.substr(0, so_path.rfind('/') + 1);
  64. return so_path;
  65. }
  66. }
  67. static void GetOpsProtoPath(string &opsproto_path) {
  68. GELOGD("Start to get ops proto path schedule.");
  69. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  70. if (path_env != nullptr) {
  71. string path = path_env;
  72. string file_path = ge::parser::RealPath(path.c_str());
  73. if (file_path.empty()) {
  74. GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str());
  75. return;
  76. }
  77. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  78. GELOGI("Get opsproto so path from env : %s", path.c_str());
  79. return;
  80. }
  81. string path_base = GetSoPath();
  82. GELOGI("path_base is %s", path_base.c_str());
  83. path_base = path_base.substr(0, path_base.rfind('/'));
  84. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  85. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  86. }
  87. static void GetAclParams(const std::map<ge::AscendString, ge::AscendString> &parser_params, const string &key,
  88. string &value) {
  89. for (auto &ele : parser_params) {
  90. const char *key_ascend = ele.first.GetString();
  91. if (key_ascend == nullptr) {
  92. GELOGW("Input options key is null, Please check!");
  93. continue;
  94. }
  95. string key_str = key_ascend;
  96. if (key == key_str) {
  97. const char *value_ascend = ele.second.GetString();
  98. if (value_ascend == nullptr) {
  99. value = "";
  100. } else {
  101. value = value_ascend;
  102. }
  103. return;
  104. }
  105. }
  106. value = "";
  107. return;
  108. }
  109. static bool CheckDigitStr(std::string &str) {
  110. for (char c : str) {
  111. if (!isdigit(c)) {
  112. GELOGE(domi::FAILED, "Value[%s] is not positive integer", str.c_str());
  113. return false;
  114. }
  115. }
  116. return true;
  117. }
  118. } // namespace
  119. namespace ge {
  120. static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) {
  121. if ((s == "true") || (s == "false")) {
  122. return true;
  123. } else {
  124. ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s});
  125. GELOGE(PARAM_INVALID, "Input parameter[%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str());
  126. return false;
  127. }
  128. }
  129. static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) {
  130. int32_t out_size = op_desc->GetOutputsSize();
  131. if (index < 0 || index >= out_size) {
  132. GELOGE(domi::FAILED,
  133. "out_node [%s] output index:%d must be smaller "
  134. "than node output size:%d and can not be negative!",
  135. op_desc->GetName().c_str(), index, out_size);
  136. std::string fail_reason = "output index:" + to_string(index) +
  137. " must be smaller than output size:" + to_string(out_size) + " and can not be negative!";
  138. ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"},
  139. {"out_nodes", op_desc->GetName(), fail_reason});
  140. return domi::FAILED;
  141. }
  142. return domi::SUCCESS;
  143. }
  144. domi::Status AclGrphParseUtil::LoadOpsProtoLib() {
  145. string opsproto_path;
  146. GetOpsProtoPath(opsproto_path);
  147. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  148. OpsProtoManager *manager = OpsProtoManager::Instance();
  149. map<string, string> option_tmp;
  150. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  151. bool is_proto_init = manager->Initialize(option_tmp);
  152. if (!is_proto_init) {
  153. GELOGE(FAILED, "Load ops_proto lib failed, ops proto path is invalid.");
  154. return FAILED;
  155. }
  156. return SUCCESS;
  157. }
  158. void AclGrphParseUtil::SaveCustomCaffeProtoPath() {
  159. GELOGD("Enter save custom caffe proto path.");
  160. std::string path_base = GetSoPath();
  161. path_base = path_base.substr(0, path_base.rfind('/'));
  162. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  163. ge::GetParserContext().caffe_proto_path = path_base + "include/proto/";
  164. string custom_op_path;
  165. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  166. if (path_env != nullptr) {
  167. std::string path = path_env;
  168. custom_op_path = path + "/framework/custom/caffe/";
  169. GELOGI("Get custom proto path from env : %s", path_env);
  170. GetParserContext().custom_proto_path = custom_op_path;
  171. return;
  172. }
  173. custom_op_path = path_base + "ops/framework/custom/caffe/";
  174. ge::GetParserContext().custom_proto_path = custom_op_path;
  175. return;
  176. }
  177. // Initialize PARSER, load custom op plugin
  178. // options will be used later for parser decoupling
  179. domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, std::string> &options) {
  180. GELOGT(TRACE_INIT, "AclParserInitialize start");
  181. // check init status
  182. if (parser_initialized) {
  183. GELOGW("AclParserInitialize is called more than once");
  184. return SUCCESS;
  185. }
  186. // load custom op plugin
  187. TBEPluginLoader::Instance().LoadPluginSo(options);
  188. // load and save custom op proto for prediction
  189. (void)LoadOpsProtoLib();
  190. SaveCustomCaffeProtoPath();
  191. auto op_registry = domi::OpRegistry::Instance();
  192. if (op_registry == nullptr) {
  193. GELOGE(FAILED, "Get OpRegistry instance failed");
  194. return FAILED;
  195. }
  196. auto it = options.find(ge::FRAMEWORK_TYPE);
  197. if (it == options.end()) {
  198. GELOGE(FAILED, "Can not find ge.frameworkType in options");
  199. return FAILED;
  200. }
  201. std::string fmk_type = it->second;
  202. std::vector<OpRegistrationData> registrationDatas = op_registry->registrationDatas;
  203. GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size());
  204. for (OpRegistrationData &reg_data : registrationDatas) {
  205. if (std::to_string(reg_data.GetFrameworkType()) == fmk_type) {
  206. (void)OpRegistrationTbe::Instance()->Finalize(reg_data, false);
  207. (void)domi::OpRegistry::Instance()->Register(reg_data);
  208. }
  209. }
  210. // set init status
  211. if (!parser_initialized) {
  212. // Initialize success, first time calling initialize
  213. parser_initialized = true;
  214. }
  215. GELOGT(TRACE_STOP, "AclParserInitialize finished");
  216. return SUCCESS;
  217. }
  218. void AclGrphParseUtil::SetDefaultFormat() {
  219. if (ge::GetParserContext().type == domi::TENSORFLOW) {
  220. ge::GetParserContext().format = domi::DOMI_TENSOR_NHWC;
  221. } else {
  222. ge::GetParserContext().format = domi::DOMI_TENSOR_NCHW;
  223. }
  224. }
  225. domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
  226. try {
  227. // parse output node
  228. if (!out_nodes.empty()) {
  229. ge::GetParserContext().out_nodes_map.clear();
  230. ge::GetParserContext().user_out_nodes.clear();
  231. ge::GetParserContext().user_out_nodes_top_vec.clear();
  232. vector<string> nodes_v = StringUtils::Split(out_nodes, ';');
  233. for (const string &node : nodes_v) {
  234. vector<string> key_value_v = StringUtils::Split(node, ':');
  235. if (key_value_v.size() != 2) { // The size must be 2.
  236. if (key_value_v.size() == 1 && ge::GetParserContext().type == domi::CAFFE) {
  237. ge::GetParserContext().user_out_nodes_top_vec.push_back(node);
  238. continue;
  239. }
  240. ErrorManager::GetInstance().ATCReportErrMessage(
  241. "E10001", {"parameter", "value", "reason"},
  242. {"out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""});
  243. GELOGE(PARAM_INVALID,
  244. "The input format of out_nodes is invalid, the correct format is "
  245. "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.",
  246. node.c_str());
  247. return PARAM_INVALID;
  248. }
  249. if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) {
  250. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  251. {"out_nodes", out_nodes, "is not all index or top_name"});
  252. GELOGE(PARAM_INVALID, "This out_nodes str must be all index or top_name, while the actual input is %s",
  253. out_nodes.c_str());
  254. return PARAM_INVALID;
  255. }
  256. // stoi: The method may throw an exception: invalid_argument/out_of_range
  257. if (!CheckDigitStr(key_value_v[1])) {
  258. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  259. {"out_nodes", out_nodes, "is not positive integer"});
  260. GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str());
  261. return PARAM_INVALID;
  262. }
  263. auto iter = ge::GetParserContext().out_nodes_map.find(key_value_v[0]);
  264. int32_t index = stoi(StringUtils::Trim(key_value_v[1]));
  265. GELOGD("Get output info: node[%s] and index[%d]", key_value_v[0].c_str(), index);
  266. if (iter != ge::GetParserContext().out_nodes_map.end()) {
  267. iter->second.emplace_back(index);
  268. } else {
  269. std::vector<int32_t> index_v;
  270. index_v.emplace_back(index);
  271. ge::GetParserContext().out_nodes_map.emplace(key_value_v[0], index_v);
  272. }
  273. ge::GetParserContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index));
  274. }
  275. }
  276. } catch (std::invalid_argument &) {
  277. GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
  278. ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes});
  279. return PARAM_INVALID;
  280. } catch (std::out_of_range &) {
  281. GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
  282. ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes});
  283. return PARAM_INVALID;
  284. }
  285. return SUCCESS;
  286. }
  287. domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_output_fp16) {
  288. if (is_output_fp16.empty()) {
  289. return SUCCESS;
  290. }
  291. vector<domiTensorFormat_t> &output_formats = ge::GetParserContext().output_formats;
  292. output_formats.clear();
  293. vector<string> node_format_vec = StringUtils::Split(is_output_fp16, ',');
  294. for (auto &is_fp16 : node_format_vec) {
  295. StringUtils::Trim(is_fp16);
  296. if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) {
  297. GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]",
  298. is_output_fp16.c_str());
  299. return PARAM_INVALID;
  300. }
  301. if (is_fp16 == "false") {
  302. output_formats.push_back(DOMI_TENSOR_ND);
  303. } else if (is_fp16 == "true") {
  304. output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0);
  305. }
  306. }
  307. return SUCCESS;
  308. }
  309. domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) {
  310. ge::GetParserContext().enable_scope_fusion_passes.clear();
  311. if (enable_scope_fusion_passes.empty()) {
  312. return SUCCESS;
  313. }
  314. ge::GetParserContext().enable_scope_fusion_passes = enable_scope_fusion_passes;
  315. return SUCCESS;
  316. }
  317. void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec,
  318. const string &fp16_nodes_name, uint32_t index, OpDescPtr &op_desc) {
  319. if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) {
  320. if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) {
  321. GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str());
  322. if (!AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_FORMAT, TypeUtils::FormatToSerialString(FORMAT_NC1HWC0))) {
  323. GELOGW("This node [%s] set NC1HWC0 failed", fp16_nodes_name.c_str());
  324. }
  325. }
  326. }
  327. }
  328. domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
  329. const string &is_input_adjust_hw_layout) {
  330. GE_CHECK_NOTNULL(graph);
  331. vector<string> adjust_fp16_format_vec;
  332. if (!is_input_adjust_hw_layout.empty()) {
  333. adjust_fp16_format_vec = StringUtils::Split(is_input_adjust_hw_layout, ',');
  334. for (auto &s : adjust_fp16_format_vec) {
  335. StringUtils::Trim(s);
  336. if (!CheckInputTrueOrFalse(s, "is_input_adjust_hw_layout")) {
  337. GELOGE(PARAM_INVALID, "Invalid Param, is_input_adjust_hw_layout only support true/false: but is [%s]",
  338. is_input_adjust_hw_layout.c_str());
  339. return PARAM_INVALID;
  340. }
  341. }
  342. }
  343. if (input_fp16_nodes.empty()) {
  344. return SUCCESS;
  345. }
  346. GELOGI("The input_fp16_nodes is set %s", input_fp16_nodes.c_str());
  347. vector<string> input_fp16_nodes_vec = StringUtils::Split(input_fp16_nodes, ';');
  348. for (uint32_t i = 0; i < input_fp16_nodes_vec.size(); ++i) {
  349. ge::NodePtr node = graph->FindNode(input_fp16_nodes_vec[i]);
  350. if (node == nullptr) {
  351. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  352. {"input_fp16_nodes", input_fp16_nodes_vec[i]});
  353. GELOGE(PARAM_INVALID, "Input parameter[input_fp16_nodes]'s opname[%s] is not exist in model",
  354. input_fp16_nodes_vec[i].c_str());
  355. return PARAM_INVALID;
  356. }
  357. auto op_desc = node->GetOpDesc();
  358. GE_CHECK_NOTNULL(op_desc);
  359. if (op_desc->GetType() != ge::parser::DATA) {
  360. ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"},
  361. {"input_fp16_nodes", input_fp16_nodes_vec[i]});
  362. GELOGE(PARAM_INVALID, "Input parameter[input_fp16_nodes]'s opname[%s] is not a input opname",
  363. input_fp16_nodes_vec[i].c_str());
  364. return PARAM_INVALID;
  365. }
  366. AddAttrsForInputNodes(adjust_fp16_format_vec, input_fp16_nodes_vec[i], i, op_desc);
  367. }
  368. return SUCCESS;
  369. }
  370. void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
  371. std::vector<std::string> &output_nodes_name) {
  372. output_nodes_name.clear();
  373. if (ge::GetParserContext().out_top_names.empty()) {
  374. // tf process, no top name.
  375. for (const auto output_node_info : output_nodes_info) {
  376. std::string node_name = output_node_info.first->GetName();
  377. int32_t index = output_node_info.second;
  378. output_nodes_name.push_back(node_name + ":" + std::to_string(index));
  379. }
  380. return;
  381. }
  382. // caffe process, need add top name after node_name:index
  383. for (size_t i = 0; i < output_nodes_info.size(); ++i) {
  384. std::string node_name = output_nodes_info[i].first->GetName();
  385. int32_t index = output_nodes_info[i].second;
  386. if (i < ge::GetParserContext().out_top_names.size()) {
  387. output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" +
  388. ge::GetParserContext().out_top_names[i]);
  389. } else {
  390. GELOGW("Get top name of node [%s] fail.", node_name.c_str());
  391. output_nodes_name.push_back(node_name + ":" + std::to_string(index));
  392. }
  393. }
  394. }
  395. domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
  396. std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
  397. ge::OpDescPtr tmpDescPtr = node->GetOpDesc();
  398. if (tmpDescPtr == nullptr) {
  399. GELOGE(domi::FAILED, "Get outnode op desc fail.");
  400. return domi::FAILED;
  401. }
  402. size_t size = tmpDescPtr->GetOutputsSize();
  403. if (node->GetType() != ge::parser::NETOUTPUT) {
  404. for (size_t index = 0; index < size; ++index) {
  405. output_nodes_info.push_back(std::make_pair(node, index));
  406. GELOGD("Get output leaf node:%s.", node->GetName().c_str());
  407. }
  408. } else {
  409. const auto in_anchors = node->GetAllInDataAnchors();
  410. for (auto in_anchor : in_anchors) {
  411. auto out_anchor = in_anchor->GetPeerOutAnchor();
  412. if (out_anchor == nullptr) {
  413. GELOGE(domi::FAILED, "Get leaf node op desc fail.");
  414. return domi::FAILED;
  415. }
  416. auto out_node = out_anchor->GetOwnerNode();
  417. output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx()));
  418. }
  419. }
  420. return SUCCESS;
  421. }
  422. domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
  423. std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
  424. std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes;
  425. if (ge::GetParserContext().type == domi::CAFFE && !default_out_nodes.empty()) {
  426. for (uint32_t i = 0; i < default_out_nodes.size(); ++i) {
  427. ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first);
  428. if (out_node == nullptr) {
  429. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  430. {"out_nodes", default_out_nodes[i].first});
  431. GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", default_out_nodes[i].first.c_str());
  432. return domi::FAILED;
  433. }
  434. output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second));
  435. GELOGD("Get default output node:%s.", out_node->GetName().c_str());
  436. }
  437. return domi::SUCCESS;
  438. }
  439. for (ge::NodePtr node : compute_graph->GetDirectNode()) {
  440. if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) {
  441. Status ret = GetOutputLeaf(node, output_nodes_info);
  442. GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Find leaf fail.");
  443. }
  444. }
  445. return domi::SUCCESS;
  446. }
  447. domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
  448. const std::map<AscendString, AscendString> &parser_params) {
  449. ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  450. GE_CHECK_NOTNULL(compute_graph);
  451. std::vector<std::pair<std::string, int32_t>> user_out_nodes = ge::GetParserContext().user_out_nodes;
  452. std::vector<domiTensorFormat_t> output_formats = ge::GetParserContext().output_formats;
  453. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info;
  454. std::vector<std::string> output_nodes_name;
  455. // User declared outputs
  456. for (uint32_t i = 0; i < user_out_nodes.size(); ++i) {
  457. ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first);
  458. if (out_node == nullptr) {
  459. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  460. {"out_nodes", user_out_nodes[i].first});
  461. GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str());
  462. return domi::FAILED;
  463. }
  464. auto op_desc = out_node->GetOpDesc();
  465. GE_CHECK_NOTNULL(op_desc);
  466. if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) {
  467. GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str());
  468. return domi::FAILED;
  469. }
  470. // add user_define_output_nodes attr.
  471. (void)ge::AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_OUTPUT_NODES, "true");
  472. if (i < output_formats.size()) {
  473. if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) {
  474. GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str());
  475. vector<string> output_fp16_5hd_vec;
  476. (void)ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec);
  477. output_fp16_5hd_vec.push_back(std::to_string(user_out_nodes[i].second) + ":" + "NC1HWC0");
  478. (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec);
  479. }
  480. }
  481. output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second));
  482. }
  483. // default output node (leaf)
  484. if (user_out_nodes.empty()) {
  485. if (GetDefaultOutInfo(compute_graph, output_nodes_info) != SUCCESS) {
  486. GELOGE(domi::FAILED, "Get default output info failed.");
  487. return domi::FAILED;
  488. }
  489. }
  490. GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name);
  491. compute_graph->SetGraphOutNodesInfo(output_nodes_info);
  492. ge::GetParserContext().net_out_nodes = output_nodes_name;
  493. GELOGI("Set graph %s output node success.", graph.GetName().c_str());
  494. return domi::SUCCESS;
  495. }
  496. domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) {
  497. for (auto &ele : parser_params) {
  498. const char *key_ascend = ele.first.GetString();
  499. if (key_ascend == nullptr) {
  500. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
  501. {"parser_params", "null AscendString"});
  502. GELOGE(PARAM_INVALID, "Input options key is null, Please check!");
  503. return PARAM_INVALID;
  504. }
  505. string key_str = key_ascend;
  506. auto it = ge::ir_option::ir_parser_suppported_options.find(key_str);
  507. if (it == ge::ir_option::ir_parser_suppported_options.end()) {
  508. ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"parser_params", key_str});
  509. GELOGE(PARAM_INVALID, "Input options include unsupported option(%s).Please check!", key_ascend);
  510. return PARAM_INVALID;
  511. }
  512. }
  513. return SUCCESS;
  514. }
  515. domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
  516. string &graph_name) {
  517. GELOGI("Parse graph user options start.");
  518. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, return PARAM_INVALID,
  519. "Parse paragrams invalid.");
  520. // support paragrams: out_nodes, is_output_adjust_hw_layout, output, enable_scope_fusion_passes
  521. SetDefaultFormat();
  522. string out_nodes;
  523. GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes);
  524. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputNodes(out_nodes) != SUCCESS, return PARAM_INVALID,
  525. "Parse out_nodes failed");
  526. string is_output_adjust_hw_layout;
  527. GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout);
  528. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS,
  529. return PARAM_INVALID, "Parse is_output_adjust_hw_layout failed");
  530. string tmp_name;
  531. GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name);
  532. graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + ge::parser::CurrentTimeInStr()) : tmp_name;
  533. string enable_scope_fusion_passes;
  534. GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes);
  535. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS, return PARAM_INVALID,
  536. "Parse enable_scope_fusion_passes failed");
  537. return SUCCESS;
  538. }
  539. domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
  540. const std::map<AscendString, AscendString> &parser_params) {
  541. // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout,
  542. ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
  543. GE_CHECK_NOTNULL(compute_graph);
  544. string input_fp16_nodes;
  545. GetAclParams(parser_params, ge::ir_option::INPUT_FP16_NODES, input_fp16_nodes);
  546. string is_input_adjust_hw_layout;
  547. GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout);
  548. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  549. ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS,
  550. return PARAM_INVALID, "Parse input_fp16_nodes failed");
  551. return SUCCESS;
  552. }
  553. namespace parser {
  554. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) {
  555. if (path == nullptr) {
  556. GELOGE(ge::FAILED, "path pointer is NULL.");
  557. return "";
  558. }
  559. if (strlen(path) >= PATH_MAX) {
  560. ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)});
  561. GELOGE(ge::FAILED, "Path[%s] len is too long, it must be less than %d", path, PATH_MAX);
  562. return "";
  563. }
  564. // Nullptr is returned when the path does not exist or there is no permission
  565. // Return absolute path when path is accessible
  566. std::string res;
  567. char resolved_path[PATH_MAX] = {0};
  568. if (realpath(path, resolved_path) != nullptr) {
  569. res = resolved_path;
  570. }
  571. return res;
  572. }
  573. // Get file length
  574. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) {
  575. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null.");
  576. std::string real_path = RealPath(input_file.c_str());
  577. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str());
  578. unsigned long long file_length = 0;
  579. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
  580. ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"},
  581. {input_file, strerror(errno)});
  582. return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));
  583. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0),
  584. ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file});
  585. return -1, "File[%s] size is 0, not valid.", input_file.c_str());
  586. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit,
  587. ErrorManager::GetInstance().ATCReportErrMessage(
  588. "E19016", {"filepath", "filesize", "maxlen"},
  589. {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
  590. return -1, "File[%s] size %lld is out of limit: %d.",
  591. input_file.c_str(), file_length, kMaxFileSizeLimit);
  592. return static_cast<long>(file_length);
  593. }
  594. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() {
  595. struct timeval tv{};
  596. int ret = gettimeofday(&tv, nullptr);
  597. GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret);
  598. auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds
  599. return static_cast<uint64_t>(total_use_time);
  600. }
  601. static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) {
  602. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr,
  603. return false, "incorrect parameter. nullptr == proto");
  604. coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit, kWarningThreshold);
  605. return proto->ParseFromCodedStream(&coded_stream);
  606. }
  607. /** @ingroup domi_common
  608. * @brief Read all data from binary file
  609. * @param [in] file_name File path
  610. * @param [out] buffer The address of the output memory, which needs to be released by the caller
  611. * @param [out] length Output memory size
  612. * @return false fail
  613. * @return true success
  614. */
  615. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer,
  616. int &length) {
  617. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr");
  618. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr");
  619. std::string real_path = RealPath(file_name);
  620. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "file path '%s' not valid", file_name);
  621. std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate);
  622. if (!file.is_open()) {
  623. GELOGE(ge::FAILED, "Read file %s failed.", file_name);
  624. return false;
  625. }
  626. length = static_cast<int>(file.tellg());
  627. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); return false, "file length <= 0");
  628. file.seekg(0, std::ios::beg);
  629. *buffer = new(std::nothrow) char[length]();
  630. GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(), "new an object failed.");
  631. file.read(*buffer, length);
  632. file.close();
  633. return true;
  634. }
  635. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) {
  636. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr),
  637. return false,
  638. "Input parameter file or proto is nullptr!");
  639. std::string real_path = RealPath(file);
  640. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
  641. return false, "pb file path '%s' not valid", file);
  642. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");
  643. std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary);
  644. if (!fs.is_open()) {
  645. ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"});
  646. GELOGE(ge::FAILED, "Open real path[%s] failed.", file);
  647. return false;
  648. }
  649. google::protobuf::io::IstreamInputStream istream(&fs);
  650. google::protobuf::io::CodedInputStream coded_stream(&istream);
  651. bool ret = ReadProtoFromCodedInputStream(coded_stream, proto);
  652. fs.close();
  653. if (!ret) {
  654. ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file});
  655. GELOGE(ge::FAILED, "Parse file[%s] failed.", file);
  656. return ret;
  657. }
  658. return ret;
  659. }
  660. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) {
  661. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false,
  662. "incorrect parameter. proto is nullptr || data is nullptr || size is 0");
  663. google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size);
  664. return ReadProtoFromCodedInputStream(coded_stream, proto);
  665. }
  666. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file,
  667. google::protobuf::Message *message) {
  668. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false,
  669. "incorrect parameter. nullptr == file || nullptr == message");
  670. std::string real_path = RealPath(file);
  671. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
  672. ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"},
  673. {file, strerror(errno)});
  674. return false, "Path[%s]'s realpath is empty, errmsg[%s]", file,
  675. strerror(errno));
  676. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");
  677. std::ifstream fs(real_path.c_str(), std::ifstream::in);
  678. if (!fs.is_open()) {
  679. ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file});
  680. GELOGE(ge::FAILED,
  681. "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file);
  682. return false;
  683. }
  684. google::protobuf::io::IstreamInputStream input(&fs);
  685. bool ret = google::protobuf::TextFormat::Parse(&input, message);
  686. GE_IF_BOOL_EXEC(!ret,
  687. ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file});
  688. GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, "
  689. "please check whether the file is a valid protobuf format file.", file));
  690. fs.close();
  691. return ret;
  692. }
  693. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size,
  694. google::protobuf::Message *message) {
  695. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false,
  696. "incorrect parameter. data is nullptr || message is nullptr");
  697. std::string str(data, static_cast<size_t>(size));
  698. std::istringstream fs(str);
  699. google::protobuf::io::IstreamInputStream input(&fs);
  700. bool ret = google::protobuf::TextFormat::Parse(&input, message);
  701. GE_IF_BOOL_EXEC(
  702. !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));
  703. return ret;
  704. }
  705. ///
  706. /// @brief get the Original Type of FrameworkOp
  707. /// @param [in] node
  708. /// @param [out] type
  709. /// @return Status
  710. ///
  711. Status GetOriginalType(const ge::NodePtr &node, string &type) {
  712. GE_CHECK_NOTNULL(node);
  713. type = node->GetType();
  714. GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
  715. GE_CHECK_NOTNULL(node->GetOpDesc());
  716. bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  717. if (!ret) {
  718. GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str());
  719. return INTERNAL_ERROR;
  720. }
  721. GELOGD("Get FrameWorkOp original type [%s]", type.c_str());
  722. return SUCCESS;
  723. }
  724. FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) {
  725. char ebuff[kMaxBuffSize];
  726. regex_t reg;
  727. int cflags = REG_EXTENDED | REG_NOSUB;
  728. int ret = regcomp(&reg, mode.c_str(), cflags);
  729. if (ret) {
  730. regerror(ret, &reg, ebuff, kMaxBuffSize);
  731. GELOGW("regcomp failed, reason: %s", ebuff);
  732. regfree(&reg);
  733. return true;
  734. }
  735. ret = regexec(&reg, str.c_str(), 0, nullptr, 0);
  736. if (ret) {
  737. regerror(ret, &reg, ebuff, kMaxBuffSize);
  738. GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff);
  739. regfree(&reg);
  740. return false;
  741. }
  742. regfree(&reg);
  743. return true;
  744. }
  745. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() {
  746. std::time_t now = std::time(nullptr);
  747. std::tm *ptm = std::localtime(&now);
  748. if (ptm == nullptr) {
  749. GELOGE(ge::FAILED, "Localtime failed.");
  750. return "";
  751. }
  752. const int kTimeBufferLen = 32;
  753. char buffer[kTimeBufferLen + 1] = {0};
  754. // format: 20171122042550
  755. std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm);
  756. return std::string(buffer);
  757. }
  758. } // namespace parser
  759. } // namespace ge