You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_parser.h 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. /**
  2. * \file src/model_parser.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "../network_impl_base.h"
  13. #include "lite/global.h"
  14. #include <flatbuffers/flatbuffers.h>
  15. #include "pack_model_generated.h"
  16. #include <unordered_map>
  17. namespace lite {
  18. /*!
  19. * \brief parse the model and decyt
  20. */
  21. class ModelParser {
  22. public:
  23. ModelParser(std::shared_ptr<void> model_ptr, size_t model_length)
  24. : m_model(model_ptr), m_total_length(model_length) {
  25. //! parse the header
  26. parse_header();
  27. }
  28. //! parse the Info part of the model, update the network_config and
  29. //! network_io
  30. bool parse_model_info(
  31. Config& network_config, NetworkIO& network_io,
  32. std::unordered_map<std::string, LiteAny>& isolated_config_map,
  33. std::string& extra_info) const;
  34. //! parse the model and decrypt the model
  35. std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const;
  36. private:
  37. //! parse the header of the model and store the model related information
  38. //! to the menber data
  39. void parse_header();
  40. //! decrypt a memory with length of length and decryption method name
  41. //! decrypt_name
  42. std::shared_ptr<void> decrypt_memory(
  43. const uint8_t* data, size_t length, const std::string decryption_name,
  44. size_t& result_length) const;
  45. private:
  46. std::string m_model_name;
  47. //! the info and model decryption method name, the
  48. //! decryption func can be found through this name
  49. std::string m_info_decryption_name;
  50. std::string m_model_decryption_name;
  51. //! the function name to parse the model info
  52. std::string m_info_parse_func_name;
  53. //! if a model is not added json info to the model is not crypted, the
  54. //! model is a bare model
  55. bool m_is_bare_model = true;
  56. const model_parse::ModelInfo* m_info = nullptr;
  57. const model_parse::ModelData* m_model_data = nullptr;
  58. std::shared_ptr<void> m_model;
  59. size_t m_total_length;
  60. static std::string sm_model_tag;
  61. };
  62. } // namespace lite
  63. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}