You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_oss.cpp 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914
  1. /*
  2. * Dump file layout:
  3. * [uint32_t fourcc]
  4. * [00 00 00 00]
  5. * [uint64_t offset to graph from tensor start]
  6. * [Tensor 1]
  7. * [Tensor 2]
  8. * [...]
  9. * [Tensor N]
  10. * [SizePrefixed FlatBuffers Graph]
  11. */
  12. #if MGB_ENABLE_FBS_SERIALIZATION
  13. #include "megbrain/graph/exc_extra_info.h"
  14. #include "megbrain/opr/io.h"
  15. #include "megbrain/serialization/batched_device_value_loader.h"
  16. #include "megbrain/serialization/helper.h"
  17. #include "megbrain/serialization/internal/flatbuffers_helper.h"
  18. #include "megbrain/serialization/internal/schema_generated.h"
  19. #include "megbrain/serialization/metadata.h"
  20. #include "megbrain/serialization/opr_load_dump.h"
  21. #include "megbrain/serialization/serializer.h"
  22. #include "serializer_oss_common.h"
  23. #include <flatbuffers/flatbuffers.h>
  24. #include <cerrno>
  25. #include <cinttypes>
  26. #include <cstdio>
  27. using namespace mgb;
  28. using namespace mgb::serialization;
  29. namespace {
  30. bool magic_compare = true;
  31. //! feature bits for backward compatibility; default value should be 0
  32. struct FeatureBits64 {
  33. //! reserved for new fields
  34. uint64_t : 64;
  35. static void write(OutputFile& fout) {
  36. static_assert(sizeof(FeatureBits64) == 8, "bad feature bits");
  37. FeatureBits64 fb64;
  38. memset(&fb64, 0, sizeof(fb64));
  39. fout.write(&fb64, 8);
  40. }
  41. };
  42. } // namespace
  43. namespace mgb {
  44. namespace serialization {
  45. class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
  46. const std::unique_ptr<OutputFile> m_file;
  47. flatbuffers::FlatBufferBuilder m_builder;
  48. DumpConfig m_config;
  49. DumpResult m_cur_rst;
  50. size_t m_nr_shared_tensor;
  51. std::vector<std::pair<cg::OperatorNodeBase*, const OprRegistry*>> m_oprs_to_dump;
  52. ThinHashMap<VarNode*, size_t> m_var2id;
  53. //! set of output vars specified by user
  54. ThinHashSet<VarNode*> m_output_vars;
  55. std::unordered_set<std::string> m_used_input_names, m_used_param_names;
  56. //! current opr to be dumped
  57. cg::OperatorNodeBase* m_cur_opr = nullptr;
  58. // Will be filled in dump_tensor
  59. std::vector<flatbuffers::Offset<fbs::Tensor>> m_cur_opr_tensor;
  60. std::vector<flatbuffers::Offset<fbs::Blob>> m_blobs;
  61. std::vector<fbs::OperatorParam> m_cur_opr_param_type;
  62. std::vector<flatbuffers::Offset<void>> m_cur_opr_param;
  63. void init_oprs_to_dump(const SymbolVarArray& endpoints);
  64. flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata);
  65. flatbuffers::Offset<fbs::Operator> build_single_opr(
  66. cg::OperatorNodeBase* opr, const OprRegistry* registry);
  67. flatbuffers::Offset<fbs::DType> build_dtype(DType dtype);
  68. public:
  69. GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {}
  70. DumpResult dump(
  71. const SymbolVarArray& output_vars, const DumpConfig& config = {},
  72. const Metadata& metadata = {}) override;
  73. const GraphDumpConfig& config() const override { return m_config; }
  74. void dump_tensor(
  75. const std::string& name, const HostTensorND& tensor,
  76. TensorWriteMethod method) override;
  77. flatbuffers::FlatBufferBuilder& builder() override { return m_builder; }
  78. void append_param(uint32_t type, uint32_t value) override {
  79. static_assert(
  80. std::is_same<uint32_t, flatbuffers::uoffset_t>::value,
  81. "append_param depends on uoffset_t being uint32_t");
  82. static_assert(
  83. std::is_standard_layout<flatbuffers::Offset<void>>::value,
  84. "append_param depends on flatbuffers::Offset having "
  85. "standard memory layout");
  86. mgb_assert(type != fbs::OperatorParam_NONE);
  87. m_cur_opr_param_type.emplace_back(static_cast<fbs::OperatorParam>(type));
  88. m_cur_opr_param.emplace_back(value);
  89. }
  90. void dump_buf_with_len(const void* data, uint32_t size) override;
  91. GraphDumpFormat format() const override { return GraphDumpFormat::FLATBUFFERS; }
  92. };
  93. flatbuffers::Offset<fbs::DType> GraphDumperOSS::build_dtype(DType dtype) {
  94. return fbs::intl::build_dtype(m_builder, dtype);
  95. }
  96. void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) {
  97. m_oprs_to_dump.clear();
  98. m_var2id.clear();
  99. // iterate oprs to init m_var2id
  100. size_t next_id = 0;
  101. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  102. if (should_remove_in_dump(opr)) {
  103. mgb_assert(opr->input().size() == 1);
  104. // Copy input ID to output
  105. auto id = m_var2id.at(opr->input(0));
  106. for (auto i : opr->output())
  107. m_var2id[i] = id;
  108. } else {
  109. auto registry = OprRegistry::find_by_type(opr->dyn_typeinfo());
  110. if (!registry || !registry->dumper) {
  111. mgb_throw(
  112. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  113. "serialization as FlatBuffers is not supported for "
  114. "operator %s",
  115. opr->dyn_typeinfo()->name);
  116. }
  117. m_oprs_to_dump.emplace_back(opr, registry);
  118. for (auto i : opr->output()) {
  119. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  120. m_var2id[i] = next_id++;
  121. }
  122. }
  123. }
  124. };
  125. cg::DepOprIter dep_opr_iter{on_opr};
  126. for (auto i : endpoints) {
  127. dep_opr_iter.add(i.node()->owner_opr());
  128. }
  129. }
  130. flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata(
  131. const Metadata& metadata) {
  132. auto user_info = m_builder.CreateSharedString(metadata.user_info);
  133. fbs::MetadataBuilder builder(m_builder);
  134. builder.add_is_valid(metadata.is_valid);
  135. builder.add_graph_modified(metadata.graph_modified);
  136. builder.add_user_info(user_info);
  137. builder.add_optimize_options(metadata.optimize_options);
  138. return builder.Finish();
  139. }
  140. flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
  141. cg::OperatorNodeBase* opr, const OprRegistry* registry) {
  142. m_cur_opr = opr;
  143. ++m_cur_rst.nr_opr;
  144. using namespace flatbuffers;
  145. Offset<Vector<Offset<fbs::CompNode>>> comp_node;
  146. auto& config = opr->config();
  147. if (config.has_comp_node_set()) {
  148. std::vector<flatbuffers::Offset<fbs::CompNode>> cns;
  149. for (const auto& cn : config.comp_node()) {
  150. cns.emplace_back(fbs::CreateCompNode(
  151. m_builder, m_builder.CreateSharedString(cn.to_string_logical())));
  152. }
  153. comp_node = m_builder.CreateVector(cns);
  154. }
  155. Offset<Vector<uint32_t>> inputs;
  156. if (opr->input().size()) {
  157. std::vector<uint32_t> v;
  158. v.reserve(opr->input().size());
  159. for (auto inp : opr->input()) {
  160. v.emplace_back(m_var2id.at(inp));
  161. }
  162. inputs = m_builder.CreateVector(v);
  163. }
  164. Offset<String> operator_name;
  165. if (m_config.keep_op_name) {
  166. operator_name = m_builder.CreateSharedString(opr->name());
  167. }
  168. Offset<Vector<Offset<String>>> output_names;
  169. if (m_config.keep_var_name >= 2 ||
  170. (m_config.keep_var_name == 1 &&
  171. contains_any_in_set(opr->output(), m_output_vars))) {
  172. std::vector<std::string> onames;
  173. for (auto i : opr->output()) {
  174. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  175. onames.emplace_back(i->name());
  176. }
  177. }
  178. output_names = m_builder.CreateVectorOfStrings(onames);
  179. }
  180. auto output_dtype = build_dtype(config.output_dtype());
  181. m_cur_opr_tensor.clear();
  182. m_blobs.clear();
  183. m_cur_opr_param.clear();
  184. m_cur_opr_param_type.clear();
  185. registry->dumper(*this, *opr);
  186. Offset<Vector<Offset<fbs::Tensor>>> tensors;
  187. if (m_cur_opr_tensor.size())
  188. tensors = m_builder.CreateVector(m_cur_opr_tensor);
  189. Offset<Vector<Offset<fbs::Blob>>> blobs;
  190. if (m_blobs.size())
  191. blobs = m_builder.CreateVector(m_blobs);
  192. Offset<Vector<uint8_t>> additional_params_type;
  193. Offset<Vector<Offset<void>>> additional_params;
  194. auto param_cnt = m_cur_opr_param_type.size();
  195. if (param_cnt > 1) {
  196. additional_params_type = m_builder.CreateVectorScalarCast<uint8_t>(
  197. m_cur_opr_param_type.data() + 1, param_cnt - 1);
  198. additional_params =
  199. m_builder.CreateVector(m_cur_opr_param.data() + 1, param_cnt - 1);
  200. }
  201. fbs::OperatorBuilder builder(m_builder);
  202. builder.add_type_id(registry->persist_type_id);
  203. builder.add_inputs(inputs);
  204. if (m_config.keep_opr_priority) {
  205. builder.add_priority(opr->node_prop().attribute().priority);
  206. }
  207. builder.add_comp_node(comp_node);
  208. builder.add_output_name(output_names);
  209. builder.add_name(operator_name);
  210. builder.add_output_dtype(output_dtype);
  211. if (param_cnt > 0) {
  212. builder.add_param_type(m_cur_opr_param_type[0]);
  213. builder.add_param(m_cur_opr_param[0]);
  214. }
  215. if (param_cnt > 1) {
  216. builder.add_additional_params_type(additional_params_type);
  217. builder.add_additional_params(additional_params);
  218. }
  219. builder.add_tensors(tensors);
  220. builder.add_blobs(blobs);
  221. m_cur_opr = nullptr;
  222. return builder.Finish();
  223. }
  224. GraphDumper::DumpResult GraphDumperOSS::dump(
  225. const SymbolVarArray& output_vars, const DumpConfig& config,
  226. const Metadata& metadata) {
  227. mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph");
  228. auto begin_pos = m_file->tell();
  229. m_config = config;
  230. m_builder.Reset();
  231. m_output_vars.clear();
  232. m_cur_rst = {};
  233. m_used_input_names.clear();
  234. m_used_param_names.clear();
  235. m_nr_shared_tensor = 0;
  236. // process output vars
  237. bool keep_output_var_name = m_config.keep_var_name >= 1;
  238. std::unordered_set<std::string> output_var_names;
  239. for (auto i : output_vars) {
  240. mgb_assert(
  241. !i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  242. "can not dump var with VOLATILE_CONTENT flag: %s",
  243. cg::dump_var_info({i.node()}).c_str());
  244. if (m_output_vars.insert(i.node()).second && keep_output_var_name) {
  245. auto name_ins = output_var_names.insert(i.node()->name()).second;
  246. mgb_assert(name_ins, "duplicated output var name: %s", i.node()->cname());
  247. }
  248. }
  249. // Write magic
  250. uint32_t magic = MGB_MAGIC;
  251. m_file->write(&magic, sizeof(magic));
  252. // write FeatureBits
  253. FeatureBits64::write(*m_file);
  254. // Padding
  255. uint32_t reserved = 0;
  256. m_file->write(&reserved, sizeof(reserved));
  257. // Write placeholder for offset_to_fbs
  258. auto offset_pos = m_file->tell();
  259. uint64_t offset_to_fbs = 0;
  260. m_file->write(&offset_to_fbs, sizeof(offset_to_fbs));
  261. // Dump metadata
  262. auto fbmeta = build_metadata(metadata);
  263. // Dump operators
  264. init_oprs_to_dump(output_vars);
  265. std::vector<flatbuffers::Offset<fbs::Operator>> oprs;
  266. for (auto&& i : m_oprs_to_dump) {
  267. record_opr_dumped(i.second->persist_type_id, i.second->name, 0);
  268. oprs.emplace_back(build_single_opr(i.first, i.second));
  269. }
  270. auto fb_oprs = m_builder.CreateVector(oprs);
  271. // Dump output vars
  272. std::vector<fbs::OutputVar> output_vars_idx;
  273. output_vars_idx.reserve(output_vars.size());
  274. for (auto i : output_vars) {
  275. output_vars_idx.emplace_back(m_var2id.at(i.node()), i.node()->id());
  276. }
  277. auto fb_output_vars = m_builder.CreateVectorOfStructs(output_vars_idx);
  278. XXHash content_hash;
  279. content_hash.update(m_builder.GetCurrentBufferPointer(), m_builder.GetSize());
  280. auto graph_hash = content_hash.digest();
  281. fbs::GraphBuilder graph(m_builder);
  282. graph.add_mgb_version(MGB_VERSION);
  283. graph.add_hash(graph_hash);
  284. graph.add_oprs(fb_oprs);
  285. graph.add_output_vars_idx(fb_output_vars);
  286. graph.add_nr_shared_tensor(m_nr_shared_tensor);
  287. graph.add_metadata(fbmeta);
  288. m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier());
  289. // Write actual offset_to_fbs
  290. auto cur = m_file->tell();
  291. mgb_assert(cur >= offset_pos && cur - offset_pos >= sizeof(offset_to_fbs));
  292. offset_to_fbs = cur - offset_pos - sizeof(offset_to_fbs);
  293. m_file->seek(offset_pos);
  294. m_file->write(&offset_to_fbs, sizeof(offset_to_fbs));
  295. m_file->seek(cur);
  296. // Write serialized fbs::Graph
  297. m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize());
  298. // Finalize DumpResult
  299. auto&& ret = m_cur_rst;
  300. for (size_t i = 0; i < output_vars.size(); i++) {
  301. ret.outputs.emplace_back(
  302. keep_output_var_name ? output_vars[i].node()->cname()
  303. : ssprintf("unnamed%zu", i));
  304. }
  305. ret.content_hash = graph_hash;
  306. std::sort(ret.inputs.begin(), ret.inputs.end());
  307. mgb_assert(ret.nr_opr == m_oprs_to_dump.size());
  308. ret.tot_bytes = m_file->tell() - begin_pos;
  309. return ret;
  310. }
  311. void GraphDumperOSS::dump_tensor(
  312. const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) {
  313. using namespace flatbuffers;
  314. using Meth = TensorWriteMethod;
  315. mgb_assert(
  316. (method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()),
  317. "name must be non-empty for non Meth::VALUE_ANONYMOUS tensors");
  318. bool has_value = method != Meth::META_INPUT;
  319. bool should_keep_name = true;
  320. switch (method) {
  321. case Meth::VALUE_ANONYMOUS:
  322. should_keep_name = false;
  323. break;
  324. case Meth::VALUE_SHARED:
  325. should_keep_name = m_config.keep_param_name;
  326. ++m_nr_shared_tensor;
  327. if (m_config.keep_param_name) {
  328. mgb_assert(
  329. m_used_param_names.insert(name).second,
  330. "duplicated VALUE_SHARED tensor name: %s", name.c_str());
  331. m_cur_rst.params.emplace_back(name);
  332. }
  333. break;
  334. case Meth::META_INPUT:
  335. case Meth::VALUE_INPUT:
  336. mgb_assert(!name.empty(), "empty input tensor name");
  337. mgb_assert(
  338. m_used_input_names.insert(name).second,
  339. "duplicated input tensor name: %s", name.c_str());
  340. m_cur_rst.inputs.emplace_back(name);
  341. break;
  342. }
  343. size_t value_size = 0;
  344. if (has_value) {
  345. check_tensor_value_valid(name, tensor);
  346. auto begin = m_file->tell();
  347. auto&& dumper = m_config.tensor_value_dumper;
  348. if (dumper) {
  349. dumper(*m_file, *m_cur_opr, tensor);
  350. } else {
  351. m_file->write(tensor.raw_ptr(), tensor.layout().span().high_byte);
  352. }
  353. value_size = m_file->tell() - begin;
  354. m_cur_rst.tensor_value_bytes += value_size;
  355. }
  356. auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0;
  357. auto shape = m_builder.CreateVectorScalarCast<uint32_t>(
  358. tensor.shape().shape, tensor.shape().ndim);
  359. auto comp_node = fbs::CreateCompNode(
  360. m_builder,
  361. m_builder.CreateSharedString(tensor.comp_node().to_string_logical()));
  362. auto dtype = build_dtype(tensor.dtype());
  363. auto serialized_tensor =
  364. fbs::CreateTensor(m_builder, fbname, shape, comp_node, dtype, value_size);
  365. m_cur_opr_tensor.emplace_back(serialized_tensor);
  366. }
  367. void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) {
  368. auto blob = fbs::CreateBlob(
  369. m_builder, m_builder.CreateVector(static_cast<const uint8_t*>(data), size));
  370. m_blobs.emplace_back(blob);
  371. }
  372. // ----------------------------- Loader --------------------------------------
  373. class GraphLoaderOSS final : public GraphLoader {
  374. const LoadConfig* m_cur_load_config = nullptr;
  375. std::unique_ptr<InputFile> m_file;
  376. FeatureBits64 m_feature_bits;
  377. SharedBuffer m_graph_buf{{}, 0};
  378. const fbs::Graph* m_graph;
  379. SharedTensorIDMap m_shared_tensor_map;
  380. uint32_t m_mgb_version = 0;
  381. uint64_t m_graph_hash = 0;
  382. class OprLoadContextImpl;
  383. friend class OprLoadContextImpl;
  384. void verify();
  385. public:
  386. GraphLoaderOSS(std::unique_ptr<InputFile> input_file)
  387. : m_file{std::move(input_file)} {}
  388. std::unique_ptr<InputFile> reset_file(std::unique_ptr<InputFile> file) override {
  389. file.swap(m_file);
  390. return file;
  391. }
  392. LoadResult load(const LoadConfig& config, bool rewind) override;
  393. const SharedTensorIDMap& shared_tensor_id_map() const override {
  394. mgb_assert(m_graph_hash, "graph not loaded yet");
  395. return m_shared_tensor_map;
  396. }
  397. GraphDumpFormat format() const override { return GraphDumpFormat::FLATBUFFERS; }
  398. };
  399. class GraphLoaderOSS::OprLoadContextImpl final : public OprLoadContextFlatBuffers {
  400. GraphLoaderOSS* const m_loader;
  401. size_t m_cur_shared_tensor_idx = 0;
  402. std::shared_ptr<ComputingGraph> m_graph;
  403. LoadResult::TensorMap m_tensor_map;
  404. VarNodeArray m_id2varnode;
  405. BatchedDeviceValueLoader m_device_value_loader;
  406. const fbs::Operator* m_current_opr;
  407. size_t m_cur_opr_tensor_cnt;
  408. size_t m_cur_opr_blob_cnt;
  409. size_t m_cur_opr_param_cnt;
  410. ComputingGraph& graph() override { return *m_graph; }
  411. const GraphLoadConfig& config() const override {
  412. return *m_loader->m_cur_load_config;
  413. }
  414. void load_tensor_value(
  415. HostTensorND* dest, const TensorLayout& layout, const fbs::Tensor* tensor);
  416. std::shared_ptr<HostTensorND> load_tensor() override;
  417. std::shared_ptr<DeviceTensorND> load_tensor_shared() override;
  418. void load_single_opr(const fbs::Operator* opr);
  419. public:
  420. OprLoadContextImpl(GraphLoaderOSS* loader, uint32_t version)
  421. : OprLoadContextFlatBuffers(version), m_loader{loader} {
  422. m_graph = loader->m_cur_load_config->comp_graph;
  423. if (!m_graph) {
  424. m_graph = ComputingGraph::make();
  425. }
  426. auto maker = [this]() {
  427. return std::shared_ptr<OprLoadContext>{
  428. std::shared_ptr<OprLoadContext>{}, this};
  429. };
  430. auto got = m_graph->options().user_data.get_user_data_or_create<OprLoadContext>(
  431. maker);
  432. mgb_assert(got == this);
  433. }
  434. ~OprLoadContextImpl() noexcept {
  435. auto nr = m_graph->options().user_data.pop_user_data<OprLoadContext>();
  436. mgb_assert(nr == 1);
  437. }
  438. Metadata load_metadata();
  439. LoadResult load_oprs();
  440. CompNode load_comp_node(const fbs::CompNode* comp_node);
  441. const void* get_next_param(uint32_t enumv) override {
  442. auto type = static_cast<fbs::OperatorParam>(enumv);
  443. if (m_cur_opr_param_cnt == 0) {
  444. m_cur_opr_param_cnt++;
  445. if (m_current_opr->param_type() == type) {
  446. return m_current_opr->param();
  447. }
  448. } else {
  449. mgb_assert(
  450. m_current_opr->additional_params() &&
  451. m_cur_opr_param_cnt - 1 <
  452. m_current_opr->additional_params()->size());
  453. auto i = m_cur_opr_param_cnt++ - 1;
  454. if (m_current_opr->additional_params_type()->Get(i) == type) {
  455. return m_current_opr->additional_params()->Get(i);
  456. }
  457. }
  458. return nullptr;
  459. }
  460. std::string load_buf_with_len() override {
  461. mgb_assert(
  462. m_current_opr->blobs() &&
  463. m_cur_opr_blob_cnt < m_current_opr->blobs()->size());
  464. auto blob = m_current_opr->blobs()->Get(m_cur_opr_blob_cnt++);
  465. mgb_assert(blob && blob->data());
  466. auto data = blob->data()->data();
  467. return {reinterpret_cast<const char*>(data), blob->data()->size()};
  468. }
  469. SharedBuffer load_shared_buf_with_len() override {
  470. mgb_assert(
  471. m_current_opr->blobs() &&
  472. m_cur_opr_blob_cnt < m_current_opr->blobs()->size());
  473. auto blob = m_current_opr->blobs()->Get(m_cur_opr_blob_cnt++);
  474. mgb_assert(blob && blob->data());
  475. auto size = blob->data()->size();
  476. std::shared_ptr<uint8_t> shptr{
  477. new uint8_t[size], [](uint8_t* p) { delete[] p; }};
  478. memcpy(shptr.get(), blob->data()->data(), size);
  479. return {std::move(shptr), size};
  480. }
  481. };
  482. CompNode GraphLoaderOSS::OprLoadContextImpl::load_comp_node(
  483. const fbs::CompNode* comp_node) {
  484. mgb_assert(comp_node);
  485. if (!comp_node->logical_locator())
  486. return {};
  487. auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str());
  488. m_loader->m_cur_load_config->comp_node_mapper(loc);
  489. return CompNode::load(loc);
  490. }
  491. TensorLayout load_tensor_layout(const fbs::Tensor* tensor) {
  492. TensorLayout layout;
  493. if (tensor->shape()) {
  494. layout.ndim = tensor->shape()->size();
  495. std::copy(tensor->shape()->begin(), tensor->shape()->end(), layout.shape);
  496. }
  497. if (tensor->dtype()) {
  498. // modify data type inplace for TensorLayout
  499. layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype()));
  500. }
  501. layout.init_contiguous_stride();
  502. return layout;
  503. }
  504. void GraphLoaderOSS::OprLoadContextImpl::load_tensor_value(
  505. HostTensorND* dest, const TensorLayout& layout, const fbs::Tensor* tensor) {
  506. auto&& loader = m_loader->m_cur_load_config->tensor_value_loader;
  507. auto&& file = m_loader->m_file;
  508. auto begin_pos = file->tell();
  509. file->skip(tensor->offset());
  510. if (loader) {
  511. // call custom loader
  512. void* dest_ptr = nullptr;
  513. if (dest) {
  514. dest->dtype(layout.dtype).resize(layout);
  515. dest_ptr = dest->raw_ptr();
  516. }
  517. loader(dest_ptr, layout, *file);
  518. } else {
  519. if (dest) {
  520. file->read_into_tensor(*dest, layout);
  521. } else {
  522. file->skip(layout.span().high_byte);
  523. }
  524. }
  525. mgb_throw_if(
  526. file->tell() < begin_pos, SerializationError,
  527. "Custom tensor value loader accessed out of range data before "
  528. "start of data blob");
  529. auto data_size = tensor->data_size();
  530. auto consumed_size = file->tell() - begin_pos;
  531. mgb_throw_if(
  532. consumed_size > data_size, SerializationError,
  533. "Custom tensor value loader consumed more data than "
  534. "available: consumed %zu, has %u",
  535. consumed_size, data_size);
  536. if (consumed_size < data_size) {
  537. mgb_log_warn(
  538. "Tensor value loader consumed less data than available: "
  539. "consumed %zu bytes, has %u bytes",
  540. consumed_size, data_size);
  541. file->skip(data_size - consumed_size);
  542. }
  543. }
  544. std::shared_ptr<HostTensorND> GraphLoaderOSS::OprLoadContextImpl::load_tensor() {
  545. mgb_assert(
  546. m_current_opr->tensors() &&
  547. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  548. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  549. auto comp_node = load_comp_node(tensor->comp_node());
  550. auto layout = load_tensor_layout(tensor);
  551. auto ret = std::make_shared<HostTensorND>(comp_node, layout);
  552. if (tensor->data_size()) {
  553. load_tensor_value(ret.get(), layout, tensor);
  554. }
  555. if (tensor->name()) {
  556. m_tensor_map[tensor->name()->str()] = ret;
  557. }
  558. if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) {
  559. mod(tensor->name() ? tensor->name()->str() : "", tensor->data_size() != 0,
  560. *ret);
  561. }
  562. return ret;
  563. }
  564. std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl::
  565. load_tensor_shared() {
  566. mgb_assert(
  567. m_current_opr->tensors() &&
  568. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  569. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  570. auto comp_node = load_comp_node(tensor->comp_node());
  571. auto layout = load_tensor_layout(tensor);
  572. mgb_assert(tensor->data_size());
  573. auto&& sh_reg = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
  574. auto&& sh_ptr_ref = sh_reg.second[comp_node.mem_node()];
  575. if (sh_ptr_ref) {
  576. // cached tensor value is valid so we can reuse it
  577. load_tensor_value(nullptr, layout, tensor);
  578. if (sh_ptr_ref->comp_node() == comp_node)
  579. return sh_ptr_ref;
  580. // same mem node but different comp node, change comp node and share
  581. // value
  582. auto ret = std::make_shared<DeviceTensorND>(*sh_ptr_ref);
  583. ret->comp_node(comp_node);
  584. return ret;
  585. }
  586. if (tensor->name()) {
  587. sh_reg.first = tensor->name()->str();
  588. }
  589. if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
  590. // directly forward CPU memory
  591. HostTensorND hv{comp_node};
  592. load_tensor_value(&hv, layout, tensor);
  593. sh_ptr_ref = std::make_shared<DeviceTensorND>();
  594. *sh_ptr_ref = DeviceTensorND::make_proxy(hv);
  595. } else {
  596. // use lazy load for non-CPU devices
  597. HostTensorND hv{CompNode::default_cpu()};
  598. load_tensor_value(&hv, layout, tensor);
  599. sh_ptr_ref = m_device_value_loader.make(comp_node, std::move(hv));
  600. }
  601. return sh_ptr_ref;
  602. }
  603. Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() {
  604. const auto* fbmeta = m_loader->m_graph->metadata();
  605. Metadata ret;
  606. if (fbmeta) {
  607. ret.is_valid = fbmeta->is_valid();
  608. ret.graph_modified = fbmeta->graph_modified();
  609. if (fbmeta->user_info()) {
  610. ret.user_info = fbmeta->user_info()->str();
  611. ret.has_user_info = true;
  612. }
  613. if (fbmeta->optimize_options()) {
  614. ret.optimize_options = fbmeta->optimize_options();
  615. ret.optimized_for_inference = true;
  616. }
  617. }
  618. return ret;
  619. }
  620. void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(const fbs::Operator* fbopr) {
  621. m_cur_opr_tensor_cnt = 0;
  622. m_cur_opr_blob_cnt = 0;
  623. m_cur_opr_param_cnt = 0;
  624. OperatorNodeConfig config;
  625. if (fbopr->output_dtype()) {
  626. config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
  627. }
  628. if (fbopr->name()) {
  629. config.name(fbopr->name()->str());
  630. }
  631. if (fbopr->comp_node()) {
  632. auto cnt = fbopr->comp_node()->size();
  633. cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
  634. for (size_t i = 0; i < cnt; i++) {
  635. CompNode cn{};
  636. auto node = fbopr->comp_node()->Get(i);
  637. if (node) {
  638. cn = load_comp_node(node);
  639. }
  640. comp_node_arr[i] = cn;
  641. }
  642. config.comp_node_arr(comp_node_arr);
  643. }
  644. const OprRegistry* registry;
  645. if (magic_compare) {
  646. registry = OprRegistry::find_by_id(fbopr->type_id());
  647. } else {
  648. registry = OprRegistry::find_by_unversioned_id(fbopr->type_id());
  649. }
  650. mgb_throw_if(
  651. !registry, SerializationError,
  652. "failed to find opr with type %s, use python env "
  653. "config.dump_registered_oprs() to get a dict that maps from "
  654. "opr id to opr name",
  655. std::to_string(fbopr->type_id()).c_str());
  656. // load inputs
  657. VarNodeArray inputs;
  658. if (fbopr->inputs()) {
  659. inputs.resize(fbopr->inputs()->size());
  660. for (size_t i = 0; i < inputs.size(); ++i) {
  661. inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i));
  662. }
  663. }
  664. // call loader
  665. auto accessor = registry->loader(*this, inputs, config);
  666. auto opr = accessor.opr();
  667. // check opr type; note that:
  668. // 1. registry->type may be empty for dynamic opr loaders or legacy oprs
  669. // 2. due to some optimization, an opr may be replaced by ImmutableTensor
  670. mgb_assert(
  671. opr && (opr->dyn_typeinfo() == registry->type || !registry->type ||
  672. opr->same_type<opr::ImmutableTensor>()),
  673. "got_type=%s expected_type=%s", opr ? opr->dyn_typeinfo()->name : nullptr,
  674. registry->type->name);
  675. // record output vars; read output names
  676. size_t i = 0;
  677. for (auto ovar : accessor.output()) {
  678. if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  679. m_id2varnode.push_back(ovar);
  680. if (fbopr->output_name()) {
  681. ovar->name(fbopr->output_name()->Get(i++)->str());
  682. }
  683. }
  684. }
  685. opr->node_prop().attribute().priority = fbopr->priority();
  686. }
  687. GraphLoader::LoadResult GraphLoaderOSS::OprLoadContextImpl::load_oprs() {
  688. // load oprs
  689. const auto* oprs = m_loader->m_graph->oprs();
  690. {
  691. // inplace arith graph optimization is disabled during opr load
  692. // it tries to restore the same graph as it was dumped
  693. // see test TestSerializer2.LOGEXP for example
  694. GraphLoader::ScopedGraphOptDisabler _(m_graph);
  695. for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) {
  696. m_current_opr = oprs->Get(i);
  697. load_single_opr(m_current_opr);
  698. }
  699. }
  700. // batched loading device values
  701. m_device_value_loader.apply();
  702. LoadResult ret;
  703. ret.graph = m_graph;
  704. ret.tensor_map = m_tensor_map;
  705. const auto* outputs = m_loader->m_graph->output_vars_idx();
  706. ret.output_var_list.resize(outputs->size());
  707. for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) {
  708. auto out = outputs->Get(i);
  709. auto var = m_id2varnode.at(out->compact_id());
  710. ret.output_var_map[var->name()] = var;
  711. ret.output_var_map_id[out->original_id()] = var;
  712. ret.output_var_list[i] = var;
  713. }
  714. mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size());
  715. return ret;
  716. }
  717. GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewind) {
  718. mgb_assert(m_file);
  719. m_cur_load_config = &config;
  720. if (rewind) {
  721. m_file->rewind();
  722. }
  723. uint32_t magic;
  724. m_file->read(&magic, sizeof(magic));
  725. mgb_throw_if(
  726. (magic != MGB_MAGIC) && (magic != MAGIC_V0), SerializationError,
  727. "wrong magic: wanted %#08x or %#08x, actual %#08x (not a invalid fbs "
  728. "model?)",
  729. MGB_MAGIC, MAGIC_V0, magic);
  730. if (magic == MGB_MAGIC) {
  731. // read FeatureBits
  732. magic_compare = true;
  733. m_file->read(&m_feature_bits, 8);
  734. } else {
  735. magic_compare = false;
  736. }
  737. m_file->skip(4);
  738. uint64_t offset_to_fbs;
  739. m_file->read(&offset_to_fbs, sizeof(offset_to_fbs));
  740. auto tensor_begin = m_file->tell();
  741. // Skip tensor data
  742. m_file->skip(offset_to_fbs);
  743. // Read fbs::Graph
  744. uint32_t size;
  745. m_file->read(&size, sizeof(size));
  746. m_graph_buf = m_file->read_shared(size);
  747. // Rewind back to tensor data
  748. m_file->rewind();
  749. m_file->skip(tensor_begin);
  750. mgb_throw_if(
  751. !fbs::GraphBufferHasIdentifier(m_graph_buf.data()), SerializationError,
  752. "invalid fbs model");
  753. {
  754. flatbuffers::Verifier verifier(
  755. static_cast<const uint8_t*>(m_graph_buf.data()), m_graph_buf.size());
  756. mgb_throw_if(
  757. !fbs::VerifyGraphBuffer(verifier), SerializationError,
  758. "model verification failed (invalid or corrupted model?)");
  759. }
  760. m_graph = fbs::GetGraph(m_graph_buf.data());
  761. m_mgb_version = m_graph->mgb_version();
  762. if (m_graph->mgb_version() > MGB_VERSION) {
  763. mgb_log_warn(
  764. "loading model from future runtime: version=%u "
  765. "model_version=%u",
  766. MGB_VERSION, m_graph->mgb_version());
  767. }
  768. if (!m_graph_hash) {
  769. m_graph_hash = m_graph->hash();
  770. mgb_assert(
  771. m_graph_hash,
  772. "invalid graph hash; maybe error "
  773. "occurred during graph dump");
  774. } else {
  775. mgb_assert(
  776. m_graph_hash == m_graph->hash(),
  777. "A GraphLoader instance can be used to load only one graph,"
  778. " since the tensor values are shared. Previous graph hash "
  779. "is 0x%llx, current graph hash is 0x%llx.",
  780. static_cast<unsigned long long>(m_graph_hash),
  781. static_cast<unsigned long long>(m_graph->hash()));
  782. }
  783. if (m_shared_tensor_map.empty()) {
  784. m_shared_tensor_map.resize(m_graph->nr_shared_tensor());
  785. } else {
  786. mgb_assert(m_shared_tensor_map.size() == m_graph->nr_shared_tensor());
  787. }
  788. OprLoadContextImpl ctx{this, m_graph->mgb_version()};
  789. auto metadata = ctx.load_metadata();
  790. auto result = ctx.load_oprs();
  791. result.metadata = metadata;
  792. auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size;
  793. auto cur = m_file->tell();
  794. mgb_assert(fbs_end > cur);
  795. // Skip to Graph end
  796. m_file->skip(fbs_end - cur);
  797. result.graph_compile_ahead();
  798. return result;
  799. }
  800. std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file) {
  801. return std::make_unique<GraphDumperOSS>(std::move(file));
  802. }
  803. std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) {
  804. return std::make_unique<GraphLoaderOSS>(std::move(file));
  805. }
  806. } // namespace serialization
  807. } // namespace mgb
  808. #endif