|
- /**
- * \file src/model_parser.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #include "model_parser.h"
- #include "decryption/decrypt_base.h"
- #include "parse_info/parse_info_base.h"
-
- using namespace lite;
- using namespace model_parse;
-
- std::string ModelParser::sm_model_tag = "packed_model";
-
- void ModelParser::parse_header() {
- size_t tag_length = sm_model_tag.size();
-
- //! parse model tag
- const char* ptr = static_cast<char*>(m_model.get());
- std::string tag(static_cast<const char*>(ptr), tag_length);
- if (sm_model_tag == tag) {
- m_is_bare_model = false;
- } else {
- //! if no tag, the model is bare model, return
- m_is_bare_model = true;
- return;
- }
-
- uint8_t* buffer = static_cast<uint8_t*>(m_model.get()) + tag_length;
- auto packed_model = GetPackModel(buffer);
- auto models = packed_model->models();
- LITE_ASSERT(models->size() == 1, "Now only support one model");
- auto model = models->Get(0);
- m_model_name = model->header()->name()->c_str();
- m_model_decryption_name = model->header()->model_decryption_method()->c_str();
- m_info_decryption_name = model->header()->info_decryption_method()->c_str();
- m_info_parse_func_name = model->header()->info_parse_method()->c_str();
-
- m_info = model->info();
- m_model_data = model->data();
- }
-
- bool ModelParser::parse_model_info(
- Config& network_config, NetworkIO& network_io,
- std::unordered_map<std::string, LiteAny>& isolated_config_map,
- std::string& extra_info) const {
- //! no model info, no parse, direct return
- if (m_is_bare_model || !m_info) {
- return false;
- }
- size_t info_length = m_info->data()->size();
- const uint8_t* info_data = m_info->data()->Data();
- //! decryption the info
- auto info_ptr =
- decrypt_memory(info_data, info_length, m_info_decryption_name, info_length);
- //! parse the info
- LITE_LOCK_GUARD(parse_info_static_data().map_mutex);
- auto it_parse =
- parse_info_static_data().parse_info_methods.find(m_info_parse_func_name);
- if (it_parse == parse_info_static_data().parse_info_methods.end()) {
- LITE_THROW(ssprintf(
- "can't find model info parse function %s.",
- m_info_parse_func_name.c_str()));
- }
- auto model_info_parse_func =
- parse_info_static_data().parse_info_methods[m_info_parse_func_name];
- //! convert for NetworkIOInner to NetworkIO
- if (model_info_parse_func) {
- model_info_parse_func(
- info_ptr.get(), info_length, m_model_name, network_config, network_io,
- isolated_config_map, extra_info);
- } else {
- LITE_THROW(ssprintf(
- "model info parse function of %s is empty",
- m_info_parse_func_name.c_str()));
- }
- return true;
- }
-
- std::shared_ptr<void> ModelParser::parse_model(
- size_t& model_length, const Config& config) const {
- if (m_is_bare_model) {
- if (config.bare_model_cryption_name.size() == 0) {
- model_length = m_total_length;
- return m_model;
- } else {
- return decrypt_memory(
- static_cast<uint8_t*>(m_model.get()), m_total_length,
- config.bare_model_cryption_name, model_length);
- }
- }
- LITE_ASSERT(m_model_data, "packed model parse error!");
- model_length = m_model_data->data()->size();
- const uint8_t* model_data = m_model_data->data()->Data();
- LITE_ASSERT(model_length > 0, "The loaded model is of zero length.");
- return decrypt_memory(
- model_data, model_length, m_model_decryption_name, model_length);
- }
-
- std::shared_ptr<void> ModelParser::decrypt_memory(
- const uint8_t* data, size_t length, const std::string decryption_name,
- size_t& result_length) const {
- const uint8_t* memory_ptr = data;
- if (decryption_name == "NONE") {
- result_length = length;
- return std::shared_ptr<void>(const_cast<uint8_t*>(memory_ptr), [](void*) {});
- }
- LITE_LOCK_GUARD(decryption_static_data().map_mutex);
- auto it = decryption_static_data().decryption_methods.find(decryption_name);
- if (it == decryption_static_data().decryption_methods.end()) {
- LITE_THROW(ssprintf(
- "The decryption method %s is not registed yet.",
- decryption_name.c_str()));
- }
- auto&& func = it->second.first;
- auto&& key = it->second.second;
- if (func) {
- auto model_vector = func(memory_ptr, length, *key);
- result_length = model_vector.size();
- auto tmp_model_vector = new std::vector<uint8_t>(std::move(model_vector));
- return std::shared_ptr<void>(
- tmp_model_vector->data(),
- [tmp_model_vector](void*) { delete tmp_model_vector; });
- } else {
- LITE_THROW(ssprintf(
- "No decryption function in %s method.", decryption_name.c_str()));
- }
- }
-
- // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
|