You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialization.cc 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/api/serialization.h"
  17. #include <fstream>
  18. #include "cxx_api/graph/graph_data.h"
  19. #include "utils/log_adapter.h"
  20. #include "mindspore/core/load_mindir/load_model.h"
  21. namespace mindspore {
  22. static Buffer ReadFile(const std::string &file) {
  23. Buffer buffer;
  24. if (file.empty()) {
  25. MS_LOG(ERROR) << "Pointer file is nullptr";
  26. return buffer;
  27. }
  28. char real_path_mem[PATH_MAX] = {0};
  29. char *real_path_ret = nullptr;
  30. #if defined(_WIN32) || defined(_WIN64)
  31. real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX);
  32. #else
  33. real_path_ret = realpath(common::SafeCStr(file), real_path_mem);
  34. #endif
  35. if (real_path_ret == nullptr) {
  36. MS_LOG(ERROR) << "File: " << file << " is not exist.";
  37. return buffer;
  38. }
  39. std::string real_path(real_path_mem);
  40. std::ifstream ifs(real_path);
  41. if (!ifs.good()) {
  42. MS_LOG(ERROR) << "File: " << real_path << " is not exist";
  43. return buffer;
  44. }
  45. if (!ifs.is_open()) {
  46. MS_LOG(ERROR) << "File: " << real_path << "open failed";
  47. return buffer;
  48. }
  49. ifs.seekg(0, std::ios::end);
  50. size_t size = ifs.tellg();
  51. buffer.ResizeData(size);
  52. if (buffer.DataSize() != size) {
  53. MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path;
  54. ifs.close();
  55. return buffer;
  56. }
  57. ifs.seekg(0, std::ios::beg);
  58. ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
  59. ifs.close();
  60. return buffer;
  61. }
  62. Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
  63. if (graph == nullptr) {
  64. MS_LOG(ERROR) << "Output args graph is nullptr.";
  65. return kMEInvalidInput;
  66. }
  67. if (model_type == kMindIR) {
  68. FuncGraphPtr anf_graph = nullptr;
  69. try {
  70. anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
  71. } catch (const std::exception &) {
  72. MS_LOG(ERROR) << "Load model failed.";
  73. return kMEInvalidInput;
  74. }
  75. *graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
  76. return kSuccess;
  77. } else if (model_type == kOM) {
  78. *graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
  79. return kSuccess;
  80. }
  81. MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
  82. return kMEInvalidInput;
  83. }
  84. Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
  85. if (graph == nullptr) {
  86. MS_LOG(ERROR) << "Output args graph is nullptr.";
  87. return kMEInvalidInput;
  88. }
  89. std::string file_path = CharToString(file);
  90. if (model_type == kMindIR) {
  91. FuncGraphPtr anf_graph = LoadMindIR(file_path);
  92. if (anf_graph == nullptr) {
  93. MS_LOG(ERROR) << "Load model failed.";
  94. return kMEInvalidInput;
  95. }
  96. *graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
  97. return kSuccess;
  98. } else if (model_type == kOM) {
  99. Buffer data = ReadFile(file_path);
  100. if (data.Data() == nullptr) {
  101. MS_LOG(ERROR) << "Read file " << file_path << " failed.";
  102. return kMEInvalidInput;
  103. }
  104. *graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
  105. return kSuccess;
  106. }
  107. MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
  108. return kMEInvalidInput;
  109. }
  110. Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
  111. MS_LOG(ERROR) << "Unsupported feature.";
  112. return kMEFailed;
  113. }
  114. Status Serialization::SetParameters(const std::map<std::string, Buffer> &parameters, Model *model) {
  115. MS_LOG(ERROR) << "Unsupported feature.";
  116. return kMEFailed;
  117. }
  118. Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
  119. MS_LOG(ERROR) << "Unsupported feature.";
  120. return kMEFailed;
  121. }
  122. Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) {
  123. MS_LOG(ERROR) << "Unsupported feature.";
  124. return kMEFailed;
  125. }
  126. } // namespace mindspore