You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_oss_v2.cpp 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883
  1. #if MGB_ENABLE_FBS_SERIALIZATION
  2. #include "megbrain/comp_node_env.h"
  3. #include "megbrain/opr/io.h"
  4. #include "megbrain/serialization/helper.h"
  5. #include "megbrain/serialization/internal/flatbuffers_helper.h"
  6. #include "megbrain/serialization/internal/schema_v2_generated.h"
  7. #include "megbrain/serialization/metadata.h"
  8. #include "megbrain/serialization/opr_load_dump.h"
  9. #include "megbrain/serialization/oss_opr_load_dump.h"
  10. #include "megbrain/utils/hash_ct.h"
  11. #include "megdnn/tensor_format.h"
  12. #include "serializer_oss_common.h"
  13. #include "megbrain/gopt/framework.h"
  14. namespace mgb {
  15. namespace serialization {
  16. /*!
  17. * \brief replace the the opr who has the replace_opr methord in OprLoadDumpImplV2
  18. */
  19. class PassConvertToCompatible : public gopt::Pass {
  20. ThinHashMap<
  21. Typeinfo*, thin_function<cg::OperatorNodeBase*(
  22. cg::OperatorNodeBase*, const VarNodeArray&)>>
  23. m_opr_replace_func;
  24. gopt::VarReplaceCheckFlag m_var_replace_check_flag =
  25. gopt::VarReplaceCheckFlag::CHECK_ALL;
  26. public:
  27. const char* name() const override { return "PassConvertToCompatible"; };
  28. PassConvertToCompatible& set_var_replace_check_flag(
  29. gopt::VarReplaceCheckFlag flag) {
  30. m_var_replace_check_flag = flag;
  31. return *this;
  32. }
  33. void apply(gopt::OptState& state) const override {
  34. state.set_var_replace_check_flag(m_var_replace_check_flag);
  35. auto rewriter = state.graph().make_rewriter();
  36. auto on_opr = [this, &rewriter](cg::OperatorNodeBase* opr) {
  37. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  38. if (it != m_opr_replace_func.end()) {
  39. VarNodeArray new_inp;
  40. new_inp.clear();
  41. new_inp.reserve(opr->input().size());
  42. for (auto i : opr->input()) {
  43. new_inp.push_back(rewriter.get_var(i));
  44. }
  45. auto new_opr = (it->second)(opr, new_inp);
  46. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  47. for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size());
  48. i++) {
  49. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  50. }
  51. } else {
  52. rewriter.auto_replace_outputs(opr);
  53. }
  54. };
  55. state.graph().iter(on_opr);
  56. rewriter.apply_inplace();
  57. }
  58. static std::unique_ptr<PassConvertToCompatible> make(
  59. const SymbolVarArray& output_vars) {
  60. auto ret = std::make_unique<PassConvertToCompatible>();
  61. // iterate oprs to init
  62. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  63. if (!GraphDumper::should_remove_in_dump(opr)) {
  64. auto registry = OprRegistryV2::versioned_find_by_typeinfo(
  65. opr->dyn_typeinfo(), CURRENT_VERSION);
  66. mgb_throw_if(
  67. !registry,
  68. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  69. "serialization as FlatBuffers is not supported for "
  70. "operator %s, typeinfo %p",
  71. opr->dyn_typeinfo()->name, opr->dyn_typeinfo());
  72. if (registry->converter) {
  73. ret->m_opr_replace_func[opr->dyn_typeinfo()] = registry->converter;
  74. }
  75. }
  76. };
  77. cg::DepOprIter dep_opr_iter{on_opr};
  78. for (auto i : output_vars) {
  79. dep_opr_iter.add(i.node()->owner_opr());
  80. }
  81. return ret;
  82. };
  83. };
  84. namespace {
  85. fbs::v2::TensorFormat get_flatbuffer_tensor_format_type(
  86. const TensorLayout::Format& format) {
  87. using Type = megdnn::TensorFormat::Type;
  88. switch (format.type()) {
  89. case Type::DEFAULT:
  90. return fbs::v2::TensorFormat::TensorFormat_DefaultTensorFormat;
  91. case Type::IMAGE2D_PACK4:
  92. return fbs::v2::TensorFormat::TensorFormat_Image2DPackedTensorFormat;
  93. case Type::LOWBITS_ALIGNED_TO_BYTE:
  94. return fbs::v2::TensorFormat::TensorFormat_LowbitsAlignedTensorFormat;
  95. default:
  96. mgb_throw(
  97. SerializationError, "invalid tensor format type in serialization.");
  98. }
  99. }
  100. } // namespace
  101. flatbuffers::Offset<fbs::DType> GraphDumperOSSV2::build_dtype(DType dtype) {
  102. return fbs::intl::build_dtype(m_builder, dtype);
  103. }
  104. flatbuffers::Offset<void> GraphDumperOSSV2::build_tensor_format(
  105. const TensorLayout::Format& format) {
  106. using Type = megdnn::TensorFormat::Type;
  107. switch (format.type()) {
  108. case Type::DEFAULT:
  109. return fbs::v2::CreateDefaultTensorFormat(m_builder).Union();
  110. case Type::IMAGE2D_PACK4:
  111. return fbs::v2::CreateImage2DPackedTensorFormat(
  112. m_builder, format.as_impl<megdnn::Image2DPack4TensorFormat>()
  113. .align_axis())
  114. .Union();
  115. case Type::LOWBITS_ALIGNED_TO_BYTE: {
  116. auto size_bite = format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
  117. .size_nbits();
  118. auto align_size_in_bits =
  119. format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
  120. .align_size_in_bits();
  121. return fbs::v2::CreateLowbitsAlignedTensorFormat(
  122. m_builder, size_bite, align_size_in_bits)
  123. .Union();
  124. }
  125. default:
  126. mgb_throw(
  127. SerializationError, "invalid tensor format type in serialization.");
  128. }
  129. }
  130. flatbuffers::Offset<fbs::v2::MiddleTensor> GraphDumperOSSV2::build_middle_tensor(
  131. const SymbolVar var) {
  132. mgb_assert(var.node());
  133. auto fbname = m_builder.CreateSharedString(var.node()->name());
  134. flatbuffers::Offset<fbs::v2::MiddleTensor> serialized_middle_tensor;
  135. if (var.node()->dev_tensor_valid()) {
  136. auto layout = var.node()->layout();
  137. auto fshape =
  138. m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
  139. auto fcomp_node = fbs::v2::CreateCompNode(
  140. m_builder, m_builder.CreateSharedString(
  141. var.node()->comp_node().to_string_logical()));
  142. auto fdtype = build_dtype(layout.dtype);
  143. auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
  144. auto fformat = build_tensor_format(layout.format);
  145. serialized_middle_tensor = fbs::v2::CreateMiddleTensor(
  146. m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat);
  147. }
  148. serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
  149. return serialized_middle_tensor;
  150. }
  151. flatbuffers::Offset<fbs::v2::OutputVar> GraphDumperOSSV2::build_output_var(
  152. const SymbolVar var) {
  153. auto out_node = var.node();
  154. if (m_var2midtensor_id.find(var.node()) == m_var2midtensor_id.end()) {
  155. mgb_assert(m_var_remove_in_dump.find(var.node()) != m_var_remove_in_dump.end());
  156. out_node = m_var_remove_in_dump[var.node()];
  157. }
  158. return fbs::v2::CreateOutputVar(
  159. m_builder, m_var2midtensor_id.at(out_node), var.node()->id());
  160. }
  161. void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
  162. m_oprs_to_dump.clear();
  163. // iterate oprs to init
  164. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  165. if (should_remove_in_dump(opr)) {
  166. mgb_assert(opr->input().size() == 1);
  167. // Copy input ID to output
  168. for (auto i : opr->output()) {
  169. if (m_var_remove_in_dump.find(opr->input(0)) !=
  170. m_var_remove_in_dump.end()) {
  171. m_var_remove_in_dump[i] = m_var_remove_in_dump[opr->input(0)];
  172. } else {
  173. m_var_remove_in_dump[i] = opr->input(0);
  174. }
  175. }
  176. } else {
  177. auto registry = OprRegistryV2::versioned_find_by_typeinfo(
  178. opr->dyn_typeinfo(), m_version);
  179. if (!registry || !registry->dumper) {
  180. mgb_throw(
  181. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  182. "serialization as FlatBuffers is not supported for "
  183. "operator %s",
  184. opr->dyn_typeinfo()->name);
  185. }
  186. mgb_assert(
  187. registry->version <= m_version,
  188. "The Operator version should less than model version");
  189. m_oprs_to_dump.emplace_back(opr, registry);
  190. }
  191. };
  192. cg::DepOprIter dep_opr_iter{on_opr};
  193. for (auto i : endpoints) {
  194. dep_opr_iter.add(i.node()->owner_opr());
  195. }
  196. }
  197. flatbuffers::Offset<fbs::v2::Metadata> GraphDumperOSSV2::build_metadata(
  198. const Metadata& metadata) {
  199. auto user_info = m_builder.CreateSharedString(metadata.user_info);
  200. fbs::v2::MetadataBuilder builder(m_builder);
  201. builder.add_is_valid(metadata.is_valid);
  202. builder.add_graph_modified(metadata.graph_modified);
  203. builder.add_optimize_options(metadata.optimize_options);
  204. builder.add_user_info(user_info);
  205. return builder.Finish();
  206. }
  207. flatbuffers::Offset<fbs::v2::Operator> GraphDumperOSSV2::build_single_opr(
  208. cg::OperatorNodeBase* opr, const OprRegistryV2* registry) {
  209. m_cur_opr = opr;
  210. ++m_cur_rst.nr_opr;
  211. using namespace flatbuffers;
  212. Offset<Vector<uint32_t>> inputs;
  213. if (m_cur_opr->input().size()) {
  214. std::vector<uint32_t> v;
  215. v.reserve(m_cur_opr->input().size());
  216. for (auto inp : m_cur_opr->input()) {
  217. if (m_var2midtensor_id.find(inp) != m_var2midtensor_id.end()) {
  218. v.emplace_back(m_var2midtensor_id.at(inp));
  219. } else {
  220. mgb_assert(
  221. m_var_remove_in_dump.find(inp) != m_var_remove_in_dump.end(),
  222. "when dump the model, the dependence of var is wrong.");
  223. v.emplace_back(m_var2midtensor_id.at(m_var_remove_in_dump[inp]));
  224. }
  225. }
  226. inputs = m_builder.CreateVector(v);
  227. }
  228. m_cur_opr_tensor.clear();
  229. m_blobs.clear();
  230. m_cur_opr_param.clear();
  231. m_cur_opr_param_type.clear();
  232. registry->dumper(*this, *m_cur_opr);
  233. Offset<Vector<Offset<fbs::v2::CompNode>>> comp_node;
  234. auto& config = m_cur_opr->config();
  235. if (config.has_comp_node_set()) {
  236. std::vector<flatbuffers::Offset<fbs::v2::CompNode>> cns;
  237. for (const auto& cn : config.comp_node()) {
  238. cns.emplace_back(fbs::v2::CreateCompNode(
  239. m_builder, m_builder.CreateSharedString(cn.to_string_logical())));
  240. }
  241. comp_node = m_builder.CreateVector(cns);
  242. }
  243. Offset<String> operator_name;
  244. if (m_config.keep_op_name) {
  245. operator_name = m_builder.CreateSharedString(m_cur_opr->name());
  246. }
  247. auto output_dtype = build_dtype(config.output_dtype());
  248. Offset<Vector<uint32_t>> outputs;
  249. if (m_cur_opr->output().size()) {
  250. std::vector<uint32_t> v;
  251. v.reserve(m_cur_opr->output().size());
  252. for (auto out : m_cur_opr->output()) {
  253. if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  254. auto fbs_out = build_middle_tensor(out);
  255. m_model_middle_tensors.push_back(fbs_out);
  256. m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1;
  257. v.emplace_back(m_var2midtensor_id.at(out));
  258. }
  259. }
  260. outputs = m_builder.CreateVector(v);
  261. }
  262. Offset<Vector<Offset<fbs::v2::Tensor>>> tensors;
  263. if (m_cur_opr_tensor.size())
  264. tensors = m_builder.CreateVector(m_cur_opr_tensor);
  265. //! the blobs data is used by custom data
  266. //! m_blobs will be filled by the Operator dumper function
  267. Offset<Vector<Offset<fbs::v2::Blob>>> blobs;
  268. if (m_blobs.size())
  269. blobs = m_builder.CreateVector(m_blobs);
  270. Offset<Vector<uint8_t>> additional_params_type;
  271. Offset<Vector<Offset<void>>> additional_params;
  272. auto param_cnt = m_cur_opr_param_type.size();
  273. if (param_cnt > 1) {
  274. additional_params_type = m_builder.CreateVectorScalarCast<uint8_t>(
  275. m_cur_opr_param_type.data() + 1, param_cnt - 1);
  276. additional_params =
  277. m_builder.CreateVector(m_cur_opr_param.data() + 1, param_cnt - 1);
  278. }
  279. auto opr_type = m_builder.CreateSharedString(registry->name);
  280. fbs::v2::OperatorBuilder builder(m_builder);
  281. builder.add_type(opr_type);
  282. builder.add_type_id(registry->type_id);
  283. builder.add_inputs(inputs);
  284. builder.add_outputs(outputs);
  285. if (m_config.keep_opr_priority) {
  286. builder.add_priority(opr->node_prop().attribute().priority);
  287. }
  288. builder.add_comp_node(comp_node);
  289. builder.add_opr_version(registry->get_version());
  290. builder.add_name(operator_name);
  291. builder.add_output_dtype(output_dtype);
  292. if (param_cnt > 0) {
  293. builder.add_param_type(m_cur_opr_param_type[0]);
  294. builder.add_param(m_cur_opr_param[0]);
  295. }
  296. if (param_cnt > 1) {
  297. builder.add_additional_params_type(additional_params_type);
  298. builder.add_additional_params(additional_params);
  299. }
  300. builder.add_tensors(tensors);
  301. builder.add_custom_data(blobs);
  302. m_cur_opr = nullptr;
  303. return builder.Finish();
  304. }
  305. SymbolVarArray GraphDumperOSSV2::converter_all_opr_to_compatiable(
  306. const SymbolVarArray& output_vars) {
  307. gopt::GraphOptimizer optimizer;
  308. VarNodeArray rets_var;
  309. for (auto& symbolvar : output_vars) {
  310. rets_var.push_back(symbolvar.node());
  311. }
  312. optimizer.add_pass(PassConvertToCompatible::make(output_vars));
  313. optimizer.apply_inplace(rets_var);
  314. SymbolVarArray dst_vars;
  315. for (auto& var : rets_var) {
  316. dst_vars.push_back({var});
  317. }
  318. return dst_vars;
  319. }
  320. GraphDumper::DumpResult GraphDumperOSSV2::dump(
  321. const SymbolVarArray& output_vars, const DumpConfig& config,
  322. const Metadata& metadata) {
  323. mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph");
  324. auto new_output_vars = output_vars;
  325. if (!config.no_change_graph) {
  326. new_output_vars = converter_all_opr_to_compatiable(output_vars);
  327. }
  328. auto begin_pos = m_file->tell();
  329. m_config = config;
  330. m_builder.Reset();
  331. m_output_vars.clear();
  332. m_cur_rst = {};
  333. m_used_input_names.clear();
  334. m_used_param_names.clear();
  335. m_var_remove_in_dump.clear();
  336. m_model_middle_tensors.clear();
  337. m_var2midtensor_id.clear();
  338. m_nr_shared_tensor = 0;
  339. // process output vars
  340. bool keep_output_var_name = m_config.keep_var_name >= 1;
  341. std::unordered_set<std::string> output_var_names;
  342. for (auto i : new_output_vars) {
  343. mgb_assert(
  344. !i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  345. "can not dump var with VOLATILE_CONTENT flag: %s",
  346. cg::dump_var_info({i.node()}).c_str());
  347. if (m_output_vars.insert(i.node()).second && keep_output_var_name) {
  348. auto name_ins = output_var_names.insert(i.node()->name()).second;
  349. mgb_assert(name_ins, "duplicated output var name: %s", i.node()->cname());
  350. }
  351. }
  352. // Dump metadata
  353. auto fbmeta = build_metadata(metadata);
  354. // Dump operators
  355. init_oprs_to_dump(new_output_vars);
  356. std::vector<flatbuffers::Offset<fbs::v2::Operator>> oprs;
  357. for (auto&& i : m_oprs_to_dump) {
  358. oprs.emplace_back(build_single_opr(i.first, i.second));
  359. }
  360. auto fb_oprs = m_builder.CreateVector(oprs);
  361. // Dump output vars
  362. std::vector<flatbuffers::Offset<fbs::v2::OutputVar>> output_vars_idx;
  363. output_vars_idx.reserve(new_output_vars.size());
  364. for (auto i : new_output_vars) {
  365. auto foutput_vars_idx = build_output_var(i);
  366. output_vars_idx.push_back(foutput_vars_idx);
  367. }
  368. auto fb_output_vars = m_builder.CreateVector(output_vars_idx);
  369. std::vector<flatbuffers::Offset<fbs::v2::OutputAlias>> output_vars_alias;
  370. if (m_config.alias_name_map.size() > 0) {
  371. for (auto&& pair : m_config.alias_name_map) {
  372. std::string name;
  373. SymbolVar var;
  374. std::tie(name, var) = pair;
  375. auto fbs_name = m_builder.CreateSharedString(name);
  376. output_vars_alias.push_back(
  377. fbs::v2::CreateOutputAlias(m_builder, var.node()->id(), fbs_name));
  378. }
  379. }
  380. auto fbs_output_alias = m_builder.CreateVector(output_vars_alias);
  381. auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
  382. fbs::v2::ModelBuilder model(m_builder);
  383. model.add_mge_version(MGB_VERSION);
  384. model.add_model_version(m_version);
  385. model.add_oprs(fb_oprs);
  386. model.add_middle_tensors(fb_mid_tensor);
  387. model.add_output_vars_idx(fb_output_vars);
  388. model.add_output_alias(fbs_output_alias);
  389. model.add_nr_shared_tensor(m_nr_shared_tensor);
  390. model.add_metadata(fbmeta);
  391. m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier());
  392. // Write serialized fbs::Graph
  393. m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize());
  394. // Finalize DumpResult
  395. auto&& ret = m_cur_rst;
  396. for (size_t i = 0; i < new_output_vars.size(); i++) {
  397. ret.outputs.emplace_back(
  398. keep_output_var_name ? new_output_vars[i].node()->cname()
  399. : ssprintf("unnamed%zu", i));
  400. }
  401. std::sort(ret.inputs.begin(), ret.inputs.end());
  402. mgb_assert(ret.nr_opr == m_oprs_to_dump.size());
  403. ret.tot_bytes = m_file->tell() - begin_pos;
  404. return ret;
  405. }
  406. void GraphDumperOSSV2::dump_tensor(
  407. const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) {
  408. using namespace flatbuffers;
  409. using Meth = TensorWriteMethod;
  410. mgb_assert(
  411. (method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()),
  412. "name must be non-empty for non Meth::VALUE_ANONYMOUS tensors");
  413. bool has_value = method != Meth::META_INPUT;
  414. bool should_keep_name = true;
  415. switch (method) {
  416. case Meth::VALUE_ANONYMOUS:
  417. should_keep_name = false;
  418. break;
  419. case Meth::VALUE_SHARED:
  420. should_keep_name = m_config.keep_param_name;
  421. ++m_nr_shared_tensor;
  422. if (m_config.keep_param_name) {
  423. mgb_assert(
  424. m_used_param_names.insert(name).second,
  425. "duplicated VALUE_SHARED tensor name: %s", name.c_str());
  426. m_cur_rst.params.emplace_back(name);
  427. }
  428. break;
  429. case Meth::META_INPUT:
  430. case Meth::VALUE_INPUT:
  431. mgb_assert(!name.empty(), "empty input tensor name");
  432. mgb_assert(
  433. m_used_input_names.insert(name).second,
  434. "duplicated input tensor name: %s", name.c_str());
  435. m_cur_rst.inputs.emplace_back(name);
  436. break;
  437. }
  438. auto& layout = tensor.layout();
  439. flatbuffers::Offset<flatbuffers::Vector<uint8_t>> data;
  440. if (has_value) {
  441. check_tensor_value_valid(name, tensor);
  442. auto&& dumper = m_config.tensor_value_dumper;
  443. if (dumper) {
  444. mgb_log_warn(
  445. "serialization v2 format is pure flatbuffer format, not support "
  446. "user tensor value dumper callback.");
  447. }
  448. data = m_builder.CreateVector(
  449. reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte);
  450. m_cur_rst.tensor_value_bytes += layout.span().high_byte;
  451. }
  452. auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0;
  453. auto fshape = m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
  454. auto fcomp_node = fbs::v2::CreateCompNode(
  455. m_builder,
  456. m_builder.CreateSharedString(tensor.comp_node().to_string_logical()));
  457. auto fdtype = build_dtype(layout.dtype);
  458. auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
  459. auto fformat = build_tensor_format(layout.format);
  460. auto serialized_tensor = fbs::v2::CreateTensor(
  461. m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data);
  462. m_cur_opr_tensor.emplace_back(serialized_tensor);
  463. }
  464. void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) {
  465. auto blob = fbs::v2::CreateBlob(
  466. m_builder, m_builder.CreateVector(static_cast<const uint8_t*>(data), size));
  467. m_blobs.emplace_back(blob);
  468. }
  469. // ----------------------------- Loader --------------------------------------
  470. CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node(
  471. const fbs::v2::CompNode* comp_node) {
  472. mgb_assert(comp_node);
  473. if (!comp_node->logical_locator())
  474. return {};
  475. auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str());
  476. m_loader->m_cur_load_config->comp_node_mapper(loc);
  477. return CompNode::load(loc);
  478. }
  479. TensorFormat load_tensor_format(
  480. const fbs::v2::TensorFormat fformat_type, const void* fformat,
  481. const CompNode& comp_node) {
  482. switch (fformat_type) {
  483. case fbs::v2::TensorFormat_DefaultTensorFormat:
  484. return megdnn::DefaultTensorFormat::make();
  485. case fbs::v2::TensorFormat_Image2DPackedTensorFormat: {
  486. auto image_format =
  487. static_cast<const fbs::v2::Image2DPackedTensorFormat*>(fformat);
  488. auto handle =
  489. MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle();
  490. return megdnn::Image2DPack4TensorFormat::make(
  491. image_format->align_axis(), handle);
  492. }
  493. case fbs::v2::TensorFormat_LowbitsAlignedTensorFormat: {
  494. auto lowbit_format =
  495. static_cast<const fbs::v2::LowbitsAlignedTensorFormat*>(fformat);
  496. return megdnn::LowbitsAlignedToBytesTensorFormat::make(
  497. lowbit_format->size_nbits());
  498. }
  499. default:
  500. mgb_throw(
  501. SerializationError, "invalid tensor format type in serialization.");
  502. }
  503. }
  504. TensorLayout load_tensor_layout(
  505. const fbs::v2::Tensor* tensor, const CompNode& comp_node) {
  506. TensorLayout layout;
  507. if (tensor->shape()) {
  508. layout.ndim = tensor->shape()->size();
  509. std::copy(tensor->shape()->begin(), tensor->shape()->end(), layout.shape);
  510. }
  511. if (tensor->dtype()) {
  512. // modify data type inplace for TensorLayout
  513. layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype()));
  514. }
  515. if (tensor->format() && tensor->format_type()) {
  516. layout.format =
  517. load_tensor_format(tensor->format_type(), tensor->format(), comp_node);
  518. }
  519. layout.init_contiguous_stride();
  520. return layout;
  521. }
  522. //! the opr loader should make sure the exist of tensors and the number of
  523. //! tensor, here just assert it.
  524. std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() {
  525. mgb_assert(
  526. m_current_opr->tensors() &&
  527. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  528. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  529. auto comp_node = load_comp_node(tensor->comp_node());
  530. auto layout = load_tensor_layout(tensor, comp_node);
  531. auto ret = std::make_shared<HostTensorND>(comp_node, layout);
  532. auto&& loader = m_loader->m_cur_load_config->tensor_value_loader;
  533. if (tensor->data() && tensor->data()->size() > 0) {
  534. if (loader) {
  535. mgb_log_warn(
  536. "serialization v2 format is pure flatbuffer format, not support "
  537. "user tensor value loader callback.");
  538. }
  539. memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size());
  540. }
  541. if (tensor->name()) {
  542. m_tensor_map[tensor->name()->str()] = ret;
  543. }
  544. if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) {
  545. bool has_value = false;
  546. if (tensor && tensor->data()) {
  547. has_value = tensor->data()->size() != 0;
  548. }
  549. mod(tensor->name() ? tensor->name()->str() : "", has_value, *ret);
  550. }
  551. return ret;
  552. }
  553. std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
  554. load_tensor_shared() {
  555. mgb_assert(
  556. m_current_opr->tensors() &&
  557. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  558. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  559. auto comp_node = load_comp_node(tensor->comp_node());
  560. auto layout = load_tensor_layout(tensor, comp_node);
  561. mgb_assert(tensor->data());
  562. auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
  563. auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()];
  564. if (shared_tensor_ref) {
  565. if (shared_tensor_ref->comp_node() == comp_node)
  566. return shared_tensor_ref;
  567. // same mem node but different comp node, change comp node and share
  568. // value
  569. auto ret = std::make_shared<DeviceTensorND>(*shared_tensor_ref);
  570. ret->comp_node(comp_node);
  571. return ret;
  572. }
  573. if (tensor->name()) {
  574. shared_pair.first = tensor->name()->str();
  575. }
  576. if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
  577. // directly forward CPU memory
  578. HostTensorND hv{comp_node};
  579. if (tensor->data() && tensor->data()->size() > 0) {
  580. hv.dtype(layout.dtype).resize(layout);
  581. memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size());
  582. }
  583. shared_tensor_ref = std::make_shared<DeviceTensorND>();
  584. *shared_tensor_ref = DeviceTensorND::make_proxy(hv);
  585. } else {
  586. // use lazy load for non-CPU devices
  587. HostTensorND hv{CompNode::default_cpu()};
  588. if (tensor->data() && tensor->data()->size() > 0) {
  589. hv.dtype(layout.dtype).resize(layout);
  590. memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size());
  591. }
  592. shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv));
  593. }
  594. return shared_tensor_ref;
  595. }
  596. Metadata GraphLoaderOSSV2::OprLoadContextImpl::load_metadata() {
  597. const auto* fbmeta = m_loader->m_model->metadata();
  598. Metadata ret;
  599. if (fbmeta) {
  600. ret.is_valid = fbmeta->is_valid();
  601. ret.graph_modified = fbmeta->graph_modified();
  602. if (fbmeta->user_info()) {
  603. ret.user_info = fbmeta->user_info()->str();
  604. ret.has_user_info = true;
  605. }
  606. if (fbmeta->optimize_options()) {
  607. ret.optimize_options = fbmeta->optimize_options();
  608. ret.optimized_for_inference = true;
  609. }
  610. }
  611. return ret;
  612. }
  613. void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
  614. const fbs::v2::Operator* fbopr) {
  615. m_cur_opr_tensor_cnt = 0;
  616. m_cur_opr_blob_cnt = 0;
  617. m_cur_opr_param_cnt = 0;
  618. OperatorNodeConfig config;
  619. if (fbopr->output_dtype()) {
  620. config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
  621. }
  622. if (fbopr->name()) {
  623. config.name(fbopr->name()->str());
  624. }
  625. if (fbopr->comp_node()) {
  626. auto cnt = fbopr->comp_node()->size();
  627. cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
  628. for (size_t i = 0; i < cnt; i++) {
  629. CompNode cn{};
  630. auto node = fbopr->comp_node()->Get(i);
  631. if (node) {
  632. cn = load_comp_node(node);
  633. }
  634. comp_node_arr[i] = cn;
  635. }
  636. config.comp_node_arr(comp_node_arr);
  637. }
  638. //! opr version must be exist
  639. uint8_t opr_version = fbopr->opr_version();
  640. auto type_id = fbopr->type_id();
  641. const OprRegistryV2* registry =
  642. OprRegistryV2::versioned_find_by_id(type_id, opr_version);
  643. mgb_throw_if(
  644. !registry, SerializationError,
  645. "failed to find opr with type %s and version %d.",
  646. fbopr->type()->str().c_str(), opr_version);
  647. // load inputs
  648. VarNodeArray inputs;
  649. if (fbopr->inputs()) {
  650. inputs.resize(fbopr->inputs()->size());
  651. for (size_t i = 0; i < inputs.size(); ++i) {
  652. inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i));
  653. }
  654. }
  655. // call loader
  656. auto accessor = registry->loader(*this, inputs, config);
  657. auto opr = accessor.opr();
  658. // check opr type; note that:
  659. // 1. registry->type may be empty for dynamic opr loaders or legacy oprs
  660. // 2. due to some optimization, an opr may be replaced by ImmutableTensor
  661. mgb_assert(
  662. opr && (opr->dyn_typeinfo() == registry->type || !registry->type ||
  663. opr->same_type<opr::ImmutableTensor>()),
  664. "got_type=%s expected_type=%s", opr ? opr->dyn_typeinfo()->name : nullptr,
  665. registry->type->name);
  666. // record output vars; read output names
  667. size_t i = 0;
  668. for (auto ovar : accessor.output()) {
  669. if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  670. m_id2varnode.push_back(ovar);
  671. if (fbopr->outputs()) {
  672. auto id = fbopr->outputs()->Get(i);
  673. mgb_assert(
  674. m_id2varnode.size() - 1 == fbopr->outputs()->Get(i),
  675. "id2var is %zu, fbs get id is %d\n", m_id2varnode.size() - 1,
  676. fbopr->outputs()->Get(i));
  677. if (m_middle_tensors.size() > i) {
  678. auto name = m_middle_tensors[id]->name()->str();
  679. ovar->name(name);
  680. }
  681. }
  682. i++;
  683. }
  684. }
  685. opr->node_prop().attribute().priority = fbopr->priority();
  686. }
  687. GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() {
  688. // load oprs
  689. const auto* oprs = m_loader->m_model->oprs();
  690. {
  691. // inplace arith graph optimization is disabled during opr load
  692. // it tries to restore the same graph as it was dumped
  693. // see test TestSerializer2.LOGEXP for example
  694. GraphLoader::ScopedGraphOptDisabler _(m_graph);
  695. for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) {
  696. m_current_opr = oprs->Get(i);
  697. load_single_opr(m_current_opr);
  698. }
  699. }
  700. // batched loading device values
  701. m_device_value_loader.apply();
  702. LoadResult ret;
  703. ret.graph = m_graph;
  704. ret.tensor_map = m_tensor_map;
  705. const auto* outputs = m_loader->m_model->output_vars_idx();
  706. ret.output_var_list.resize(outputs->size());
  707. for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) {
  708. auto out = outputs->Get(i);
  709. auto var = m_id2varnode.at(out->compact_id());
  710. ret.output_var_map[var->name()] = var;
  711. ret.output_var_map_id[out->original_id()] = var;
  712. ret.output_var_list[i] = var;
  713. }
  714. mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size());
  715. return ret;
  716. }
  717. void GraphLoaderOSSV2::OprLoadContextImpl::load_middle_tensor() {
  718. auto model = m_loader->m_model;
  719. if (model->middle_tensors()) {
  720. for (unsigned int i = 0; i < m_loader->m_model->middle_tensors()->size(); i++) {
  721. m_middle_tensors.push_back(model->middle_tensors()->Get(i));
  722. }
  723. }
  724. }
  725. GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool rewind) {
  726. mgb_assert(m_file);
  727. m_cur_load_config = &config;
  728. if (rewind) {
  729. m_file->rewind();
  730. }
  731. // Read fbs::Graph
  732. uint32_t size;
  733. m_file->read(&size, sizeof(size));
  734. m_model_buf = m_file->read_shared(size);
  735. mgb_throw_if(
  736. !fbs::v2::ModelBufferHasIdentifier(m_model_buf.data()), SerializationError,
  737. "invalid fbs model");
  738. {
  739. flatbuffers::Verifier verifier(
  740. static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size());
  741. mgb_throw_if(
  742. !fbs::v2::VerifyModelBuffer(verifier), SerializationError,
  743. "model verification failed (invalid or corrupted model?)");
  744. }
  745. m_model = fbs::v2::GetModel(m_model_buf.data());
  746. m_mgb_version = m_model->mge_version();
  747. m_model_version = m_model->model_version();
  748. if (m_model->mge_version() > MGB_VERSION) {
  749. mgb_log_warn(
  750. "loading model from future runtime: version=%u "
  751. "model_version=%u",
  752. MGB_VERSION, m_model->mge_version());
  753. }
  754. if (m_model_version > CURRENT_VERSION) {
  755. mgb_log_warn(
  756. "The model dump in the future version %d, try to load it, maybe case "
  757. "load error in %d version.",
  758. m_model_version, CURRENT_VERSION);
  759. }
  760. if (m_shared_tensor_map.empty()) {
  761. m_shared_tensor_map.resize(m_model->nr_shared_tensor());
  762. } else {
  763. mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor());
  764. }
  765. OprLoadContextImpl ctx{this, m_model->mge_version()};
  766. ctx.load_middle_tensor();
  767. auto metadata = ctx.load_metadata();
  768. auto result = ctx.load_oprs();
  769. result.metadata = metadata;
  770. if (m_model->output_alias() && m_model->output_alias()->size() > 0) {
  771. auto nr_alias = m_model->output_alias()->size();
  772. result.output_var_list.resize(nr_alias);
  773. for (size_t i = 0; i < nr_alias; i++) {
  774. auto output_alias = m_model->output_alias()->Get(i);
  775. std::string name = output_alias->name()->str();
  776. size_t id = output_alias->id();
  777. result.output_var_map[name] = result.output_var_map_id[id];
  778. result.output_var_list[i] = result.output_var_map_id[id];
  779. }
  780. }
  781. m_model_loaded = true;
  782. result.graph_compile_ahead();
  783. return result;
  784. }
  785. std::unique_ptr<GraphDumper> make_fbs_v2_dumper(
  786. std::unique_ptr<OutputFile> file, int version) {
  787. return std::make_unique<GraphDumperOSSV2>(std::move(file), version);
  788. }
  789. std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) {
  790. return std::make_unique<GraphLoaderOSSV2>(std::move(file));
  791. }
  792. bool is_fbs_v2_file(InputFile& file) {
  793. constexpr size_t identifier_length = 25;
  794. char identifier[identifier_length];
  795. file.read(identifier, identifier_length);
  796. file.skip(-identifier_length);
  797. //! skip the size in prefix of the file
  798. return fbs::v2::ModelBufferHasIdentifier(identifier + sizeof(uint32_t));
  799. }
  800. } // namespace serialization
  801. } // namespace mgb
  802. #endif
  803. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}