You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer.cpp 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #include "megbrain/serialization/serializer.h"
  2. #include "megbrain/gopt/inference.h"
  3. #include "megbrain/opr/utility.h"
  4. namespace mgb {
  5. namespace serialization {
  6. /* ====================== helper impls ====================== */
  7. GraphLoader::LoadResult::~LoadResult() noexcept = default;
  8. std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
  9. const ComputingGraph::OutputSpec& outspec) {
  10. auto ret = graph->compile(outspec);
  11. if (graph->options().comp_node_seq_record_level == 2) {
  12. ComputingGraph::assert_destroy(graph);
  13. }
  14. return ret;
  15. }
  16. void GraphLoader::LoadResult::graph_compile_ahead() {
  17. //! when force_output_use_user_specified_memory is set, the output var may
  18. //! be changed by gopt, then the var in LoadResult can not exist, so here
  19. //! just do basic optimize_for_inference ahead, and replace the var in
  20. //! LoadResult
  21. if (graph->options().force_output_use_user_specified_memory) {
  22. auto options = gopt::OptimizeForInferenceOptions{};
  23. auto new_vars = gopt::optimize_for_inference(output_var_list, options);
  24. output_var_list = new_vars;
  25. output_var_map.clear();
  26. for (auto& var : new_vars) {
  27. output_var_map[var.node()->cname()] = var;
  28. }
  29. std::unordered_map<size_t, SymbolVar> var_map_id;
  30. for (auto& var : new_vars) {
  31. bool found = false;
  32. for (auto& old_var_it : output_var_map_id) {
  33. if (old_var_it.second.node()->name() == var.node()->name()) {
  34. found = true;
  35. var_map_id[old_var_it.first] = var;
  36. }
  37. }
  38. mgb_assert(
  39. found, "can't find var name %s when optimize_for_inference. ",
  40. var.node()->cname());
  41. }
  42. }
  43. }
  44. GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
  45. SharedTensorNameMap ret;
  46. for (auto&& i : shared_tensor_id_map()) {
  47. mgb_assert(!i.first.empty(), "name stripped during graph dump");
  48. auto ins = ret.emplace(i.first, &i.second);
  49. mgb_assert(ins.second);
  50. }
  51. return ret;
  52. }
  53. std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
  54. std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file);
  55. std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file);
  56. std::unique_ptr<GraphDumper> make_fbs_v2_dumper(
  57. std::unique_ptr<OutputFile> file, int version);
  58. bool is_fbs_file(InputFile& file);
  59. bool is_fbs_v2_file(InputFile& file);
  60. bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
  61. #if MGB_ENABLE_GRAD
  62. return opr->same_type<opr::SetGrad>();
  63. #else
  64. return false;
  65. #endif
  66. }
  67. std::unique_ptr<GraphDumper> GraphDumper::make(
  68. std::unique_ptr<OutputFile> file, GraphDumpFormat format, int version) {
  69. switch (format) {
  70. case GraphDumpFormat::FLATBUFFERS:
  71. #if MGB_ENABLE_FBS_SERIALIZATION
  72. return make_fbs_dumper(std::move(file));
  73. #endif
  74. MGB_FALLTHRU
  75. case GraphDumpFormat::FLATBUFFERS_V2:
  76. #if MGB_ENABLE_FBS_SERIALIZATION
  77. return make_fbs_v2_dumper(std::move(file), version);
  78. #endif
  79. MGB_FALLTHRU
  80. default:
  81. mgb_throw(SerializationError, "unsupported serialization format requested");
  82. }
  83. mgb_assert(false, "unreachable");
  84. }
  85. std::unique_ptr<GraphLoader> GraphLoader::make(
  86. std::unique_ptr<InputFile> file, GraphDumpFormat format) {
  87. switch (format) {
  88. case GraphDumpFormat::FLATBUFFERS:
  89. #if MGB_ENABLE_FBS_SERIALIZATION
  90. return make_fbs_loader(std::move(file));
  91. #endif
  92. MGB_FALLTHRU
  93. case GraphDumpFormat::FLATBUFFERS_V2:
  94. #if MGB_ENABLE_FBS_SERIALIZATION
  95. return make_fbs_v2_loader(std::move(file));
  96. #endif
  97. MGB_FALLTHRU
  98. default:
  99. mgb_throw(SerializationError, "unsupported serialization format requested");
  100. }
  101. mgb_assert(false, "unreachable");
  102. }
  103. Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file) {
  104. #if MGB_ENABLE_FBS_SERIALIZATION
  105. if (is_fbs_file(file)) {
  106. return GraphDumpFormat::FLATBUFFERS;
  107. }
  108. if (is_fbs_v2_file(file)) {
  109. return GraphDumpFormat::FLATBUFFERS_V2;
  110. }
  111. #endif
  112. return {};
  113. }
  114. } // namespace serialization
  115. } // namespace mgb