You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_parser.cpp 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. /**
  2. * \file src/model_parser.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "model_parser.h"
  12. #include "decryption/decrypt_base.h"
  13. #include "parse_info/parse_info_base.h"
  14. using namespace lite;
  15. using namespace model_parse;
  16. std::string ModelParser::sm_model_tag = "packed_model";
  17. void ModelParser::parse_header() {
  18. size_t tag_length = sm_model_tag.size();
  19. //! parse model tag
  20. const char* ptr = static_cast<char*>(m_model.get());
  21. std::string tag(static_cast<const char*>(ptr), tag_length);
  22. if (sm_model_tag == tag) {
  23. m_is_bare_model = false;
  24. } else {
  25. //! if no tag, the model is bare model, return
  26. m_is_bare_model = true;
  27. return;
  28. }
  29. uint8_t* buffer = static_cast<uint8_t*>(m_model.get()) + tag_length;
  30. auto packed_model = GetPackModel(buffer);
  31. auto models = packed_model->models();
  32. LITE_ASSERT(models->size() == 1, "Now only support one model");
  33. auto model = models->Get(0);
  34. m_model_name = model->header()->name()->c_str();
  35. m_model_decryption_name = model->header()->model_decryption_method()->c_str();
  36. m_info_decryption_name = model->header()->info_decryption_method()->c_str();
  37. m_info_parse_func_name = model->header()->info_parse_method()->c_str();
  38. m_info = model->info();
  39. m_model_data = model->data();
  40. }
  41. bool ModelParser::parse_model_info(
  42. Config& network_config, NetworkIO& network_io,
  43. std::unordered_map<std::string, LiteAny>& isolated_config_map,
  44. std::string& extra_info) const {
  45. //! no model info, no parse, direct return
  46. if (m_is_bare_model || !m_info) {
  47. return false;
  48. }
  49. size_t info_length = m_info->data()->size();
  50. const uint8_t* info_data = m_info->data()->Data();
  51. //! decryption the info
  52. auto info_ptr =
  53. decrypt_memory(info_data, info_length, m_info_decryption_name, info_length);
  54. //! parse the info
  55. LITE_LOCK_GUARD(parse_info_static_data().map_mutex);
  56. auto it_parse =
  57. parse_info_static_data().parse_info_methods.find(m_info_parse_func_name);
  58. if (it_parse == parse_info_static_data().parse_info_methods.end()) {
  59. LITE_THROW(ssprintf(
  60. "can't find model info parse function %s.",
  61. m_info_parse_func_name.c_str()));
  62. }
  63. auto model_info_parse_func =
  64. parse_info_static_data().parse_info_methods[m_info_parse_func_name];
  65. //! convert for NetworkIOInner to NetworkIO
  66. if (model_info_parse_func) {
  67. model_info_parse_func(
  68. info_ptr.get(), info_length, m_model_name, network_config, network_io,
  69. isolated_config_map, extra_info);
  70. } else {
  71. LITE_THROW(ssprintf(
  72. "model info parse function of %s is empty",
  73. m_info_parse_func_name.c_str()));
  74. }
  75. return true;
  76. }
  77. std::shared_ptr<void> ModelParser::parse_model(
  78. size_t& model_length, const Config& config) const {
  79. if (m_is_bare_model) {
  80. if (config.bare_model_cryption_name.size() == 0) {
  81. model_length = m_total_length;
  82. return m_model;
  83. } else {
  84. return decrypt_memory(
  85. static_cast<uint8_t*>(m_model.get()), m_total_length,
  86. config.bare_model_cryption_name, model_length);
  87. }
  88. }
  89. LITE_ASSERT(m_model_data, "packed model parse error!");
  90. model_length = m_model_data->data()->size();
  91. const uint8_t* model_data = m_model_data->data()->Data();
  92. LITE_ASSERT(model_length > 0, "The loaded model is of zero length.");
  93. return decrypt_memory(
  94. model_data, model_length, m_model_decryption_name, model_length);
  95. }
  96. std::shared_ptr<void> ModelParser::decrypt_memory(
  97. const uint8_t* data, size_t length, const std::string decryption_name,
  98. size_t& result_length) const {
  99. const uint8_t* memory_ptr = data;
  100. if (decryption_name == "NONE") {
  101. result_length = length;
  102. return std::shared_ptr<void>(const_cast<uint8_t*>(memory_ptr), [](void*) {});
  103. }
  104. LITE_LOCK_GUARD(decryption_static_data().map_mutex);
  105. auto it = decryption_static_data().decryption_methods.find(decryption_name);
  106. if (it == decryption_static_data().decryption_methods.end()) {
  107. LITE_THROW(ssprintf(
  108. "The decryption method %s is not registed yet.",
  109. decryption_name.c_str()));
  110. }
  111. auto&& func = it->second.first;
  112. auto&& key = it->second.second;
  113. if (func) {
  114. auto model_vector = func(memory_ptr, length, *key);
  115. result_length = model_vector.size();
  116. auto tmp_model_vector = new std::vector<uint8_t>(std::move(model_vector));
  117. return std::shared_ptr<void>(
  118. tmp_model_vector->data(),
  119. [tmp_model_vector](void*) { delete tmp_model_vector; });
  120. } else {
  121. LITE_THROW(ssprintf(
  122. "No decryption function in %s method.", decryption_name.c_str()));
  123. }
  124. }
  125. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}