You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialization.cc 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/api/serialization.h"
  17. #include <fstream>
  18. #include "cxx_api/graph/graph_data.h"
  19. #include "utils/utils.h"
  20. namespace mindspore {
  21. static Buffer ReadFile(const std::string &file) {
  22. Buffer buffer;
  23. if (file.empty()) {
  24. MS_LOG(ERROR) << "Pointer file is nullptr";
  25. return buffer;
  26. }
  27. char real_path_mem[PATH_MAX] = {0};
  28. char *real_path_ret = nullptr;
  29. #if defined(_WIN32) || defined(_WIN64)
  30. real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX);
  31. #else
  32. real_path_ret = realpath(common::SafeCStr(file), real_path_mem);
  33. #endif
  34. if (real_path_ret == nullptr) {
  35. MS_LOG(ERROR) << "File: " << file << " is not exist.";
  36. return buffer;
  37. }
  38. std::string real_path(real_path_mem);
  39. std::ifstream ifs(real_path);
  40. if (!ifs.good()) {
  41. MS_LOG(ERROR) << "File: " << real_path << " is not exist";
  42. return buffer;
  43. }
  44. if (!ifs.is_open()) {
  45. MS_LOG(ERROR) << "File: " << real_path << "open failed";
  46. return buffer;
  47. }
  48. ifs.seekg(0, std::ios::end);
  49. size_t size = ifs.tellg();
  50. buffer.ResizeData(size);
  51. if (buffer.DataSize() != size) {
  52. MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path;
  53. ifs.close();
  54. return buffer;
  55. }
  56. ifs.seekg(0, std::ios::beg);
  57. ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
  58. ifs.close();
  59. return buffer;
  60. }
  61. Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelType model_type) {
  62. if (model_type == kMindIR) {
  63. auto anf_graph = std::make_shared<FuncGraph>();
  64. return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
  65. } else if (model_type == kOM) {
  66. return Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
  67. }
  68. MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
  69. }
  70. Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
  71. Buffer data = ReadFile(file);
  72. if (data.Data() == nullptr) {
  73. MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
  74. }
  75. if (model_type == kMindIR) {
  76. auto anf_graph = std::make_shared<FuncGraph>();
  77. return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
  78. } else if (model_type == kOM) {
  79. return Graph(std::make_shared<Graph::GraphData>(data, kOM));
  80. }
  81. MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
  82. }
  83. Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
  84. MS_LOG(ERROR) << "Unsupported feature.";
  85. return kMEFailed;
  86. }
  87. Status Serialization::SetParameters(const std::map<std::string, Buffer> &parameters, Model *model) {
  88. MS_LOG(ERROR) << "Unsupported feature.";
  89. return kMEFailed;
  90. }
  91. Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
  92. MS_LOG(ERROR) << "Unsupported feature.";
  93. return kMEFailed;
  94. }
  95. Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file) {
  96. MS_LOG(ERROR) << "Unsupported feature.";
  97. return kMEFailed;
  98. }
  99. } // namespace mindspore

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.