You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_load_dump.cpp 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * \file src/serialization/impl/opr_load_dump.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "megbrain/serialization/opr_load_dump.h"
  10. #include "megbrain/opr/param_defs.h"
  11. #include "megbrain/serialization/file.h"
  12. #include "megbrain/serialization/helper.h"
  13. using namespace mgb;
  14. using namespace serialization;
  15. MGB_TYPEINFO_OBJ_IMPL(OprLoadContext);
  16. OprLoader OprLoadContext::make_opr_loader(const std::string& id) {
  17. auto&& maker = config().opr_loader_maker;
  18. mgb_throw_if(
  19. !maker, SerializationError,
  20. "opr_loader_maker not set in LoadConfig; but opr loader with "
  21. "id %s is needed",
  22. id.c_str());
  23. return maker(id);
  24. }
  25. template <>
  26. void OprDumpContextRawPOD::write_param(const DType& param) {
  27. if (m_check_param_tag) {
  28. uint32_t tag = megdnn::param::FakeSerializedDType::TAG;
  29. write_raw(&tag, sizeof(tag));
  30. }
  31. serialization::serialize_dtype(
  32. param, [this](const void* data, size_t len) { write_raw(data, len); });
  33. }
  34. template <>
  35. DType OprLoadContextRawPOD::read_param() {
  36. if (m_check_param_tag) {
  37. uint32_t tag;
  38. read_raw(&tag, sizeof(tag));
  39. mgb_throw_if(
  40. tag != megdnn::param::FakeSerializedDType::TAG, MegBrainError,
  41. "ERROR tag");
  42. }
  43. return serialization::deserialize_dtype(
  44. [this](void* data, size_t len) { read_raw(data, len); });
  45. }
  46. std::string OprLoadContextRawPOD::load_buf_with_len() {
  47. std::string ret;
  48. uint32_t size;
  49. read_raw(&size, sizeof(size));
  50. ret.resize(size);
  51. read_raw(&ret[0], size);
  52. return ret;
  53. }
  54. SharedBuffer OprLoadContextRawPOD::load_shared_buf_with_len() {
  55. uint32_t size;
  56. read_raw(&size, sizeof(size));
  57. return load_shared_buf(size);
  58. }
  59. void GraphDumpConfig::default_tensor_value_dumper(
  60. OutputFile& fout, const cg::OperatorNodeBase& /*opr*/,
  61. const HostTensorND& tensor) {
  62. auto size = tensor.layout().span().high_byte;
  63. fout.write(tensor.raw_ptr(), size);
  64. }
  65. void GraphLoadConfig::default_tensor_value_loader(
  66. void* ptr, const TensorLayout& layout, InputFile& fin) {
  67. auto sz = layout.span().high_byte;
  68. if (ptr) {
  69. fin.read(ptr, sz);
  70. } else {
  71. fin.skip(sz);
  72. }
  73. }
  74. SharedBuffer OprLoadContextRawPOD::load_shared_buf(size_t size) {
  75. std::shared_ptr<uint8_t> shptr{new uint8_t[size], [](uint8_t* p) { delete[] p; }};
  76. read_raw(shptr.get(), size);
  77. return {std::move(shptr), size};
  78. }
  79. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}