You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.cpp 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. /**
  2. * \file src/mge/common.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite_build_config.h"
  12. #if LITE_BUILD_WITH_MGE
  13. #include "common.h"
  14. #include "megdnn/dtype.h"
  15. using namespace lite;
  16. using namespace mgb;
  17. enum class CompressionMethod {
  18. NO_COMPRESSION = 0,
  19. FLOAT32_STRIDE_FLOAT32_BASE_UINT8_WEIGHTS = 1,
  20. FLOAT32_STRIDE_FLOAT32_BASE_UINT16_WEIGHTS = 2,
  21. };
  22. void lite::decompressed_tensor_value_loader(
  23. void* ptr_, const mgb::TensorLayout& layout,
  24. mgb::serialization::InputFile& fin) {
  25. uint8_t compress_flag;
  26. fin.read(&compress_flag, sizeof(compress_flag));
  27. size_t num_weights = layout.total_nr_elems();
  28. switch (CompressionMethod(compress_flag)) {
  29. case CompressionMethod::NO_COMPRESSION: {
  30. mgb::serialization::GraphLoadConfig::default_tensor_value_loader(
  31. ptr_, layout, fin);
  32. break;
  33. }
  34. case CompressionMethod::FLOAT32_STRIDE_FLOAT32_BASE_UINT8_WEIGHTS: {
  35. if (ptr_) {
  36. float stride, base;
  37. std::vector<uint8_t> weights(num_weights);
  38. fin.read(&stride, sizeof(stride));
  39. fin.read(&base, sizeof(base));
  40. fin.read(weights.data(), num_weights * sizeof(uint8_t));
  41. auto* ptr = static_cast<float*>(ptr_);
  42. for (size_t i = 0; i < num_weights; ++i)
  43. ptr[i] = stride * weights[i] + base;
  44. } else {
  45. fin.skip(sizeof(float) * 2 + num_weights * sizeof(uint8_t));
  46. }
  47. break;
  48. }
  49. case CompressionMethod::FLOAT32_STRIDE_FLOAT32_BASE_UINT16_WEIGHTS: {
  50. if (ptr_) {
  51. float stride, base;
  52. std::vector<uint16_t> weights(num_weights);
  53. fin.read(&stride, sizeof(stride));
  54. fin.read(&base, sizeof(base));
  55. fin.read(weights.data(), num_weights * sizeof(uint16_t));
  56. auto* ptr = static_cast<float*>(ptr_);
  57. for (size_t i = 0; i < num_weights; ++i)
  58. ptr[i] = stride * weights[i] + base;
  59. } else {
  60. fin.skip(sizeof(float) * 2 + num_weights * sizeof(uint16_t));
  61. }
  62. break;
  63. }
  64. default:
  65. LITE_THROW("Unexpected compression method");
  66. }
  67. }
  68. LTensorLayout lite::to_impl_layout(const Layout& layout) {
  69. mgb::TensorLayout mge_layout;
  70. mge_layout.ndim = layout.ndim;
  71. LITE_ASSERT(layout.ndim < TensorShape::MAX_NDIM, "lite layout ndim is to large");
  72. for (size_t i = 0; i < layout.ndim; i++) {
  73. mge_layout.shape[i] = layout.shapes[i];
  74. }
  75. mge_layout.init_contiguous_stride();
  76. switch (layout.data_type) {
  77. case LiteDataType::LITE_FLOAT:
  78. mge_layout.dtype = mgb::dtype::Float32();
  79. break;
  80. #if !MEGDNN_DISABLE_FLOAT16
  81. case LiteDataType::LITE_HALF:
  82. mge_layout.dtype = mgb::dtype::Float16();
  83. break;
  84. #endif
  85. case LiteDataType::LITE_INT:
  86. mge_layout.dtype = mgb::dtype::Int32();
  87. break;
  88. case LiteDataType::LITE_INT8:
  89. mge_layout.dtype = mgb::dtype::Int8();
  90. break;
  91. case LiteDataType::LITE_UINT8:
  92. mge_layout.dtype = mgb::dtype::Uint8();
  93. break;
  94. case LiteDataType::LITE_INT16:
  95. mge_layout.dtype = mgb::dtype::Int16();
  96. break;
  97. case LiteDataType::LITE_UINT16:
  98. mge_layout.dtype = mgb::dtype::Uint16();
  99. break;
  100. default:
  101. LITE_THROW(mgb::ssprintf(
  102. "unsupport dtype in lite enum id is %d.",
  103. static_cast<int>(layout.data_type)));
  104. }
  105. return mge_layout;
  106. }
  107. Layout lite::to_lite_layout(const LTensorLayout& mge_layout) {
  108. Layout layout;
  109. if (!mge_layout.dtype.valid()) {
  110. return layout;
  111. }
  112. layout.ndim = mge_layout.ndim;
  113. LITE_ASSERT(layout.ndim < layout.MAXDIM, "tensor layout ndim is to large");
  114. for (size_t i = 0; i < layout.ndim; i++) {
  115. layout.shapes[i] = mge_layout.shape[i];
  116. }
  117. switch (mge_layout.dtype.enumv()) {
  118. case mgb::DTypeEnum::Float32:
  119. layout.data_type = LiteDataType::LITE_FLOAT;
  120. break;
  121. #if !MEGDNN_DISABLE_FLOAT16
  122. case mgb::DTypeEnum::Float16:
  123. layout.data_type = LiteDataType::LITE_HALF;
  124. break;
  125. #endif
  126. case mgb::DTypeEnum::Int32:
  127. layout.data_type = LiteDataType::LITE_INT;
  128. break;
  129. case mgb::DTypeEnum::Int16:
  130. layout.data_type = LiteDataType::LITE_INT16;
  131. break;
  132. case mgb::DTypeEnum::Uint16:
  133. layout.data_type = LiteDataType::LITE_UINT16;
  134. break;
  135. case mgb::DTypeEnum::Int8:
  136. layout.data_type = LiteDataType::LITE_INT8;
  137. break;
  138. case mgb::DTypeEnum::Uint8:
  139. layout.data_type = LiteDataType::LITE_UINT8;
  140. break;
  141. default:
  142. LITE_THROW(mgb::ssprintf(
  143. "unsupport dtype in lite : %s.", mge_layout.to_string().c_str()));
  144. }
  145. return layout;
  146. }
  147. mgb::CompNode::Locator lite::to_compnode_locator(const LiteDeviceType& device) {
  148. mgb::CompNode::Locator loc;
  149. switch (device) {
  150. case LiteDeviceType::LITE_CPU:
  151. loc.type = mgb::CompNode::DeviceType::CPU;
  152. break;
  153. case LiteDeviceType::LITE_CUDA:
  154. loc.type = mgb::CompNode::DeviceType::CUDA;
  155. break;
  156. case LiteDeviceType::LITE_ATLAS:
  157. loc.type = mgb::CompNode::DeviceType::ATLAS;
  158. break;
  159. case LiteDeviceType::LITE_DEVICE_DEFAULT:
  160. loc.type = mgb::CompNode::DeviceType::UNSPEC;
  161. break;
  162. default:
  163. LITE_THROW(ssprintf(
  164. "lite unsupported compnode type: enum value: %d.", (int)(device)));
  165. }
  166. return loc;
  167. }
  168. LiteDeviceType lite::get_device_from_locator(const mgb::CompNode::Locator& locator) {
  169. switch (locator.type) {
  170. case mgb::CompNode::DeviceType::CPU:
  171. case mgb::CompNode::DeviceType::MULTITHREAD:
  172. return LiteDeviceType::LITE_CPU;
  173. case mgb::CompNode::DeviceType::CUDA:
  174. return LiteDeviceType::LITE_CUDA;
  175. case mgb::CompNode::DeviceType::ATLAS:
  176. return LiteDeviceType::LITE_ATLAS;
  177. case mgb::CompNode::DeviceType::UNSPEC:
  178. return LiteDeviceType::LITE_DEVICE_DEFAULT;
  179. default:
  180. LITE_THROW(ssprintf(
  181. "lite unsupported compnode type: enum value: %d.",
  182. (int)(locator.type)));
  183. }
  184. }
  185. #endif
  186. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台