You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.cpp 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #include "lite_build_config.h"
  2. #if LITE_BUILD_WITH_MGE
  3. #include "common.h"
  4. #include "megdnn/dtype.h"
  5. using namespace lite;
  6. using namespace mgb;
  7. enum class CompressionMethod {
  8. NO_COMPRESSION = 0,
  9. FLOAT32_STRIDE_FLOAT32_BASE_UINT8_WEIGHTS = 1,
  10. FLOAT32_STRIDE_FLOAT32_BASE_UINT16_WEIGHTS = 2,
  11. };
  12. void lite::decompressed_tensor_value_loader(
  13. void* ptr_, const mgb::TensorLayout& layout,
  14. mgb::serialization::InputFile& fin) {
  15. uint8_t compress_flag;
  16. fin.read(&compress_flag, sizeof(compress_flag));
  17. size_t num_weights = layout.total_nr_elems();
  18. switch (CompressionMethod(compress_flag)) {
  19. case CompressionMethod::NO_COMPRESSION: {
  20. mgb::serialization::GraphLoadConfig::default_tensor_value_loader(
  21. ptr_, layout, fin);
  22. break;
  23. }
  24. case CompressionMethod::FLOAT32_STRIDE_FLOAT32_BASE_UINT8_WEIGHTS: {
  25. if (ptr_) {
  26. float stride, base;
  27. std::vector<uint8_t> weights(num_weights);
  28. fin.read(&stride, sizeof(stride));
  29. fin.read(&base, sizeof(base));
  30. fin.read(weights.data(), num_weights * sizeof(uint8_t));
  31. auto* ptr = static_cast<float*>(ptr_);
  32. for (size_t i = 0; i < num_weights; ++i)
  33. ptr[i] = stride * weights[i] + base;
  34. } else {
  35. fin.skip(sizeof(float) * 2 + num_weights * sizeof(uint8_t));
  36. }
  37. break;
  38. }
  39. case CompressionMethod::FLOAT32_STRIDE_FLOAT32_BASE_UINT16_WEIGHTS: {
  40. if (ptr_) {
  41. float stride, base;
  42. std::vector<uint16_t> weights(num_weights);
  43. fin.read(&stride, sizeof(stride));
  44. fin.read(&base, sizeof(base));
  45. fin.read(weights.data(), num_weights * sizeof(uint16_t));
  46. auto* ptr = static_cast<float*>(ptr_);
  47. for (size_t i = 0; i < num_weights; ++i)
  48. ptr[i] = stride * weights[i] + base;
  49. } else {
  50. fin.skip(sizeof(float) * 2 + num_weights * sizeof(uint16_t));
  51. }
  52. break;
  53. }
  54. default:
  55. LITE_THROW("Unexpected compression method");
  56. }
  57. }
  58. LTensorLayout lite::to_impl_layout(const Layout& layout) {
  59. mgb::TensorLayout mge_layout;
  60. mge_layout.ndim = layout.ndim;
  61. LITE_ASSERT(layout.ndim < TensorShape::MAX_NDIM, "lite layout ndim is to large");
  62. for (size_t i = 0; i < layout.ndim; i++) {
  63. mge_layout.shape[i] = layout.shapes[i];
  64. }
  65. mge_layout.init_contiguous_stride();
  66. switch (layout.data_type) {
  67. case LiteDataType::LITE_FLOAT:
  68. mge_layout.dtype = mgb::dtype::Float32();
  69. break;
  70. #if !MEGDNN_DISABLE_FLOAT16
  71. case LiteDataType::LITE_HALF:
  72. mge_layout.dtype = mgb::dtype::Float16();
  73. break;
  74. #endif
  75. case LiteDataType::LITE_INT:
  76. mge_layout.dtype = mgb::dtype::Int32();
  77. break;
  78. case LiteDataType::LITE_INT8:
  79. mge_layout.dtype = mgb::dtype::Int8();
  80. break;
  81. case LiteDataType::LITE_UINT8:
  82. mge_layout.dtype = mgb::dtype::Uint8();
  83. break;
  84. case LiteDataType::LITE_INT16:
  85. mge_layout.dtype = mgb::dtype::Int16();
  86. break;
  87. case LiteDataType::LITE_UINT16:
  88. mge_layout.dtype = mgb::dtype::Uint16();
  89. break;
  90. default:
  91. LITE_THROW(mgb::ssprintf(
  92. "unsupport dtype in lite enum id is %d.",
  93. static_cast<int>(layout.data_type)));
  94. }
  95. return mge_layout;
  96. }
  97. Layout lite::to_lite_layout(const LTensorLayout& mge_layout) {
  98. Layout layout;
  99. if (!mge_layout.dtype.valid()) {
  100. return layout;
  101. }
  102. layout.ndim = mge_layout.ndim;
  103. LITE_ASSERT(layout.ndim < layout.MAXDIM, "tensor layout ndim is to large");
  104. for (size_t i = 0; i < layout.ndim; i++) {
  105. layout.shapes[i] = mge_layout.shape[i];
  106. }
  107. switch (mge_layout.dtype.enumv()) {
  108. case mgb::DTypeEnum::Float32:
  109. layout.data_type = LiteDataType::LITE_FLOAT;
  110. break;
  111. #if !MEGDNN_DISABLE_FLOAT16
  112. case mgb::DTypeEnum::Float16:
  113. layout.data_type = LiteDataType::LITE_HALF;
  114. break;
  115. #endif
  116. case mgb::DTypeEnum::Int32:
  117. layout.data_type = LiteDataType::LITE_INT;
  118. break;
  119. case mgb::DTypeEnum::Int16:
  120. layout.data_type = LiteDataType::LITE_INT16;
  121. break;
  122. case mgb::DTypeEnum::Uint16:
  123. layout.data_type = LiteDataType::LITE_UINT16;
  124. break;
  125. case mgb::DTypeEnum::Int8:
  126. layout.data_type = LiteDataType::LITE_INT8;
  127. break;
  128. case mgb::DTypeEnum::Uint8:
  129. layout.data_type = LiteDataType::LITE_UINT8;
  130. break;
  131. default:
  132. LITE_THROW(mgb::ssprintf(
  133. "unsupport dtype in lite : %s.", mge_layout.to_string().c_str()));
  134. }
  135. return layout;
  136. }
  137. mgb::CompNode::Locator lite::to_compnode_locator(const LiteDeviceType& device) {
  138. mgb::CompNode::Locator loc;
  139. switch (device) {
  140. case LiteDeviceType::LITE_CPU:
  141. loc.type = mgb::CompNode::DeviceType::CPU;
  142. break;
  143. case LiteDeviceType::LITE_CUDA:
  144. loc.type = mgb::CompNode::DeviceType::CUDA;
  145. break;
  146. case LiteDeviceType::LITE_ATLAS:
  147. loc.type = mgb::CompNode::DeviceType::ATLAS;
  148. break;
  149. case LiteDeviceType::LITE_CAMBRICON:
  150. loc.type = mgb::CompNode::DeviceType::CAMBRICON;
  151. break;
  152. case LiteDeviceType::LITE_DEVICE_DEFAULT:
  153. loc.type = mgb::CompNode::DeviceType::UNSPEC;
  154. break;
  155. default:
  156. LITE_THROW(ssprintf(
  157. "lite unsupported compnode type: enum value: %d.", (int)(device)));
  158. }
  159. return loc;
  160. }
  161. LiteDeviceType lite::get_device_from_locator(const mgb::CompNode::Locator& locator) {
  162. switch (locator.type) {
  163. case mgb::CompNode::DeviceType::CPU:
  164. case mgb::CompNode::DeviceType::MULTITHREAD:
  165. return LiteDeviceType::LITE_CPU;
  166. case mgb::CompNode::DeviceType::CUDA:
  167. return LiteDeviceType::LITE_CUDA;
  168. case mgb::CompNode::DeviceType::ATLAS:
  169. return LiteDeviceType::LITE_ATLAS;
  170. case mgb::CompNode::DeviceType::CAMBRICON:
  171. return LiteDeviceType::LITE_CAMBRICON;
  172. case mgb::CompNode::DeviceType::UNSPEC:
  173. return LiteDeviceType::LITE_DEVICE_DEFAULT;
  174. default:
  175. LITE_THROW(ssprintf(
  176. "lite unsupported compnode type: enum value: %d.",
  177. (int)(locator.type)));
  178. }
  179. }
  180. #endif
  181. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}