You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer.cpp 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #include "megbrain/serialization/serializer.h"
  2. #include "megbrain/gopt/inference.h"
  3. #include "megbrain/opr/io.h"
  4. #include "megbrain/opr/tensor_manip.h"
  5. #include "megbrain/opr/utility.h"
  6. namespace {
  7. bool is_opr_memforward_var(mgb::VarNode* var) {
  8. if (var) {
  9. auto opr = var->owner_opr();
  10. if (opr->try_cast_final<mgb::opr::Reshape>() ||
  11. opr->try_cast_final<mgb::opr::Broadcast>() ||
  12. opr->try_cast_final<mgb::opr::Subtensor>() ||
  13. opr->try_cast_final<mgb::opr::AxisAddRemove>() ||
  14. opr->try_cast_final<mgb::opr::Dimshuffle>()) {
  15. return true;
  16. }
  17. };
  18. return false;
  19. }
  20. } // namespace
  21. namespace mgb {
  22. namespace serialization {
  23. /* ====================== helper impls ====================== */
  24. GraphLoader::LoadResult::~LoadResult() noexcept = default;
  25. std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
  26. const ComputingGraph::OutputSpec& outspec) {
  27. auto ret = graph->compile(outspec);
  28. if (graph->options().comp_node_seq_record_level == 2) {
  29. ComputingGraph::assert_destroy(graph);
  30. }
  31. return ret;
  32. }
  33. void GraphLoader::LoadResult::update_output_var_list(
  34. const SymbolVarArray& output_var_array) {
  35. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
  36. mgb_assert(output_var_array.size() == output_var_list.size());
  37. // replace symvar in output_var_list
  38. for (size_t idx = 0; idx < output_var_array.size(); ++idx) {
  39. out_var_map[output_var_list[idx]] = output_var_array[idx];
  40. output_var_list[idx] = output_var_array[idx];
  41. }
  42. // replace symvar in output_var_map_id
  43. for (auto&& item : output_var_map_id) {
  44. item.second = out_var_map[item.second];
  45. }
  46. // replace symvar in output_var_map
  47. for (auto&& item : output_var_map) {
  48. item.second = out_var_map[item.second].rename(item.first);
  49. }
  50. }
  51. void GraphLoader::LoadResult::graph_compile_ahead() {
  52. //! when force_output_use_user_specified_memory is set, the output var may
  53. //! be changed by gopt, then the var in LoadResult can not exist, so here
  54. //! just do basic optimize_for_inference ahead, and replace the var in
  55. //! LoadResult
  56. if (graph->options().force_output_use_user_specified_memory) {
  57. //! if the output var is like dimshuffle, reshape, it maybe memory forward to
  58. //! the output, so add a Copy operator in the end.
  59. for (auto& var : output_var_list) {
  60. if (is_opr_memforward_var(var.node())) {
  61. std::string name = var.node()->name();
  62. var = opr::Copy::make(var, name);
  63. }
  64. }
  65. auto options = gopt::OptimizeForInferenceOptions{};
  66. auto new_vars = gopt::optimize_for_inference(output_var_list, options);
  67. output_var_list = new_vars;
  68. output_var_map.clear();
  69. for (auto& var : new_vars) {
  70. output_var_map[var.node()->cname()] = var;
  71. }
  72. std::unordered_map<size_t, SymbolVar> var_map_id;
  73. for (auto& var : new_vars) {
  74. bool found = false;
  75. for (auto& old_var_it : output_var_map_id) {
  76. if (old_var_it.second.node()->name() == var.node()->name()) {
  77. found = true;
  78. var_map_id[old_var_it.first] = var;
  79. }
  80. }
  81. mgb_assert(
  82. found, "can't find var name %s when optimize_for_inference. ",
  83. var.node()->cname());
  84. }
  85. output_var_map_id = var_map_id;
  86. }
  87. }
  88. GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
  89. SharedTensorNameMap ret;
  90. for (auto&& i : shared_tensor_id_map()) {
  91. mgb_assert(!i.first.empty(), "name stripped during graph dump");
  92. auto ins = ret.emplace(i.first, &i.second);
  93. mgb_assert(ins.second);
  94. }
  95. return ret;
  96. }
  97. std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
  98. std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file);
  99. std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file);
  100. std::unique_ptr<GraphDumper> make_fbs_v2_dumper(
  101. std::unique_ptr<OutputFile> file, int version);
  102. bool is_fbs_file(InputFile& file);
  103. bool is_fbs_v2_file(InputFile& file);
  104. bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
  105. #if MGB_ENABLE_GRAD
  106. return opr->same_type<opr::SetGrad>();
  107. #else
  108. return false;
  109. #endif
  110. }
  111. std::unique_ptr<GraphDumper> GraphDumper::make(
  112. std::unique_ptr<OutputFile> file, GraphDumpFormat format, int version) {
  113. switch (format) {
  114. case GraphDumpFormat::FLATBUFFERS:
  115. #if MGB_ENABLE_FBS_SERIALIZATION
  116. return make_fbs_dumper(std::move(file));
  117. #endif
  118. MGB_FALLTHRU
  119. case GraphDumpFormat::FLATBUFFERS_V2:
  120. #if MGB_ENABLE_FBS_SERIALIZATION
  121. return make_fbs_v2_dumper(std::move(file), version);
  122. #endif
  123. MGB_FALLTHRU
  124. default:
  125. mgb_throw(SerializationError, "unsupported serialization format requested");
  126. }
  127. mgb_assert(false, "unreachable");
  128. }
  129. std::unique_ptr<GraphLoader> GraphLoader::make(
  130. std::unique_ptr<InputFile> file, GraphDumpFormat format) {
  131. switch (format) {
  132. case GraphDumpFormat::FLATBUFFERS:
  133. #if MGB_ENABLE_FBS_SERIALIZATION
  134. return make_fbs_loader(std::move(file));
  135. #endif
  136. MGB_FALLTHRU
  137. case GraphDumpFormat::FLATBUFFERS_V2:
  138. #if MGB_ENABLE_FBS_SERIALIZATION
  139. return make_fbs_v2_loader(std::move(file));
  140. #endif
  141. MGB_FALLTHRU
  142. default:
  143. mgb_throw(SerializationError, "unsupported serialization format requested");
  144. }
  145. mgb_assert(false, "unreachable");
  146. }
  147. Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file) {
  148. #if MGB_ENABLE_FBS_SERIALIZATION
  149. if (is_fbs_file(file)) {
  150. return GraphDumpFormat::FLATBUFFERS;
  151. }
  152. if (is_fbs_v2_file(file)) {
  153. return GraphDumpFormat::FLATBUFFERS_V2;
  154. }
  155. #endif
  156. return {};
  157. }
  158. } // namespace serialization
  159. } // namespace mgb