You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.cpp 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. /**
  2. * \file src/gopt/test/network.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "./network.h"
  13. using namespace mgb;
  14. SymbolVar Network::add_conv(
  15. SymbolVar f, size_t output_channels, KernSize kern_size, DType out_dtype,
  16. bool has_relu, Stride stride, Padding padding) {
  17. static int weight_idx = 0;
  18. static int bias_idx = 0;
  19. size_t input_channels = f.node()->shape()[1];
  20. auto weight = add_cvar(
  21. ssprintf("w%d", weight_idx).c_str(),
  22. {output_channels, input_channels, kern_size[0], kern_size[1]});
  23. auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1});
  24. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  25. weight = add_type_cvt(weight, out_dtype);
  26. bias = add_type_cvt(bias, dtype::QuantizedS32{1.f});
  27. }
  28. opr::ConvBias::Param param;
  29. param.stride_h = stride[0], param.stride_w = stride[1];
  30. param.pad_h = padding[0], param.pad_w = padding[1];
  31. if (has_relu) {
  32. param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
  33. } else {
  34. param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
  35. }
  36. auto conv = opr::ConvBias::make(
  37. f, weight, bias, param, {}, OperatorNodeConfig{out_dtype});
  38. weight_idx++;
  39. bias_idx++;
  40. return conv;
  41. }
  42. SymbolVar Network::add_group_conv(
  43. SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size,
  44. DType out_dtype, bool has_relu, Stride stride, Padding padding) {
  45. static int weight_idx = 0;
  46. static int bias_idx = 0;
  47. size_t input_channels = f.node()->shape()[1];
  48. auto weight = add_cvar(
  49. ssprintf("w%d", weight_idx).c_str(),
  50. {groups, output_channels / groups, input_channels / groups, kern_size[0],
  51. kern_size[1]});
  52. auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1});
  53. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  54. weight = add_type_cvt(weight, out_dtype);
  55. bias = add_type_cvt(bias, dtype::QuantizedS32{1.f});
  56. }
  57. opr::ConvBias::Param param;
  58. param.sparse = opr::ConvBias::Param::Sparse::GROUP;
  59. param.stride_h = stride[0], param.stride_w = stride[1];
  60. param.pad_h = padding[0], param.pad_w = padding[1];
  61. if (has_relu) {
  62. param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
  63. } else {
  64. param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
  65. }
  66. weight_idx++;
  67. bias_idx++;
  68. SymbolVar conv;
  69. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  70. conv = opr::ConvBias::make(
  71. f, weight, bias, param, {}, OperatorNodeConfig{out_dtype});
  72. } else {
  73. conv = opr::ConvBias::make(f, weight, bias, param, {});
  74. }
  75. weight_idx++;
  76. bias_idx++;
  77. return conv;
  78. }
  79. SymbolVar Network::add_deconv(
  80. SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype) {
  81. static int weight_idx = 0;
  82. size_t kernel = ratio * 2 - ratio % 2;
  83. size_t pad = ratio / 2;
  84. size_t input_channels = f.node()->shape()[1];
  85. auto weight = add_cvar(
  86. ssprintf("w%d", weight_idx).c_str(),
  87. {input_channels, output_channels, kernel, kernel});
  88. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  89. weight = add_type_cvt(weight, out_dtype);
  90. }
  91. opr::ConvolutionBackwardData::Param param;
  92. param.stride_h = param.stride_w = ratio;
  93. param.pad_h = param.pad_w = pad;
  94. auto deconv = opr::ConvolutionBackwardData::make(
  95. weight, f, param, {}, OperatorNodeConfig{out_dtype});
  96. weight_idx++;
  97. return deconv;
  98. }
  99. SymbolVar Network::add_elemwise(
  100. const SymbolVarArray inps, DType out_dtype, opr::Elemwise::Param::Mode mode) {
  101. using ElemMode = opr::Elemwise::Param::Mode;
  102. using MultiMode = opr::ElemwiseMultiType::Param::Mode;
  103. static const ThinHashMap<ElemMode, MultiMode> map = {
  104. {ElemMode::ADD, MultiMode::QADD},
  105. {ElemMode::FUSE_ADD_RELU, MultiMode::QFUSE_ADD_RELU}};
  106. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  107. MultiMode alter_mode = map.at(mode);
  108. return opr::ElemwiseMultiType::make(
  109. inps, {alter_mode}, OperatorNodeConfig{out_dtype});
  110. } else {
  111. return opr::Elemwise::make(inps, mode);
  112. }
  113. }
  114. SymbolVar Network::add_pooling(
  115. SymbolVar f, Window window, Stride stride, Padding padding,
  116. opr::Pooling::Param::Mode mode) {
  117. opr::Pooling::Param param;
  118. param.window_h = window[0], param.window_w = window[1];
  119. param.stride_h = stride[0], param.stride_w = stride[1];
  120. param.pad_h = padding[0], param.pad_w = padding[1];
  121. param.mode = mode;
  122. return opr::Pooling::make(f, param);
  123. }
  124. SymbolVar Network::add_type_cvt(SymbolVar f, DType out_dtype) {
  125. return opr::TypeCvt::make(f, out_dtype);
  126. }
  127. SymbolVar Network::add_concat(SymbolVar f, SymbolVar g, int axis) {
  128. return opr::Concat::make({f, g}, axis);
  129. }
  130. SymbolVar mgb::create_block(
  131. Network& network, SymbolVar f_in, size_t stride, size_t num_outputs1,
  132. bool has_proj, DType out_dtype) {
  133. auto proj = f_in;
  134. if (has_proj) {
  135. proj = network.add_conv(
  136. f_in, num_outputs1, {1, 1}, out_dtype, false, {stride, stride});
  137. }
  138. auto f = network.add_conv(
  139. f_in, num_outputs1, {3, 3}, out_dtype, true, {stride, stride}, {1, 1});
  140. f = network.add_conv(f, num_outputs1, {3, 3}, out_dtype, true, {1, 1}, {1, 1});
  141. f = network.add_elemwise({f, proj}, out_dtype, opr::Elemwise::Mode::FUSE_ADD_RELU);
  142. return f;
  143. }
  144. SymbolVar mgb::make_resnet18(Network& network, size_t batch, DType out_dtype) {
  145. auto data = network.add_var("data", {batch, 4, 224, 224});
  146. if (out_dtype.category() == DTypeCategory::QUANTIZED)
  147. data = network.add_type_cvt(data, dtype::QuantizedS8{1.f});
  148. auto first = out_dtype;
  149. if (out_dtype.category() == DTypeCategory::QUANTIZED)
  150. first = dtype::QuantizedS8{1.f};
  151. auto f = network.add_conv(data, 64, {7, 7}, first, true, {2, 2}, {3, 3});
  152. if (out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
  153. out_dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  154. f = network.add_type_cvt(f, out_dtype);
  155. }
  156. f = network.add_pooling(f, {3, 3}, {2, 2}, {1, 1});
  157. using Vector = SmallVector<size_t, 4>;
  158. Vector stages = {2, 2, 2, 2};
  159. Vector mid_outputs = {64, 128, 256, 512};
  160. Vector enable_stride = {0, 1, 1, 1};
  161. for (size_t i = 0; i < 4; ++i) {
  162. auto s = stages[i];
  163. auto o = mid_outputs[i];
  164. auto es = enable_stride[i];
  165. for (size_t j = 0; j < s; ++j) {
  166. size_t stride = !es || j > 0 ? 1 : 2;
  167. bool has_proj = j > 0 ? false : true;
  168. f = create_block(network, f, stride, o, has_proj, out_dtype);
  169. }
  170. }
  171. f = network.add_pooling(
  172. f, {7, 7}, {7, 7}, {0, 0}, opr::Pooling::Param::Mode::AVERAGE);
  173. f = network.add_type_cvt(f, dtype::Float32());
  174. return f;
  175. }
  176. namespace {
  177. SymbolVarArray make_pyramids(Network& network, size_t batch, DType out_dtype) {
  178. SymbolVarArray pyramids;
  179. auto data = network.add_var("data", {batch, 3, 256, 256});
  180. data = data + (-128.f);
  181. if (out_dtype.category() == DTypeCategory::QUANTIZED)
  182. data = network.add_type_cvt(data, dtype::QuantizedS8{1.f});
  183. auto first = out_dtype;
  184. if (out_dtype.category() == DTypeCategory::QUANTIZED)
  185. first = dtype::QuantizedS8{1.f};
  186. auto f = network.add_conv(data, 16, {3, 3}, first, true, {2, 2}, {1, 1});
  187. f = network.add_conv(f, 16, {3, 3}, first, true, {1, 1}, {1, 1});
  188. f = network.add_conv(f, 32, {3, 3}, first, true, {2, 2}, {1, 1});
  189. if (out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
  190. out_dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  191. f = network.add_type_cvt(f, out_dtype);
  192. }
  193. using Vector = SmallVector<size_t, 4>;
  194. Vector stages = {3, 6, 6, 3};
  195. Vector mid_outputs = {32, 64, 128, 256};
  196. Vector enable_stride = {0, 1, 1, 1};
  197. for (size_t i = 0; i < 4; ++i) {
  198. auto s = stages[i];
  199. auto o = mid_outputs[i];
  200. auto es = enable_stride[i];
  201. for (size_t j = 0; j < s; ++j) {
  202. size_t stride = !es || j > 0 ? 1 : 2;
  203. bool has_proj = j > 0 ? false : true;
  204. f = create_block(network, f, stride, o, has_proj, out_dtype);
  205. }
  206. pyramids.push_back(f);
  207. }
  208. for (size_t i = 0; i < pyramids.size(); ++i) {
  209. pyramids[i] = network.add_type_cvt(pyramids[i], first);
  210. }
  211. return pyramids;
  212. }
  213. SymbolVarArray fusion_pyramids_feature(
  214. Network& network, SymbolVarArray pyramids, size_t fpn_conv_channels) {
  215. bool touch = false;
  216. SymbolVar x;
  217. SymbolVarArray fpn;
  218. for (int i = 5; i >= 3; --i) {
  219. auto f = network.add_conv(
  220. pyramids[i - 2], fpn_conv_channels, {1, 1}, dtype::QuantizedS8{1.f},
  221. false, {1, 1}, {0, 0});
  222. if (!touch) {
  223. x = f;
  224. touch = true;
  225. } else {
  226. x = network.add_deconv(x, 2, 16, dtype::QuantizedS8{1.f});
  227. x = network.add_elemwise(
  228. {x, f}, dtype::QuantizedS8{1.f}, opr::Elemwise::Mode::ADD);
  229. }
  230. fpn.push_back(x);
  231. }
  232. x = fpn[0];
  233. for (int i = 6; i < 8; ++i) {
  234. x = network.add_conv(
  235. x, fpn_conv_channels, {3, 3}, dtype::QuantizedS8{1.f}, true, {2, 2},
  236. {1, 1});
  237. }
  238. return fpn;
  239. }
  240. } // namespace
  241. SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) {
  242. SymbolVarArray outputs;
  243. auto pyramids = make_pyramids(network, batch, out_dtype);
  244. auto fpn_hv = fusion_pyramids_feature(network, pyramids, 16);
  245. auto fpn_plate = fusion_pyramids_feature(network, pyramids, 16);
  246. outputs.insert(outputs.end(), fpn_hv.begin(), fpn_hv.end());
  247. outputs.insert(outputs.end(), fpn_plate.begin(), fpn_plate.end());
  248. return outputs;
  249. }
  250. SymbolVar mgb::bottleneck(
  251. Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t,
  252. size_t stride, DType out_dtype) {
  253. size_t in_channels = f.node()->shape()[1];
  254. SymbolVar x = f;
  255. if (t != 1) {
  256. x = network.add_conv(
  257. f, input_channels * t, {1, 1}, out_dtype, true, {1, 1}, {0, 0});
  258. }
  259. x = network.add_group_conv(
  260. x, input_channels * t, input_channels * t, {3, 3}, out_dtype, true,
  261. {stride, stride}, {1, 1});
  262. x = network.add_conv(x, channels, {1, 1}, out_dtype, false, {1, 1}, {0, 0});
  263. if (stride == 1 && in_channels == channels)
  264. x = f + x;
  265. return x;
  266. }
  267. SymbolVar mgb::bottleneck_group(
  268. Network& network, SymbolVar f, size_t input_channels, size_t channels,
  269. size_t stages, size_t s, size_t t, DType out_dtype) {
  270. SymbolVar x = f;
  271. for (size_t i = 0; i < stages; ++i) {
  272. size_t stride = i == 0 ? s : 1;
  273. x = bottleneck(network, x, input_channels, channels, t, stride, out_dtype);
  274. input_channels = channels;
  275. }
  276. return x;
  277. }
  278. namespace {
  279. size_t make_divisible(size_t v, size_t divisor) {
  280. size_t min_value = divisor;
  281. size_t new_v = std::max(min_value, (v + divisor / 2) / divisor * divisor);
  282. if (new_v < 0.9 * v)
  283. new_v += divisor;
  284. return new_v;
  285. }
  286. } // namespace
  287. SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch, DType out_dtype) {
  288. auto data = network.add_var("data", {batch, 3, 224, 224});
  289. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  290. data = network.add_type_cvt(data, dtype::QuantizedS8{1.f});
  291. }
  292. constexpr size_t round_nearest = 8;
  293. auto x = network.add_conv(
  294. data, make_divisible(32, round_nearest), {3, 3}, out_dtype, true, {2, 2},
  295. {1, 1});
  296. x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1, out_dtype);
  297. x = bottleneck_group(
  298. network, x, 16, make_divisible(24, round_nearest), 2, 2, 6, out_dtype);
  299. x = bottleneck_group(
  300. network, x, 24, make_divisible(32, round_nearest), 3, 2, 6, out_dtype);
  301. x = bottleneck_group(
  302. network, x, 32, make_divisible(64, round_nearest), 4, 2, 6, out_dtype);
  303. x = bottleneck_group(
  304. network, x, 64, make_divisible(96, round_nearest), 3, 1, 6, out_dtype);
  305. x = bottleneck_group(
  306. network, x, 96, make_divisible(160, round_nearest), 3, 2, 6, out_dtype);
  307. x = bottleneck_group(
  308. network, x, 160, make_divisible(320, round_nearest), 1, 1, 6, out_dtype);
  309. x = network.add_conv(
  310. x, make_divisible(1280, round_nearest), {1, 1}, out_dtype, true, {1, 1},
  311. {0, 0});
  312. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  313. x = network.add_type_cvt(x, dtype::Float32());
  314. }
  315. return x;
  316. }
  317. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台