You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialization.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/api/serialization.h"
  17. #include <fstream>
  18. #include <sstream>
  19. #include "cxx_api/graph/graph_data.h"
  20. #include "utils/log_adapter.h"
  21. #include "mindspore/core/load_mindir/load_model.h"
  22. #include "utils/crypto.h"
  23. namespace mindspore {
  24. static Status RealPath(const std::string &file, std::string *realpath_str) {
  25. MS_EXCEPTION_IF_NULL(realpath_str);
  26. char real_path_mem[PATH_MAX] = {0};
  27. char *real_path_ret = nullptr;
  28. #if defined(_WIN32) || defined(_WIN64)
  29. real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX);
  30. #else
  31. real_path_ret = realpath(common::SafeCStr(file), real_path_mem);
  32. #endif
  33. if (real_path_ret == nullptr) {
  34. return Status(kMEInvalidInput, "File: " + file + " does not exist.");
  35. }
  36. *realpath_str = real_path_mem;
  37. return kSuccess;
  38. }
  39. static Buffer ReadFile(const std::string &file) {
  40. Buffer buffer;
  41. if (file.empty()) {
  42. MS_LOG(ERROR) << "Pointer file is nullptr";
  43. return buffer;
  44. }
  45. std::string real_path;
  46. auto status = RealPath(file, &real_path);
  47. if (status != kSuccess) {
  48. MS_LOG(ERROR) << status.GetErrDescription();
  49. return buffer;
  50. }
  51. std::ifstream ifs(real_path);
  52. if (!ifs.good()) {
  53. MS_LOG(ERROR) << "File: " << real_path << " does not exist";
  54. return buffer;
  55. }
  56. if (!ifs.is_open()) {
  57. MS_LOG(ERROR) << "File: " << real_path << " open failed";
  58. return buffer;
  59. }
  60. (void)ifs.seekg(0, std::ios::end);
  61. size_t size = static_cast<size_t>(ifs.tellg());
  62. buffer.ResizeData(size);
  63. if (buffer.DataSize() != size) {
  64. MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path;
  65. ifs.close();
  66. return buffer;
  67. }
  68. (void)ifs.seekg(0, std::ios::beg);
  69. (void)ifs.read(reinterpret_cast<char *>(buffer.MutableData()), static_cast<std::streamsize>(size));
  70. ifs.close();
  71. return buffer;
  72. }
  73. Key::Key(const char *dec_key, size_t key_len) {
  74. len = 0;
  75. if (key_len >= max_key_len) {
  76. MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
  77. return;
  78. }
  79. auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len);
  80. if (sec_ret != EOK) {
  81. MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret;
  82. return;
  83. }
  84. len = key_len;
  85. }
  86. Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
  87. const Key &dec_key, const std::vector<char> &dec_mode) {
  88. std::stringstream err_msg;
  89. if (graph == nullptr) {
  90. err_msg << "Output args graph is nullptr.";
  91. MS_LOG(ERROR) << err_msg.str();
  92. return Status(kMEInvalidInput, err_msg.str());
  93. }
  94. if (model_type == kMindIR) {
  95. FuncGraphPtr anf_graph = nullptr;
  96. try {
  97. if (dec_key.len > dec_key.max_key_len) {
  98. err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
  99. MS_LOG(ERROR) << err_msg.str();
  100. return Status(kMEInvalidInput, err_msg.str());
  101. } else if (dec_key.len == 0) {
  102. if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) {
  103. err_msg << "Load model failed. The model_data may be encrypted, please pass in correct key.";
  104. MS_LOG(ERROR) << err_msg.str();
  105. return Status(kMEInvalidInput, err_msg.str());
  106. } else {
  107. anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
  108. }
  109. } else {
  110. size_t plain_data_size;
  111. auto plain_data = mindspore::Decrypt(&plain_data_size, reinterpret_cast<const unsigned char *>(model_data),
  112. data_size, dec_key.key, dec_key.len, CharToString(dec_mode));
  113. if (plain_data == nullptr) {
  114. err_msg << "Load model failed. Please check the valid of dec_key and dec_mode.";
  115. MS_LOG(ERROR) << err_msg.str();
  116. return Status(kMEInvalidInput, err_msg.str());
  117. }
  118. anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(plain_data.get()), plain_data_size);
  119. }
  120. } catch (const std::exception &) {
  121. err_msg << "Load model failed. Please check the valid of dec_key and dec_mode.";
  122. MS_LOG(ERROR) << err_msg.str();
  123. return Status(kMEInvalidInput, err_msg.str());
  124. }
  125. *graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
  126. return kSuccess;
  127. } else if (model_type == kOM) {
  128. *graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
  129. return kSuccess;
  130. }
  131. err_msg << "Unsupported ModelType " << model_type;
  132. MS_LOG(ERROR) << err_msg.str();
  133. return Status(kMEInvalidInput, err_msg.str());
  134. }
  135. Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
  136. return Load(file, model_type, graph, Key{}, StringToChar(kDecModeAesGcm));
  137. }
  138. Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
  139. const std::vector<char> &dec_mode) {
  140. std::stringstream err_msg;
  141. if (graph == nullptr) {
  142. err_msg << "Output args graph is nullptr.";
  143. MS_LOG(ERROR) << err_msg.str();
  144. return Status(kMEInvalidInput, err_msg.str());
  145. }
  146. std::string file_path;
  147. auto status = RealPath(CharToString(file), &file_path);
  148. if (status != kSuccess) {
  149. MS_LOG(ERROR) << status.GetErrDescription();
  150. return status;
  151. }
  152. if (model_type == kMindIR) {
  153. FuncGraphPtr anf_graph;
  154. if (dec_key.len > dec_key.max_key_len) {
  155. err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
  156. MS_LOG(ERROR) << err_msg.str();
  157. return Status(kMEInvalidInput, err_msg.str());
  158. } else if (dec_key.len == 0 && IsCipherFile(file_path)) {
  159. err_msg << "Load model failed. The file may be encrypted, please pass in correct key.";
  160. MS_LOG(ERROR) << err_msg.str();
  161. return Status(kMEInvalidInput, err_msg.str());
  162. } else {
  163. anf_graph =
  164. LoadMindIR(file_path, false, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode));
  165. }
  166. if (anf_graph == nullptr) {
  167. err_msg << "Load model failed. Please check the valid of dec_key and dec_mode";
  168. MS_LOG(ERROR) << err_msg.str();
  169. return Status(kMEInvalidInput, err_msg.str());
  170. }
  171. *graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
  172. return kSuccess;
  173. } else if (model_type == kOM) {
  174. Buffer data = ReadFile(file_path);
  175. if (data.Data() == nullptr) {
  176. err_msg << "Read file " << file_path << " failed.";
  177. MS_LOG(ERROR) << err_msg.str();
  178. return Status(kMEInvalidInput, err_msg.str());
  179. }
  180. *graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
  181. return kSuccess;
  182. }
  183. err_msg << "Unsupported ModelType " << model_type;
  184. MS_LOG(ERROR) << err_msg.str();
  185. return Status(kMEInvalidInput, err_msg.str());
  186. }
  187. Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
  188. std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
  189. std::stringstream err_msg;
  190. if (graphs == nullptr) {
  191. err_msg << "Output args graph is nullptr.";
  192. MS_LOG(ERROR) << err_msg.str();
  193. return Status(kMEInvalidInput, err_msg.str());
  194. }
  195. if (files.size() == 1) {
  196. std::vector<Graph> result(files.size());
  197. auto ret = Load(files[0], model_type, &result[0], dec_key, dec_mode);
  198. *graphs = std::move(result);
  199. return ret;
  200. }
  201. std::vector<std::string> files_path;
  202. for (const auto &file : files) {
  203. std::string file_path;
  204. auto status = RealPath(CharToString(file), &file_path);
  205. if (status != kSuccess) {
  206. MS_LOG(ERROR) << status.GetErrDescription();
  207. return status;
  208. }
  209. files_path.emplace_back(std::move(file_path));
  210. }
  211. if (model_type == kMindIR) {
  212. if (dec_key.len > dec_key.max_key_len) {
  213. err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
  214. MS_LOG(ERROR) << err_msg.str();
  215. return Status(kMEInvalidInput, err_msg.str());
  216. }
  217. auto anf_graphs =
  218. LoadMindIRs(files_path, false, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode));
  219. if (anf_graphs.size() != files_path.size()) {
  220. err_msg << "Load model failed, " << files_path.size() << " files got " << anf_graphs.size() << " graphs.";
  221. MS_LOG(ERROR) << err_msg.str();
  222. return Status(kMEInvalidInput, err_msg.str());
  223. }
  224. std::vector<Graph> results;
  225. for (size_t i = 0; i < anf_graphs.size(); ++i) {
  226. if (anf_graphs[i] == nullptr) {
  227. if (dec_key.len == 0 && IsCipherFile(files_path[i])) {
  228. err_msg << "Load model failed. The file " << files_path[i] << " be encrypted, please pass in correct key.";
  229. } else {
  230. err_msg << "Load model " << files_path[i] << " failed.";
  231. }
  232. MS_LOG(ERROR) << err_msg.str();
  233. return Status(kMEInvalidInput, err_msg.str());
  234. }
  235. results.emplace_back(std::make_shared<Graph::GraphData>(anf_graphs[i], kMindIR));
  236. }
  237. *graphs = std::move(results);
  238. return kSuccess;
  239. }
  240. err_msg << "Unsupported ModelType " << model_type;
  241. MS_LOG(ERROR) << err_msg.str();
  242. return Status(kMEInvalidInput, err_msg.str());
  243. }
  244. Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model *) {
  245. MS_LOG(ERROR) << "Unsupported feature.";
  246. return kMEFailed;
  247. }
  248. Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
  249. MS_LOG(ERROR) << "Unsupported feature.";
  250. return kMEFailed;
  251. }
  252. Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool) {
  253. MS_LOG(ERROR) << "Unsupported feature.";
  254. return kMEFailed;
  255. }
  256. } // namespace mindspore