You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_oss_v2.cpp 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870
  1. #if MGB_ENABLE_FBS_SERIALIZATION
  2. #include "megbrain/comp_node_env.h"
  3. #include "megbrain/opr/io.h"
  4. #include "megbrain/serialization/helper.h"
  5. #include "megbrain/serialization/internal/flatbuffers_helper.h"
  6. #include "megbrain/serialization/internal/schema_v2_generated.h"
  7. #include "megbrain/serialization/metadata.h"
  8. #include "megbrain/serialization/opr_load_dump.h"
  9. #include "megbrain/serialization/oss_opr_load_dump.h"
  10. #include "megbrain/utils/hash_ct.h"
  11. #include "megdnn/tensor_format.h"
  12. #include "serializer_oss_common.h"
  13. #include "megbrain/gopt/framework.h"
  14. namespace mgb {
  15. namespace serialization {
  16. /*!
  17. * \brief replace the the opr who has the replace_opr methord in OprLoadDumpImplV2
  18. */
  19. class PassConvertToCompatible : public gopt::Pass {
  20. ThinHashMap<
  21. Typeinfo*, thin_function<cg::OperatorNodeBase*(
  22. cg::OperatorNodeBase*, const VarNodeArray&)>>
  23. m_opr_replace_func;
  24. gopt::VarReplaceCheckFlag m_var_replace_check_flag =
  25. gopt::VarReplaceCheckFlag::CHECK_ALL;
  26. public:
  27. const char* name() const override { return "PassConvertToCompatible"; };
  28. PassConvertToCompatible& set_var_replace_check_flag(
  29. gopt::VarReplaceCheckFlag flag) {
  30. m_var_replace_check_flag = flag;
  31. return *this;
  32. }
  33. void apply(gopt::OptState& state) const override {
  34. state.set_var_replace_check_flag(m_var_replace_check_flag);
  35. auto rewriter = state.graph().make_rewriter();
  36. auto on_opr = [this, &rewriter](cg::OperatorNodeBase* opr) {
  37. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  38. if (it != m_opr_replace_func.end()) {
  39. VarNodeArray new_inp;
  40. new_inp.clear();
  41. new_inp.reserve(opr->input().size());
  42. for (auto i : opr->input()) {
  43. new_inp.push_back(rewriter.get_var(i));
  44. }
  45. auto new_opr = (it->second)(opr, new_inp);
  46. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  47. for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size());
  48. i++) {
  49. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  50. }
  51. } else {
  52. rewriter.auto_replace_outputs(opr);
  53. }
  54. };
  55. state.graph().iter(on_opr);
  56. rewriter.apply_inplace();
  57. }
  58. static std::unique_ptr<PassConvertToCompatible> make(
  59. const SymbolVarArray& output_vars) {
  60. auto ret = std::make_unique<PassConvertToCompatible>();
  61. // iterate oprs to init
  62. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  63. if (!GraphDumper::should_remove_in_dump(opr)) {
  64. auto registry = OprRegistryV2::versioned_find_by_typeinfo(
  65. opr->dyn_typeinfo(), CURRENT_VERSION);
  66. mgb_throw_if(
  67. !registry,
  68. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  69. "serialization as FlatBuffers is not supported for "
  70. "operator %s, typeinfo %p",
  71. opr->dyn_typeinfo()->name, opr->dyn_typeinfo());
  72. if (registry->converter) {
  73. ret->m_opr_replace_func[opr->dyn_typeinfo()] = registry->converter;
  74. }
  75. }
  76. };
  77. cg::DepOprIter dep_opr_iter{on_opr};
  78. for (auto i : output_vars) {
  79. dep_opr_iter.add(i.node()->owner_opr());
  80. }
  81. return ret;
  82. };
  83. };
  84. namespace {
  85. fbs::v2::TensorFormat get_flatbuffer_tensor_format_type(
  86. const TensorLayout::Format& format) {
  87. using Type = megdnn::TensorFormat::Type;
  88. switch (format.type()) {
  89. case Type::DEFAULT:
  90. return fbs::v2::TensorFormat::TensorFormat_DefaultTensorFormat;
  91. case Type::IMAGE2D_PACK4:
  92. return fbs::v2::TensorFormat::TensorFormat_Image2DPackedTensorFormat;
  93. case Type::LOWBITS_ALIGNED_TO_BYTE:
  94. return fbs::v2::TensorFormat::TensorFormat_LowbitsAlignedTensorFormat;
  95. default:
  96. mgb_throw(
  97. SerializationError, "invalid tensor format type in serialization.");
  98. }
  99. }
  100. } // namespace
  101. flatbuffers::Offset<fbs::DType> GraphDumperOSSV2::build_dtype(DType dtype) {
  102. return fbs::intl::build_dtype(m_builder, dtype);
  103. }
  104. flatbuffers::Offset<void> GraphDumperOSSV2::build_tensor_format(
  105. const TensorLayout::Format& format) {
  106. using Type = megdnn::TensorFormat::Type;
  107. switch (format.type()) {
  108. case Type::DEFAULT:
  109. return fbs::v2::CreateDefaultTensorFormat(m_builder).Union();
  110. case Type::IMAGE2D_PACK4:
  111. return fbs::v2::CreateImage2DPackedTensorFormat(
  112. m_builder, format.as_impl<megdnn::Image2DPack4TensorFormat>()
  113. .align_axis())
  114. .Union();
  115. case Type::LOWBITS_ALIGNED_TO_BYTE: {
  116. auto size_bite = format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
  117. .size_nbits();
  118. auto align_size_in_bits =
  119. format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
  120. .align_size_in_bits();
  121. return fbs::v2::CreateLowbitsAlignedTensorFormat(
  122. m_builder, size_bite, align_size_in_bits)
  123. .Union();
  124. }
  125. default:
  126. mgb_throw(
  127. SerializationError, "invalid tensor format type in serialization.");
  128. }
  129. }
  130. flatbuffers::Offset<fbs::v2::MiddleTensor> GraphDumperOSSV2::build_middle_tensor(
  131. const SymbolVar var) {
  132. mgb_assert(var.node());
  133. auto fbname = m_builder.CreateSharedString(var.node()->name());
  134. flatbuffers::Offset<fbs::v2::MiddleTensor> serialized_middle_tensor;
  135. if (var.node()->dev_tensor_valid()) {
  136. auto layout = var.node()->layout();
  137. auto fshape =
  138. m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
  139. auto fcomp_node = fbs::v2::CreateCompNode(
  140. m_builder, m_builder.CreateSharedString(
  141. var.node()->comp_node().to_string_logical()));
  142. auto fdtype = build_dtype(layout.dtype);
  143. auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
  144. auto fformat = build_tensor_format(layout.format);
  145. serialized_middle_tensor = fbs::v2::CreateMiddleTensor(
  146. m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat);
  147. }
  148. serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
  149. return serialized_middle_tensor;
  150. }
  151. flatbuffers::Offset<fbs::v2::OutputVar> GraphDumperOSSV2::build_output_var(
  152. const SymbolVar var) {
  153. auto out_node = var.node();
  154. if (m_var2midtensor_id.find(var.node()) == m_var2midtensor_id.end()) {
  155. mgb_assert(m_var_remove_in_dump.find(var.node()) != m_var_remove_in_dump.end());
  156. out_node = m_var_remove_in_dump[var.node()];
  157. }
  158. return fbs::v2::CreateOutputVar(
  159. m_builder, m_var2midtensor_id.at(out_node), var.node()->id());
  160. }
  161. void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
  162. m_oprs_to_dump.clear();
  163. // iterate oprs to init
  164. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  165. if (should_remove_in_dump(opr)) {
  166. mgb_assert(opr->input().size() == 1);
  167. // Copy input ID to output
  168. for (auto i : opr->output()) {
  169. if (m_var_remove_in_dump.find(opr->input(0)) !=
  170. m_var_remove_in_dump.end()) {
  171. m_var_remove_in_dump[i] = m_var_remove_in_dump[opr->input(0)];
  172. } else {
  173. m_var_remove_in_dump[i] = opr->input(0);
  174. }
  175. }
  176. } else {
  177. auto registry = OprRegistryV2::versioned_find_by_typeinfo(
  178. opr->dyn_typeinfo(), CURRENT_VERSION);
  179. if (!registry || !registry->dumper) {
  180. mgb_throw(
  181. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  182. "serialization as FlatBuffers is not supported for "
  183. "operator %s",
  184. opr->dyn_typeinfo()->name);
  185. }
  186. m_oprs_to_dump.emplace_back(opr, registry);
  187. }
  188. };
  189. cg::DepOprIter dep_opr_iter{on_opr};
  190. for (auto i : endpoints) {
  191. dep_opr_iter.add(i.node()->owner_opr());
  192. }
  193. }
  194. flatbuffers::Offset<fbs::v2::Metadata> GraphDumperOSSV2::build_metadata(
  195. const Metadata& metadata) {
  196. auto user_info = m_builder.CreateSharedString(metadata.user_info);
  197. fbs::v2::MetadataBuilder builder(m_builder);
  198. builder.add_is_valid(metadata.is_valid);
  199. builder.add_graph_modified(metadata.graph_modified);
  200. builder.add_optimize_options(metadata.optimize_options);
  201. builder.add_user_info(user_info);
  202. return builder.Finish();
  203. }
  204. flatbuffers::Offset<fbs::v2::Operator> GraphDumperOSSV2::build_single_opr(
  205. cg::OperatorNodeBase* opr, const OprRegistryV2* registry) {
  206. m_cur_opr = opr;
  207. ++m_cur_rst.nr_opr;
  208. using namespace flatbuffers;
  209. Offset<Vector<uint32_t>> inputs;
  210. if (m_cur_opr->input().size()) {
  211. std::vector<uint32_t> v;
  212. v.reserve(m_cur_opr->input().size());
  213. for (auto inp : m_cur_opr->input()) {
  214. if (m_var2midtensor_id.find(inp) != m_var2midtensor_id.end()) {
  215. v.emplace_back(m_var2midtensor_id.at(inp));
  216. } else {
  217. mgb_assert(
  218. m_var_remove_in_dump.find(inp) != m_var_remove_in_dump.end(),
  219. "when dump the model, the dependence of var is wrong.");
  220. v.emplace_back(m_var2midtensor_id.at(m_var_remove_in_dump[inp]));
  221. }
  222. }
  223. inputs = m_builder.CreateVector(v);
  224. }
  225. m_cur_opr_tensor.clear();
  226. m_blobs.clear();
  227. m_cur_opr_param.clear();
  228. m_cur_opr_param_type.clear();
  229. registry->dumper(*this, *m_cur_opr);
  230. Offset<Vector<Offset<fbs::v2::CompNode>>> comp_node;
  231. auto& config = m_cur_opr->config();
  232. if (config.has_comp_node_set()) {
  233. std::vector<flatbuffers::Offset<fbs::v2::CompNode>> cns;
  234. for (const auto& cn : config.comp_node()) {
  235. cns.emplace_back(fbs::v2::CreateCompNode(
  236. m_builder, m_builder.CreateSharedString(cn.to_string_logical())));
  237. }
  238. comp_node = m_builder.CreateVector(cns);
  239. }
  240. Offset<String> operator_name;
  241. if (m_config.keep_op_name) {
  242. operator_name = m_builder.CreateSharedString(m_cur_opr->name());
  243. }
  244. auto output_dtype = build_dtype(config.output_dtype());
  245. Offset<Vector<uint32_t>> outputs;
  246. if (m_cur_opr->output().size()) {
  247. std::vector<uint32_t> v;
  248. v.reserve(m_cur_opr->output().size());
  249. for (auto out : m_cur_opr->output()) {
  250. if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  251. auto fbs_out = build_middle_tensor(out);
  252. m_model_middle_tensors.push_back(fbs_out);
  253. m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1;
  254. v.emplace_back(m_var2midtensor_id.at(out));
  255. }
  256. }
  257. outputs = m_builder.CreateVector(v);
  258. }
  259. Offset<Vector<Offset<fbs::v2::Tensor>>> tensors;
  260. if (m_cur_opr_tensor.size())
  261. tensors = m_builder.CreateVector(m_cur_opr_tensor);
  262. //! the blobs data is used by custom data
  263. //! m_blobs will be filled by the Operator dumper function
  264. Offset<Vector<Offset<fbs::v2::Blob>>> blobs;
  265. if (m_blobs.size())
  266. blobs = m_builder.CreateVector(m_blobs);
  267. Offset<Vector<uint8_t>> additional_params_type;
  268. Offset<Vector<Offset<void>>> additional_params;
  269. auto param_cnt = m_cur_opr_param_type.size();
  270. if (param_cnt > 1) {
  271. additional_params_type = m_builder.CreateVectorScalarCast<uint8_t>(
  272. m_cur_opr_param_type.data() + 1, param_cnt - 1);
  273. additional_params =
  274. m_builder.CreateVector(m_cur_opr_param.data() + 1, param_cnt - 1);
  275. }
  276. auto opr_type = m_builder.CreateSharedString(registry->name);
  277. fbs::v2::OperatorBuilder builder(m_builder);
  278. builder.add_type(opr_type);
  279. builder.add_type_id(registry->type_id);
  280. builder.add_inputs(inputs);
  281. builder.add_outputs(outputs);
  282. if (m_config.keep_opr_priority) {
  283. builder.add_priority(opr->node_prop().attribute().priority);
  284. }
  285. builder.add_comp_node(comp_node);
  286. builder.add_opr_version(registry->get_version());
  287. builder.add_name(operator_name);
  288. builder.add_output_dtype(output_dtype);
  289. if (param_cnt > 0) {
  290. builder.add_param_type(m_cur_opr_param_type[0]);
  291. builder.add_param(m_cur_opr_param[0]);
  292. }
  293. if (param_cnt > 1) {
  294. builder.add_additional_params_type(additional_params_type);
  295. builder.add_additional_params(additional_params);
  296. }
  297. builder.add_tensors(tensors);
  298. builder.add_custom_data(blobs);
  299. m_cur_opr = nullptr;
  300. return builder.Finish();
  301. }
  302. SymbolVarArray GraphDumperOSSV2::converter_all_opr_to_compatiable(
  303. const SymbolVarArray& output_vars) {
  304. gopt::GraphOptimizer optimizer;
  305. VarNodeArray rets_var;
  306. for (auto& symbolvar : output_vars) {
  307. rets_var.push_back(symbolvar.node());
  308. }
  309. optimizer.add_pass(PassConvertToCompatible::make(output_vars));
  310. optimizer.apply_inplace(rets_var);
  311. SymbolVarArray dst_vars;
  312. for (auto& var : rets_var) {
  313. dst_vars.push_back({var});
  314. }
  315. return dst_vars;
  316. }
  317. GraphDumper::DumpResult GraphDumperOSSV2::dump(
  318. const SymbolVarArray& output_vars, const DumpConfig& config,
  319. const Metadata& metadata) {
  320. mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph");
  321. auto&& new_output_vars = converter_all_opr_to_compatiable(output_vars);
  322. auto begin_pos = m_file->tell();
  323. m_config = config;
  324. m_builder.Reset();
  325. m_output_vars.clear();
  326. m_cur_rst = {};
  327. m_used_input_names.clear();
  328. m_used_param_names.clear();
  329. m_var_remove_in_dump.clear();
  330. m_model_middle_tensors.clear();
  331. m_var2midtensor_id.clear();
  332. m_nr_shared_tensor = 0;
  333. // process output vars
  334. bool keep_output_var_name = m_config.keep_var_name >= 1;
  335. std::unordered_set<std::string> output_var_names;
  336. for (auto i : new_output_vars) {
  337. mgb_assert(
  338. !i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  339. "can not dump var with VOLATILE_CONTENT flag: %s",
  340. cg::dump_var_info({i.node()}).c_str());
  341. if (m_output_vars.insert(i.node()).second && keep_output_var_name) {
  342. auto name_ins = output_var_names.insert(i.node()->name()).second;
  343. mgb_assert(name_ins, "duplicated output var name: %s", i.node()->cname());
  344. }
  345. }
  346. // Dump metadata
  347. auto fbmeta = build_metadata(metadata);
  348. // Dump operators
  349. init_oprs_to_dump(new_output_vars);
  350. std::vector<flatbuffers::Offset<fbs::v2::Operator>> oprs;
  351. for (auto&& i : m_oprs_to_dump) {
  352. oprs.emplace_back(build_single_opr(i.first, i.second));
  353. }
  354. auto fb_oprs = m_builder.CreateVector(oprs);
  355. // Dump output vars
  356. std::vector<flatbuffers::Offset<fbs::v2::OutputVar>> output_vars_idx;
  357. output_vars_idx.reserve(new_output_vars.size());
  358. for (auto i : new_output_vars) {
  359. auto foutput_vars_idx = build_output_var(i);
  360. output_vars_idx.push_back(foutput_vars_idx);
  361. }
  362. auto fb_output_vars = m_builder.CreateVector(output_vars_idx);
  363. std::vector<flatbuffers::Offset<fbs::v2::OutputAlias>> output_vars_alias;
  364. if (m_config.alias_name_map.size() > 0) {
  365. for (auto&& pair : m_config.alias_name_map) {
  366. std::string name;
  367. SymbolVar var;
  368. std::tie(name, var) = pair;
  369. auto fbs_name = m_builder.CreateSharedString(name);
  370. output_vars_alias.push_back(
  371. fbs::v2::CreateOutputAlias(m_builder, var.node()->id(), fbs_name));
  372. }
  373. }
  374. auto fbs_output_alias = m_builder.CreateVector(output_vars_alias);
  375. auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
  376. fbs::v2::ModelBuilder model(m_builder);
  377. model.add_mge_version(MGB_VERSION);
  378. model.add_oprs(fb_oprs);
  379. model.add_middle_tensors(fb_mid_tensor);
  380. model.add_output_vars_idx(fb_output_vars);
  381. model.add_output_alias(fbs_output_alias);
  382. model.add_nr_shared_tensor(m_nr_shared_tensor);
  383. model.add_metadata(fbmeta);
  384. m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier());
  385. // Write serialized fbs::Graph
  386. m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize());
  387. // Finalize DumpResult
  388. auto&& ret = m_cur_rst;
  389. for (size_t i = 0; i < new_output_vars.size(); i++) {
  390. ret.outputs.emplace_back(
  391. keep_output_var_name ? new_output_vars[i].node()->cname()
  392. : ssprintf("unnamed%zu", i));
  393. }
  394. std::sort(ret.inputs.begin(), ret.inputs.end());
  395. mgb_assert(ret.nr_opr == m_oprs_to_dump.size());
  396. ret.tot_bytes = m_file->tell() - begin_pos;
  397. return ret;
  398. }
  399. void GraphDumperOSSV2::dump_tensor(
  400. const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) {
  401. using namespace flatbuffers;
  402. using Meth = TensorWriteMethod;
  403. mgb_assert(
  404. (method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()),
  405. "name must be non-empty for non Meth::VALUE_ANONYMOUS tensors");
  406. bool has_value = method != Meth::META_INPUT;
  407. bool should_keep_name = true;
  408. switch (method) {
  409. case Meth::VALUE_ANONYMOUS:
  410. should_keep_name = false;
  411. break;
  412. case Meth::VALUE_SHARED:
  413. should_keep_name = m_config.keep_param_name;
  414. ++m_nr_shared_tensor;
  415. if (m_config.keep_param_name) {
  416. mgb_assert(
  417. m_used_param_names.insert(name).second,
  418. "duplicated VALUE_SHARED tensor name: %s", name.c_str());
  419. m_cur_rst.params.emplace_back(name);
  420. }
  421. break;
  422. case Meth::META_INPUT:
  423. case Meth::VALUE_INPUT:
  424. mgb_assert(!name.empty(), "empty input tensor name");
  425. mgb_assert(
  426. m_used_input_names.insert(name).second,
  427. "duplicated input tensor name: %s", name.c_str());
  428. m_cur_rst.inputs.emplace_back(name);
  429. break;
  430. }
  431. auto& layout = tensor.layout();
  432. flatbuffers::Offset<flatbuffers::Vector<uint8_t>> data;
  433. if (has_value) {
  434. check_tensor_value_valid(name, tensor);
  435. auto&& dumper = m_config.tensor_value_dumper;
  436. if (dumper) {
  437. mgb_log_warn(
  438. "serialization v2 format is pure flatbuffer format, not support "
  439. "user tensor value dumper callback.");
  440. }
  441. data = m_builder.CreateVector(
  442. reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte);
  443. m_cur_rst.tensor_value_bytes += layout.span().high_byte;
  444. }
  445. auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0;
  446. auto fshape = m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
  447. auto fcomp_node = fbs::v2::CreateCompNode(
  448. m_builder,
  449. m_builder.CreateSharedString(tensor.comp_node().to_string_logical()));
  450. auto fdtype = build_dtype(layout.dtype);
  451. auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
  452. auto fformat = build_tensor_format(layout.format);
  453. auto serialized_tensor = fbs::v2::CreateTensor(
  454. m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data);
  455. m_cur_opr_tensor.emplace_back(serialized_tensor);
  456. }
  457. void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) {
  458. auto blob = fbs::v2::CreateBlob(
  459. m_builder, m_builder.CreateVector(static_cast<const uint8_t*>(data), size));
  460. m_blobs.emplace_back(blob);
  461. }
  462. // ----------------------------- Loader --------------------------------------
  463. CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node(
  464. const fbs::v2::CompNode* comp_node) {
  465. mgb_assert(comp_node);
  466. if (!comp_node->logical_locator())
  467. return {};
  468. auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str());
  469. m_loader->m_cur_load_config->comp_node_mapper(loc);
  470. return CompNode::load(loc);
  471. }
  472. TensorFormat load_tensor_format(
  473. const fbs::v2::TensorFormat fformat_type, const void* fformat,
  474. const CompNode& comp_node) {
  475. switch (fformat_type) {
  476. case fbs::v2::TensorFormat_DefaultTensorFormat:
  477. return megdnn::DefaultTensorFormat::make();
  478. case fbs::v2::TensorFormat_Image2DPackedTensorFormat: {
  479. auto image_format =
  480. static_cast<const fbs::v2::Image2DPackedTensorFormat*>(fformat);
  481. auto handle =
  482. MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle();
  483. return megdnn::Image2DPack4TensorFormat::make(
  484. image_format->align_axis(), handle);
  485. }
  486. case fbs::v2::TensorFormat_LowbitsAlignedTensorFormat: {
  487. auto lowbit_format =
  488. static_cast<const fbs::v2::LowbitsAlignedTensorFormat*>(fformat);
  489. return megdnn::LowbitsAlignedToBytesTensorFormat::make(
  490. lowbit_format->size_nbits());
  491. }
  492. default:
  493. mgb_throw(
  494. SerializationError, "invalid tensor format type in serialization.");
  495. }
  496. }
  497. TensorLayout load_tensor_layout(
  498. const fbs::v2::Tensor* tensor, const CompNode& comp_node) {
  499. TensorLayout layout;
  500. if (tensor->shape()) {
  501. layout.ndim = tensor->shape()->size();
  502. std::copy(tensor->shape()->begin(), tensor->shape()->end(), layout.shape);
  503. }
  504. if (tensor->dtype()) {
  505. // modify data type inplace for TensorLayout
  506. layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype()));
  507. }
  508. if (tensor->format() && tensor->format_type()) {
  509. layout.format =
  510. load_tensor_format(tensor->format_type(), tensor->format(), comp_node);
  511. }
  512. layout.init_contiguous_stride();
  513. return layout;
  514. }
  515. //! the opr loader should make sure the exist of tensors and the number of
  516. //! tensor, here just assert it.
  517. std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() {
  518. mgb_assert(
  519. m_current_opr->tensors() &&
  520. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  521. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  522. auto comp_node = load_comp_node(tensor->comp_node());
  523. auto layout = load_tensor_layout(tensor, comp_node);
  524. auto ret = std::make_shared<HostTensorND>(comp_node, layout);
  525. auto&& loader = m_loader->m_cur_load_config->tensor_value_loader;
  526. if (tensor->data() && tensor->data()->size() > 0) {
  527. if (loader) {
  528. mgb_log_warn(
  529. "serialization v2 format is pure flatbuffer format, not support "
  530. "user tensor value loader callback.");
  531. }
  532. memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size());
  533. }
  534. if (tensor->name()) {
  535. m_tensor_map[tensor->name()->str()] = ret;
  536. }
  537. if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) {
  538. bool has_value = false;
  539. if (tensor && tensor->data()) {
  540. has_value = tensor->data()->size() != 0;
  541. }
  542. mod(tensor->name() ? tensor->name()->str() : "", has_value, *ret);
  543. }
  544. return ret;
  545. }
  546. std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
  547. load_tensor_shared() {
  548. mgb_assert(
  549. m_current_opr->tensors() &&
  550. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  551. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  552. auto comp_node = load_comp_node(tensor->comp_node());
  553. auto layout = load_tensor_layout(tensor, comp_node);
  554. mgb_assert(tensor->data());
  555. auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
  556. auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()];
  557. if (shared_tensor_ref) {
  558. if (shared_tensor_ref->comp_node() == comp_node)
  559. return shared_tensor_ref;
  560. // same mem node but different comp node, change comp node and share
  561. // value
  562. auto ret = std::make_shared<DeviceTensorND>(*shared_tensor_ref);
  563. ret->comp_node(comp_node);
  564. return ret;
  565. }
  566. if (tensor->name()) {
  567. shared_pair.first = tensor->name()->str();
  568. }
  569. if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
  570. // directly forward CPU memory
  571. HostTensorND hv{comp_node};
  572. if (tensor->data() && tensor->data()->size() > 0) {
  573. hv.dtype(layout.dtype).resize(layout);
  574. memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size());
  575. }
  576. shared_tensor_ref = std::make_shared<DeviceTensorND>();
  577. *shared_tensor_ref = DeviceTensorND::make_proxy(hv);
  578. } else {
  579. // use lazy load for non-CPU devices
  580. HostTensorND hv{CompNode::default_cpu()};
  581. if (tensor->data() && tensor->data()->size() > 0) {
  582. hv.dtype(layout.dtype).resize(layout);
  583. memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size());
  584. }
  585. shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv));
  586. }
  587. return shared_tensor_ref;
  588. }
  589. Metadata GraphLoaderOSSV2::OprLoadContextImpl::load_metadata() {
  590. const auto* fbmeta = m_loader->m_model->metadata();
  591. Metadata ret;
  592. if (fbmeta) {
  593. ret.is_valid = fbmeta->is_valid();
  594. ret.graph_modified = fbmeta->graph_modified();
  595. if (fbmeta->user_info()) {
  596. ret.user_info = fbmeta->user_info()->str();
  597. ret.has_user_info = true;
  598. }
  599. if (fbmeta->optimize_options()) {
  600. ret.optimize_options = fbmeta->optimize_options();
  601. ret.optimized_for_inference = true;
  602. }
  603. }
  604. return ret;
  605. }
  606. void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
  607. const fbs::v2::Operator* fbopr) {
  608. m_cur_opr_tensor_cnt = 0;
  609. m_cur_opr_blob_cnt = 0;
  610. m_cur_opr_param_cnt = 0;
  611. OperatorNodeConfig config;
  612. if (fbopr->output_dtype()) {
  613. config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
  614. }
  615. if (fbopr->name()) {
  616. config.name(fbopr->name()->str());
  617. }
  618. if (fbopr->comp_node()) {
  619. auto cnt = fbopr->comp_node()->size();
  620. cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
  621. for (size_t i = 0; i < cnt; i++) {
  622. CompNode cn{};
  623. auto node = fbopr->comp_node()->Get(i);
  624. if (node) {
  625. cn = load_comp_node(node);
  626. }
  627. comp_node_arr[i] = cn;
  628. }
  629. config.comp_node_arr(comp_node_arr);
  630. }
  631. //! opr version must be exist
  632. uint8_t opr_version = fbopr->opr_version();
  633. auto type_id = fbopr->type_id();
  634. const OprRegistryV2* registry =
  635. OprRegistryV2::versioned_find_by_id(type_id, opr_version);
  636. mgb_throw_if(
  637. !registry, SerializationError,
  638. "failed to find opr with type %s , use python env "
  639. "config.dump_registered_oprs() to get a dict that maps from "
  640. "opr id to opr name",
  641. fbopr->type()->str().c_str());
  642. // load inputs
  643. VarNodeArray inputs;
  644. if (fbopr->inputs()) {
  645. inputs.resize(fbopr->inputs()->size());
  646. for (size_t i = 0; i < inputs.size(); ++i) {
  647. inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i));
  648. }
  649. }
  650. // call loader
  651. auto accessor = registry->loader(*this, inputs, config);
  652. auto opr = accessor.opr();
  653. // check opr type; note that:
  654. // 1. registry->type may be empty for dynamic opr loaders or legacy oprs
  655. // 2. due to some optimization, an opr may be replaced by ImmutableTensor
  656. mgb_assert(
  657. opr && (opr->dyn_typeinfo() == registry->type || !registry->type ||
  658. opr->same_type<opr::ImmutableTensor>()),
  659. "got_type=%s expected_type=%s", opr ? opr->dyn_typeinfo()->name : nullptr,
  660. registry->type->name);
  661. // record output vars; read output names
  662. size_t i = 0;
  663. for (auto ovar : accessor.output()) {
  664. if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  665. m_id2varnode.push_back(ovar);
  666. if (fbopr->outputs()) {
  667. auto id = fbopr->outputs()->Get(i);
  668. mgb_assert(
  669. m_id2varnode.size() - 1 == fbopr->outputs()->Get(i),
  670. "id2var is %zu, fbs get id is %d\n", m_id2varnode.size() - 1,
  671. fbopr->outputs()->Get(i));
  672. if (m_middle_tensors.size() > i) {
  673. auto name = m_middle_tensors[id]->name()->str();
  674. ovar->name(name);
  675. }
  676. }
  677. i++;
  678. }
  679. }
  680. opr->node_prop().attribute().priority = fbopr->priority();
  681. }
  682. GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() {
  683. // load oprs
  684. const auto* oprs = m_loader->m_model->oprs();
  685. {
  686. // inplace arith graph optimization is disabled during opr load
  687. // it tries to restore the same graph as it was dumped
  688. // see test TestSerializer2.LOGEXP for example
  689. GraphLoader::ScopedGraphOptDisabler _(m_graph);
  690. for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) {
  691. m_current_opr = oprs->Get(i);
  692. load_single_opr(m_current_opr);
  693. }
  694. }
  695. // batched loading device values
  696. m_device_value_loader.apply();
  697. LoadResult ret;
  698. ret.graph = m_graph;
  699. ret.tensor_map = m_tensor_map;
  700. const auto* outputs = m_loader->m_model->output_vars_idx();
  701. ret.output_var_list.resize(outputs->size());
  702. for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) {
  703. auto out = outputs->Get(i);
  704. auto var = m_id2varnode.at(out->compact_id());
  705. ret.output_var_map[var->name()] = var;
  706. ret.output_var_map_id[out->original_id()] = var;
  707. ret.output_var_list[i] = var;
  708. }
  709. mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size());
  710. return ret;
  711. }
  712. void GraphLoaderOSSV2::OprLoadContextImpl::load_middle_tensor() {
  713. auto model = m_loader->m_model;
  714. if (model->middle_tensors()) {
  715. for (unsigned int i = 0; i < m_loader->m_model->middle_tensors()->size(); i++) {
  716. m_middle_tensors.push_back(model->middle_tensors()->Get(i));
  717. }
  718. }
  719. }
  720. GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool rewind) {
  721. mgb_assert(m_file);
  722. m_cur_load_config = &config;
  723. if (rewind) {
  724. m_file->rewind();
  725. }
  726. // Read fbs::Graph
  727. uint32_t size;
  728. m_file->read(&size, sizeof(size));
  729. m_model_buf = m_file->read_shared(size);
  730. mgb_throw_if(
  731. !fbs::v2::ModelBufferHasIdentifier(m_model_buf.data()), SerializationError,
  732. "invalid fbs model");
  733. {
  734. flatbuffers::Verifier verifier(
  735. static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size());
  736. mgb_throw_if(
  737. !fbs::v2::VerifyModelBuffer(verifier), SerializationError,
  738. "model verification failed (invalid or corrupted model?)");
  739. }
  740. m_model = fbs::v2::GetModel(m_model_buf.data());
  741. m_mgb_version = m_model->mge_version();
  742. if (m_model->mge_version() > MGB_VERSION) {
  743. mgb_log_warn(
  744. "loading model from future runtime: version=%u "
  745. "model_version=%u",
  746. MGB_VERSION, m_model->mge_version());
  747. }
  748. if (m_shared_tensor_map.empty()) {
  749. m_shared_tensor_map.resize(m_model->nr_shared_tensor());
  750. } else {
  751. mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor());
  752. }
  753. OprLoadContextImpl ctx{this, m_model->mge_version()};
  754. ctx.load_middle_tensor();
  755. auto metadata = ctx.load_metadata();
  756. auto result = ctx.load_oprs();
  757. result.metadata = metadata;
  758. if (m_model->output_alias() && m_model->output_alias()->size() > 0) {
  759. auto nr_alias = m_model->output_alias()->size();
  760. result.output_var_list.resize(nr_alias);
  761. for (size_t i = 0; i < nr_alias; i++) {
  762. auto output_alias = m_model->output_alias()->Get(i);
  763. std::string name = output_alias->name()->str();
  764. size_t id = output_alias->id();
  765. result.output_var_map[name] = result.output_var_map_id[id];
  766. result.output_var_list[i] = result.output_var_map_id[id];
  767. }
  768. }
  769. m_model_loaded = true;
  770. result.graph_compile_ahead();
  771. return result;
  772. }
  773. std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file) {
  774. return std::make_unique<GraphDumperOSSV2>(std::move(file));
  775. }
  776. std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) {
  777. return std::make_unique<GraphLoaderOSSV2>(std::move(file));
  778. }
  779. bool is_fbs_v2_file(InputFile& file) {
  780. constexpr size_t identifier_length = 25;
  781. char identifier[identifier_length];
  782. file.read(identifier, identifier_length);
  783. file.skip(-identifier_length);
  784. //! skip the size in prefix of the file
  785. return fbs::v2::ModelBufferHasIdentifier(identifier + sizeof(uint32_t));
  786. }
  787. } // namespace serialization
  788. } // namespace mgb
  789. #endif
  790. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}