| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * \file inlude/lite/pack_model.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| namespace lite { | |||
| struct FeatureBits32 { | |||
| uint32_t is_fast_run_cache : 1; | |||
| //! reserved for new fields | |||
| uint32_t : 31; | |||
| }; | |||
| struct Header { | |||
| std::string name; //! model name | |||
| std::string | |||
| model_decryption_method; //! model encryption method name, this is used to | |||
| //! find the right decryption method. [ | |||
| //! AES_default | RC4_default | | |||
| //! SIMPLE_FAST_RC4_default ], default is NONE. | |||
| std::string info_decryption_method; //! info data encryption method name, this is | |||
| //! used to find the right decryption method. [ | |||
| //! AES_default | RC4_default | | |||
| //! SIMPLE_FAST_RC4_default ], default is NONE. | |||
| std::string info_parse_method = "LITE_default"; //! info parse method name. | |||
| std::string info_cache_parse_method = | |||
| "LITE_parse_cache"; //! fastrun cache parse method name. | |||
| FeatureBits32 fb32; | |||
| }; | |||
| class FbsHelper; | |||
| class ModelPacker { | |||
| public: | |||
| ModelPacker( | |||
| std::string model_path, std::string packed_model_path, | |||
| std::string info_data_path = "", std::string info_algo_policy_path = "", | |||
| std::string info_binary_cache_path = ""); | |||
| void set_header( | |||
| std::string model_decryption_method = "NONE", | |||
| std::string info_decryption_method = "NONE", bool is_fast_run_cache = true); | |||
| void pack_model(); | |||
| private: | |||
| std::string m_packed_model_path; | |||
| std::string m_info_data_path; | |||
| //! fastrun cache / algo policy | |||
| std::string m_info_algo_policy_path; | |||
| //! binary cache | |||
| std::string m_info_binary_cache_path; | |||
| Header m_header; | |||
| friend class FbsHelper; | |||
| FbsHelper* m_fbs_helper; | |||
| }; | |||
| } // namespace lite | |||
| @@ -1,7 +1,7 @@ | |||
| # BUILD the load and run for lite | |||
| include_directories(PUBLIC | |||
| $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | |||
| file(GLOB_RECURSE SOURCES ./*.cpp) | |||
| file(GLOB_RECURSE SOURCES ./*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp) | |||
| add_executable(load_and_run ${SOURCES}) | |||
| target_link_libraries(load_and_run lite_static) | |||
| @@ -43,6 +43,8 @@ public: | |||
| virtual void wait() = 0; | |||
| virtual ~ModelBase() = default; | |||
| virtual const std::string& get_model_path() const = 0; | |||
| }; | |||
| } // namespace lar | |||
| @@ -60,6 +60,8 @@ public: | |||
| //! get algo strategy | |||
| Strategy& get_lite_strategy() { return m_strategy; } | |||
| const std::string& get_model_path() const override { return model_path; } | |||
| private: | |||
| bool share_model_mem; | |||
| bool enable_layout_transform; | |||
| @@ -107,6 +107,8 @@ public: | |||
| std::move(out_file), m_format.val()); | |||
| } | |||
| const std::string& get_model_path() const override { return model_path; } | |||
| private: | |||
| bool share_model_mem; | |||
| std::string model_path; | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/model_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "model_options.h" | |||
| #include "device_options.h" | |||
| #include "lite/pack_model.h" | |||
| #include "misc.h" | |||
| #include "models/model_lite.h" | |||
| #include "models/model_mdl.h" | |||
| namespace lar { | |||
| template <typename ModelImpl> | |||
| void PackModelOption::config_model_internel( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelImpl> model) { | |||
| if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
| lite::ModelPacker packer( | |||
| model->get_model_path(), packed_model_dump, pack_info_json, pack_cache, | |||
| pack_binary_cache); | |||
| packer.set_header(pack_info_cryption, pack_model_cryption, is_fast_run_cache); | |||
| packer.pack_model(); | |||
| } | |||
| } | |||
| } // namespace lar | |||
| using namespace lar; | |||
| ////////////////////// PackModel options //////////////////////// | |||
| PackModelOption::PackModelOption() { | |||
| m_option_name = "pack_model"; | |||
| if (!FLAGS_packed_model_dump.empty()) | |||
| packed_model_dump = FLAGS_packed_model_dump; | |||
| if (!FLAGS_pack_info_json.empty()) | |||
| pack_info_json = FLAGS_pack_info_json; | |||
| if (!FLAGS_pack_cache.empty()) | |||
| pack_cache = FLAGS_pack_cache; | |||
| if (!FLAGS_pack_info_cryption.empty()) | |||
| pack_info_cryption = FLAGS_pack_info_cryption; | |||
| if (!FLAGS_pack_model_cryption.empty()) | |||
| pack_model_cryption = FLAGS_pack_model_cryption; | |||
| } | |||
| bool PackModelOption::is_valid() { | |||
| return !FLAGS_packed_model_dump.empty(); | |||
| } | |||
| std::shared_ptr<OptionBase> PackModelOption::create_option() { | |||
| static std::shared_ptr<PackModelOption> option(new PackModelOption); | |||
| if (PackModelOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void PackModelOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ////////////////////// PackModel gflags //////////////////////// | |||
| DEFINE_string(packed_model_dump, "", "The output file path of packed model."); | |||
| DEFINE_string( | |||
| pack_info_json, "", | |||
| "An encrypted or not encrypted json format file to pack into the model."); | |||
| DEFINE_string(pack_cache, "", "Pack the fastrun cache or algo policy into the model."); | |||
| DEFINE_string( | |||
| pack_info_cryption, "NONE", | |||
| "The info data encryption method name, this is used to find the right " | |||
| "decryption method. --pack-info-cryption [ AES_default | RC4_default | " | |||
| "SIMPLE_FAST_RC4_default ], default is NONE. See " | |||
| "https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" | |||
| "pack-lite-model.html for more details."); | |||
| DEFINE_string( | |||
| pack_model_cryption, "NONE", | |||
| "The model encryption method name, this is used to find the right decryption " | |||
| "method. --pack-model-cryption [ AES_default | RC4_default | " | |||
| "SIMPLE_FAST_RC4_default ], default is NONE. See " | |||
| "https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" | |||
| "pack-lite-model.html for more details."); | |||
| REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option); | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/model_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_string(packed_model_dump); | |||
| DECLARE_string(pack_info_json); | |||
| DECLARE_string(pack_cache); | |||
| DECLARE_string(pack_info_cryption); | |||
| DECLARE_string(pack_model_cryption); | |||
| namespace lar { | |||
| class PackModelOption : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; } | |||
| private: | |||
| PackModelOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>); | |||
| std::string m_option_name; | |||
| std::string packed_model_dump; | |||
| std::string pack_info_json; | |||
| std::string pack_cache; | |||
| std::string pack_binary_cache; | |||
| std::string pack_info_cryption; | |||
| std::string pack_model_cryption; | |||
| bool is_fast_run_cache = true; | |||
| }; | |||
| } // namespace lar | |||
| @@ -119,6 +119,12 @@ class TestNetwork(TestShuffleNet): | |||
| network.load(model_path) | |||
| self.do_forward(network) | |||
| def test_pack_cache_to_model(self): | |||
| model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") | |||
| network = LiteNetwork() | |||
| network.load(model_path) | |||
| self.do_forward(network) | |||
| def test_network_basic(self): | |||
| network = LiteNetwork() | |||
| network.load(self.model_path) | |||
| @@ -0,0 +1,232 @@ | |||
| /** | |||
| * \file src/pack_model/pack_model.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "lite/pack_model.h" | |||
| #include "../misc.h" | |||
| #if LITE_BUILD_WITH_MGE | |||
| #include "megbrain/utils/infile_persistent_cache.h" | |||
| #endif | |||
| #include <flatbuffers/flatbuffers.h> | |||
| #include "nlohmann/json.hpp" | |||
| #include "pack_model_generated.h" | |||
| namespace lite { | |||
| class FbsHelper { | |||
| public: | |||
| FbsHelper() = default; | |||
| FbsHelper(ModelPacker* packer, std::string model_path); | |||
| flatbuffers::Offset<model_parse::ModelHeader> build_header(); | |||
| flatbuffers::Offset<model_parse::ModelInfo> build_info(); | |||
| flatbuffers::Offset<model_parse::ModelData> build_data(); | |||
| flatbuffers::FlatBufferBuilder& builder() { return m_builder; } | |||
| private: | |||
| ModelPacker* m_packer; | |||
| flatbuffers::FlatBufferBuilder m_builder; | |||
| std::vector<uint8_t> m_model_buffer; | |||
| const model_parse::ModelHeader* m_model_header = nullptr; | |||
| const model_parse::ModelInfo* m_model_info = nullptr; | |||
| const model_parse::ModelData* m_model_data = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| using namespace lite; | |||
| using namespace model_parse; | |||
| std::vector<uint8_t> read_file(std::string path) { | |||
| FILE* fin = fopen(path.c_str(), "rb"); | |||
| LITE_ASSERT(fin, "failed to open %s: %s", path.c_str(), strerror(errno)); | |||
| fseek(fin, 0, SEEK_END); | |||
| size_t size = ftell(fin); | |||
| fseek(fin, 0, SEEK_SET); | |||
| std::vector<uint8_t> buf; | |||
| buf.resize(size); | |||
| auto nr = fread(buf.data(), size, 1, fin); | |||
| LITE_ASSERT(nr == 1); | |||
| fclose(fin); | |||
| return buf; | |||
| } | |||
| FbsHelper::FbsHelper(ModelPacker* packer, std::string model_path) : m_packer(packer) { | |||
| m_model_buffer = read_file(model_path); | |||
| const char* model_ptr = | |||
| static_cast<const char*>(static_cast<void*>(m_model_buffer.data())); | |||
| std::string tag(model_ptr, 12); | |||
| if (tag == "packed_model") { | |||
| uint8_t* buffer = m_model_buffer.data() + 12; | |||
| auto model = GetPackModel(buffer)->models()->Get(0); | |||
| m_model_header = model->header(); | |||
| m_model_info = model->info(); | |||
| m_model_data = model->data(); | |||
| } | |||
| } | |||
| flatbuffers::Offset<ModelHeader> FbsHelper::build_header() { | |||
| flatbuffers::Offset<flatbuffers::String> name, info_decryption_method, | |||
| info_parse_method, model_decryption_method, info_cache_parse_method; | |||
| bool is_fast_run_cache; | |||
| if (m_model_header) { | |||
| auto&& header = m_model_header; | |||
| name = m_builder.CreateSharedString(header->name()); | |||
| info_decryption_method = | |||
| m_builder.CreateSharedString(header->info_decryption_method()); | |||
| info_parse_method = m_builder.CreateSharedString(header->info_parse_method()); | |||
| model_decryption_method = | |||
| m_builder.CreateSharedString(header->model_decryption_method()); | |||
| info_cache_parse_method = | |||
| m_builder.CreateSharedString(header->info_cache_parse_method()); | |||
| is_fast_run_cache = header->is_fast_run_cache(); | |||
| } else { | |||
| auto&& header = m_packer->m_header; | |||
| name = m_builder.CreateSharedString(header.name); | |||
| info_decryption_method = | |||
| m_builder.CreateSharedString(header.info_decryption_method); | |||
| info_parse_method = m_builder.CreateSharedString(header.info_parse_method); | |||
| model_decryption_method = | |||
| m_builder.CreateSharedString(header.model_decryption_method); | |||
| info_cache_parse_method = | |||
| m_builder.CreateSharedString(header.info_cache_parse_method); | |||
| is_fast_run_cache = header.fb32.is_fast_run_cache; | |||
| } | |||
| return CreateModelHeader( | |||
| m_builder, name, info_decryption_method, info_parse_method, | |||
| model_decryption_method, info_cache_parse_method, is_fast_run_cache); | |||
| } | |||
| flatbuffers::Offset<ModelData> FbsHelper::build_data() { | |||
| if (m_model_data) { | |||
| auto data = m_model_data->data()->Data(); | |||
| auto size = m_model_data->data()->size(); | |||
| return CreateModelData(m_builder, m_builder.CreateVector(data, size)); | |||
| } else { | |||
| return CreateModelData(m_builder, m_builder.CreateVector(m_model_buffer)); | |||
| } | |||
| } | |||
| flatbuffers::Offset<ModelInfo> FbsHelper::build_info() { | |||
| flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_data; | |||
| if (m_model_info && m_model_info->data() && m_packer->m_info_data_path.empty()) { | |||
| auto data = m_model_info->data()->Data(); | |||
| auto size = m_model_info->data()->size(); | |||
| fb_data = m_builder.CreateVector(data, size); | |||
| } else if (!m_packer->m_info_data_path.empty()) { | |||
| auto info_data = read_file(m_packer->m_info_data_path); | |||
| fb_data = m_builder.CreateVector(info_data); | |||
| } | |||
| flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_algo_policy; | |||
| flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_binary_cache; | |||
| if (m_packer->m_header.fb32.is_fast_run_cache) { | |||
| std::vector<uint8_t> info_algo_policy; | |||
| if (!m_packer->m_info_algo_policy_path.empty()) { | |||
| info_algo_policy = read_file(m_packer->m_info_algo_policy_path); | |||
| if (m_model_info && m_model_info->algo_policy()) { | |||
| auto cache = m_model_info->algo_policy()->Data(); | |||
| auto size = m_model_info->algo_policy()->size(); | |||
| uint32_t nr_category_1, nr_category_2, nr_category; | |||
| memcpy(&nr_category_1, cache, sizeof(uint32_t)); | |||
| memcpy(&nr_category_2, info_algo_policy.data(), sizeof(uint32_t)); | |||
| nr_category = nr_category_1 + nr_category_2; | |||
| std::vector<uint8_t> cache_append; | |||
| cache_append.resize(sizeof(nr_category)); | |||
| memcpy(cache_append.data(), &nr_category, sizeof(nr_category)); | |||
| cache_append.insert( | |||
| cache_append.end(), cache + sizeof(nr_category), cache + size); | |||
| cache_append.insert( | |||
| cache_append.end(), | |||
| info_algo_policy.begin() + sizeof(nr_category), | |||
| info_algo_policy.end()); | |||
| fb_algo_policy = m_builder.CreateVector(cache_append); | |||
| } else { | |||
| fb_algo_policy = m_builder.CreateVector(info_algo_policy); | |||
| } | |||
| } | |||
| #if LITE_BUILD_WITH_MGE | |||
| else { | |||
| info_algo_policy = static_cast<mgb::InFilePersistentCache&>( | |||
| mgb::PersistentCache::inst()) | |||
| .dump_cache(); | |||
| fb_algo_policy = m_builder.CreateVector(info_algo_policy); | |||
| } | |||
| #endif | |||
| } | |||
| ModelInfoBuilder builder(m_builder); | |||
| builder.add_data(fb_data); | |||
| builder.add_algo_policy(fb_algo_policy); | |||
| builder.add_binary_cache(fb_binary_cache); | |||
| return builder.Finish(); | |||
| } | |||
| ModelPacker::ModelPacker( | |||
| std::string model_path, std::string packed_model_path, | |||
| std::string info_data_path, std::string info_algo_policy_path, | |||
| std::string info_binary_cache_path) | |||
| : m_packed_model_path(packed_model_path), | |||
| m_info_data_path(info_data_path), | |||
| m_info_algo_policy_path(info_algo_policy_path), | |||
| m_info_binary_cache_path(info_binary_cache_path) { | |||
| m_fbs_helper = new FbsHelper(this, model_path); | |||
| } | |||
| void ModelPacker::set_header( | |||
| std::string model_decryption_method, std::string info_decryption_method, | |||
| bool is_fast_run_cache) { | |||
| m_header.model_decryption_method = model_decryption_method; | |||
| m_header.info_decryption_method = info_decryption_method; | |||
| memset(&m_header.fb32, 0, sizeof(m_header.fb32)); | |||
| m_header.fb32.is_fast_run_cache = is_fast_run_cache; | |||
| if (!m_info_data_path.empty()) { | |||
| auto buf = read_file(m_info_data_path); | |||
| std::string json_string( | |||
| static_cast<const char*>(static_cast<void*>(buf.data())), buf.size()); | |||
| auto info = nlohmann::json::parse(json_string); | |||
| m_header.name = info["name"]; | |||
| } | |||
| } | |||
| void ModelPacker::pack_model() { | |||
| auto fb_header = m_fbs_helper->build_header(); | |||
| auto fb_info = m_fbs_helper->build_info(); | |||
| auto fb_data = m_fbs_helper->build_data(); | |||
| ModelBuilder model_builder(m_fbs_helper->builder()); | |||
| model_builder.add_header(fb_header); | |||
| model_builder.add_info(fb_info); | |||
| model_builder.add_data(fb_data); | |||
| auto model = model_builder.Finish(); | |||
| std::vector<flatbuffers::Offset<Model>> models; | |||
| models.emplace_back(model); | |||
| auto fb_models = m_fbs_helper->builder().CreateVector(models); | |||
| PackModelBuilder pack_model_builder(m_fbs_helper->builder()); | |||
| pack_model_builder.add_models(fb_models); | |||
| m_fbs_helper->builder().Finish(pack_model_builder.Finish()); | |||
| FILE* fptr = fopen(m_packed_model_path.c_str(), "wb"); | |||
| std::string packed_model_tag = "packed_model"; | |||
| auto nr_tag = fwrite(packed_model_tag.c_str(), 1, packed_model_tag.size(), fptr); | |||
| LITE_ASSERT(nr_tag == packed_model_tag.size()); | |||
| auto fb_size = m_fbs_helper->builder().GetSize(); | |||
| auto nr_fb = fwrite(m_fbs_helper->builder().GetBufferPointer(), 1, fb_size, fptr); | |||
| LITE_ASSERT(nr_fb == fb_size); | |||
| fclose(fptr); | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * \file src/parse_info/cache_parse.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "lite/global.h" | |||
| #if LITE_BUILD_WITH_MGE | |||
| #include "megbrain/utils/infile_persistent_cache.h" | |||
| #endif | |||
| namespace lite { | |||
| //! The LITE_parse_cache parse info function | |||
| bool parse_info_cache( | |||
| const uint8_t* cache, size_t cache_length, bool is_fast_run_cache = true, | |||
| const uint8_t* binary_cache = nullptr, size_t binary_cache_length = 0) { | |||
| LITE_MARK_USED_VAR(binary_cache); | |||
| LITE_MARK_USED_VAR(binary_cache_length); | |||
| #if LITE_BUILD_WITH_MGE | |||
| if (is_fast_run_cache) { | |||
| mgb::PersistentCache::set_impl( | |||
| std::make_shared<mgb::InFilePersistentCache>(cache, cache_length)); | |||
| } | |||
| #endif | |||
| return true; | |||
| } | |||
| } // namespace lite | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -11,6 +11,7 @@ | |||
| #include "model_parser.h" | |||
| #include "decryption/decrypt_base.h" | |||
| #include "parse_info/cache_parse.h" | |||
| #include "parse_info/parse_info_base.h" | |||
| using namespace lite; | |||
| @@ -41,6 +42,10 @@ void ModelParser::parse_header() { | |||
| m_model_decryption_name = model->header()->model_decryption_method()->c_str(); | |||
| m_info_decryption_name = model->header()->info_decryption_method()->c_str(); | |||
| m_info_parse_func_name = model->header()->info_parse_method()->c_str(); | |||
| if (model->header()->info_cache_parse_method()) | |||
| m_info_cache_parse_func_name = | |||
| model->header()->info_cache_parse_method()->c_str(); | |||
| m_is_fast_run_cache = model->header()->is_fast_run_cache(); | |||
| m_info = model->info(); | |||
| m_model_data = model->data(); | |||
| @@ -54,31 +59,52 @@ bool ModelParser::parse_model_info( | |||
| if (m_is_bare_model || !m_info) { | |||
| return false; | |||
| } | |||
| size_t info_length = m_info->data()->size(); | |||
| const uint8_t* info_data = m_info->data()->Data(); | |||
| //! decryption the info | |||
| auto info_ptr = | |||
| decrypt_memory(info_data, info_length, m_info_decryption_name, info_length); | |||
| //! parse the info | |||
| LITE_LOCK_GUARD(parse_info_static_data().map_mutex); | |||
| auto it_parse = | |||
| parse_info_static_data().parse_info_methods.find(m_info_parse_func_name); | |||
| if (it_parse == parse_info_static_data().parse_info_methods.end()) { | |||
| LITE_THROW(ssprintf( | |||
| "can't find model info parse function %s.", | |||
| m_info_parse_func_name.c_str())); | |||
| //! parse ModelInfo::data | |||
| if (m_info->data()) { | |||
| size_t info_length = m_info->data()->size(); | |||
| const uint8_t* info_data = m_info->data()->Data(); | |||
| //! decryption the info | |||
| auto info_ptr = decrypt_memory( | |||
| info_data, info_length, m_info_decryption_name, info_length); | |||
| //! parse the info | |||
| LITE_LOCK_GUARD(parse_info_static_data().map_mutex); | |||
| auto it_parse = parse_info_static_data().parse_info_methods.find( | |||
| m_info_parse_func_name); | |||
| if (it_parse == parse_info_static_data().parse_info_methods.end()) { | |||
| LITE_THROW(ssprintf( | |||
| "can't find model info parse function %s.", | |||
| m_info_parse_func_name.c_str())); | |||
| } | |||
| auto model_info_parse_func = | |||
| parse_info_static_data().parse_info_methods[m_info_parse_func_name]; | |||
| //! convert for NetworkIOInner to NetworkIO | |||
| if (model_info_parse_func) { | |||
| model_info_parse_func( | |||
| info_ptr.get(), info_length, m_model_name, network_config, | |||
| network_io, isolated_config_map, extra_info); | |||
| } else { | |||
| LITE_THROW(ssprintf( | |||
| "model info parse function of %s is empty", | |||
| m_info_parse_func_name.c_str())); | |||
| } | |||
| } | |||
| auto model_info_parse_func = | |||
| parse_info_static_data().parse_info_methods[m_info_parse_func_name]; | |||
| //! convert for NetworkIOInner to NetworkIO | |||
| if (model_info_parse_func) { | |||
| model_info_parse_func( | |||
| info_ptr.get(), info_length, m_model_name, network_config, network_io, | |||
| isolated_config_map, extra_info); | |||
| } else { | |||
| LITE_THROW(ssprintf( | |||
| "model info parse function of %s is empty", | |||
| m_info_parse_func_name.c_str())); | |||
| //! parse ModelInfo::algo_policy | |||
| if (m_info->algo_policy()) { | |||
| size_t cache_length = m_info->algo_policy()->size(); | |||
| const uint8_t* cache = m_info->algo_policy()->Data(); | |||
| if (m_info_cache_parse_func_name == "LITE_parse_cache") { | |||
| if (m_is_fast_run_cache) { | |||
| parse_info_cache(cache, cache_length); | |||
| } else if (m_info->binary_cache()) { | |||
| size_t binary_cache_length = m_info->binary_cache()->size(); | |||
| const uint8_t* binary_cache = m_info->binary_cache()->Data(); | |||
| parse_info_cache( | |||
| cache, cache_length, m_is_fast_run_cache, binary_cache, | |||
| binary_cache_length); | |||
| } else { | |||
| LITE_THROW("opencl binary cache is not given"); | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -60,6 +60,8 @@ private: | |||
| std::string m_model_decryption_name; | |||
| //! the function name to parse the model info | |||
| std::string m_info_parse_func_name; | |||
| std::string m_info_cache_parse_func_name; | |||
| bool m_is_fast_run_cache; | |||
| //! if a model is not added json info to the model is not crypted, the | |||
| //! model is a bare model | |||
| bool m_is_bare_model = true; | |||
| @@ -5,10 +5,14 @@ table ModelHeader { | |||
| info_decryption_method:string; | |||
| info_parse_method:string; | |||
| model_decryption_method:string; | |||
| info_cache_parse_method:string; | |||
| is_fast_run_cache:bool; | |||
| } | |||
| table ModelInfo { | |||
| data:[ubyte]; | |||
| algo_policy:[ubyte]; | |||
| binary_cache:[ubyte]; | |||
| } | |||
| table ModelData { | |||
| @@ -970,6 +970,25 @@ TEST(TestNetWork, LoadPackedModel) { | |||
| network->wait(); | |||
| } | |||
| TEST(TestNetWork, LoadPackedCacheModel) { | |||
| auto tensor = get_input_data("./input_data.npy"); | |||
| std::string model_path = "./test_pack_cache_to_model.lite"; | |||
| std::string input_name = "data"; | |||
| NetworkIO IO; | |||
| Config config; | |||
| std::shared_ptr<Network> network = std::make_shared<Network>(config, IO); | |||
| network->load_model(model_path); | |||
| std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name); | |||
| auto src_ptr = tensor->get_memory_ptr(); | |||
| auto src_layout = tensor->get_layout(); | |||
| input_tensor->reset(src_ptr, src_layout); | |||
| network->forward(); | |||
| network->wait(); | |||
| } | |||
| TEST(TestNetWork, GlabalLayoutTransform) { | |||
| auto tensor = get_input_data("./input_data.npy"); | |||
| std::string model_path = "./shufflenet.mge"; | |||
| @@ -216,6 +216,46 @@ void InFilePersistentCache::dump_cache(OutputFile* out_file) { | |||
| } | |||
| } | |||
| } | |||
| std::vector<uint8_t> InFilePersistentCache::dump_cache() { | |||
| std::vector<uint8_t> ret; | |||
| uint32_t nr_category = m_cache.size(); | |||
| ret.resize(sizeof(nr_category)); | |||
| memcpy(ret.data(), &nr_category, sizeof(nr_category)); | |||
| auto write_to_buffer = [&ret](uint32_t val) { | |||
| std::vector<uint8_t> vec(sizeof(val)); | |||
| memcpy(vec.data(), &val, sizeof(val)); | |||
| ret.insert(ret.end(), vec.begin(), vec.end()); | |||
| }; | |||
| for (const auto& cached_category : m_cache) { | |||
| uint32_t category_size = cached_category.first.size(); | |||
| write_to_buffer(category_size); | |||
| std::vector<uint8_t> category( | |||
| cached_category.first.begin(), cached_category.first.end()); | |||
| ret.insert(ret.end(), category.begin(), category.end()); | |||
| uint32_t nr_bobs = cached_category.second.size(); | |||
| write_to_buffer(nr_bobs); | |||
| for (const auto& item : cached_category.second) { | |||
| uint32_t size_first = item.first.size; | |||
| write_to_buffer(size_first); | |||
| ret.insert( | |||
| ret.end(), item.first.data_refhold.get(), | |||
| item.first.data_refhold.get() + size_first); | |||
| uint32_t size_second = item.second.size; | |||
| write_to_buffer(size_second); | |||
| ret.insert( | |||
| ret.end(), item.second.data_refhold.get(), | |||
| item.second.data_refhold.get() + size_second); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| Maybe<InFilePersistentCache::Blob> InFilePersistentCache::get( | |||
| const std::string& category, const Blob& key) { | |||
| decltype(m_cache.begin()) iter0; | |||
| @@ -71,6 +71,7 @@ public: | |||
| */ | |||
| MGE_WIN_DECLSPEC_FUC void dump_cache(const char* path); | |||
| MGE_WIN_DECLSPEC_FUC void dump_cache(OutputFile* out_file); | |||
| MGE_WIN_DECLSPEC_FUC std::vector<uint8_t> dump_cache(); | |||
| MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | |||
| const std::string& category, const Blob& key) override; | |||