You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #include "./network.h"
  2. #include "megbrain/opr/tensor_manip.h"
  3. using namespace mgb;
  4. SymbolVar Network::add_conv(
  5. SymbolVar f, size_t output_channels, KernSize kern_size, DType out_dtype,
  6. bool has_relu, Stride stride, Padding padding) {
  7. static int weight_idx = 0;
  8. static int bias_idx = 0;
  9. size_t input_channels = f.node()->shape()[1];
  10. auto weight = add_cvar(
  11. ssprintf("w%d", weight_idx).c_str(),
  12. {output_channels, input_channels, kern_size[0], kern_size[1]});
  13. auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1});
  14. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  15. weight = add_type_cvt(weight, out_dtype);
  16. bias = add_type_cvt(bias, dtype::QuantizedS32{1.f});
  17. }
  18. opr::ConvBias::Param param;
  19. param.stride_h = stride[0], param.stride_w = stride[1];
  20. param.pad_h = padding[0], param.pad_w = padding[1];
  21. if (has_relu) {
  22. param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
  23. } else {
  24. param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
  25. }
  26. SymbolVar conv;
  27. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  28. conv = opr::ConvBias::make(
  29. f, weight, bias, param, {}, OperatorNodeConfig{out_dtype});
  30. } else {
  31. conv = opr::ConvBias::make(f, weight, bias, param, {});
  32. }
  33. weight_idx++;
  34. bias_idx++;
  35. return conv;
  36. }
  37. SymbolVar Network::add_group_conv(
  38. SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size,
  39. DType out_dtype, bool has_relu, Stride stride, Padding padding) {
  40. static int weight_idx = 0;
  41. static int bias_idx = 0;
  42. size_t input_channels = f.node()->shape()[1];
  43. auto weight = add_cvar(
  44. ssprintf("w%d", weight_idx).c_str(),
  45. {groups, output_channels / groups, input_channels / groups, kern_size[0],
  46. kern_size[1]});
  47. auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1});
  48. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  49. weight = add_type_cvt(weight, out_dtype);
  50. bias = add_type_cvt(bias, dtype::QuantizedS32{1.f});
  51. }
  52. opr::ConvBias::Param param;
  53. param.sparse = opr::ConvBias::Param::Sparse::GROUP;
  54. param.stride_h = stride[0], param.stride_w = stride[1];
  55. param.pad_h = padding[0], param.pad_w = padding[1];
  56. if (has_relu) {
  57. param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
  58. } else {
  59. param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
  60. }
  61. weight_idx++;
  62. bias_idx++;
  63. SymbolVar conv;
  64. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  65. conv = opr::ConvBias::make(
  66. f, weight, bias, param, {}, OperatorNodeConfig{out_dtype});
  67. } else {
  68. conv = opr::ConvBias::make(f, weight, bias, param, {});
  69. }
  70. weight_idx++;
  71. bias_idx++;
  72. return conv;
  73. }
  74. SymbolVar Network::add_deconv(
  75. SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype) {
  76. static int weight_idx = 0;
  77. size_t kernel = ratio * 2 - ratio % 2;
  78. size_t pad = ratio / 2;
  79. size_t input_channels = f.node()->shape()[1];
  80. auto weight = add_cvar(
  81. ssprintf("w%d", weight_idx).c_str(),
  82. {input_channels, output_channels, kernel, kernel});
  83. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  84. weight = add_type_cvt(weight, out_dtype);
  85. }
  86. opr::ConvolutionBackwardData::Param param;
  87. param.stride_h = param.stride_w = ratio;
  88. param.pad_h = param.pad_w = pad;
  89. auto deconv = opr::ConvolutionBackwardData::make(
  90. weight, f, param, {}, OperatorNodeConfig{out_dtype});
  91. weight_idx++;
  92. return deconv;
  93. }
  94. SymbolVar Network::add_elemwise(
  95. const SymbolVarArray inps, DType out_dtype, opr::Elemwise::Param::Mode mode) {
  96. using ElemMode = opr::Elemwise::Param::Mode;
  97. using MultiMode = opr::ElemwiseMultiType::Param::Mode;
  98. static const ThinHashMap<ElemMode, MultiMode> map = {
  99. {ElemMode::ADD, MultiMode::QADD},
  100. {ElemMode::FUSE_ADD_RELU, MultiMode::QFUSE_ADD_RELU}};
  101. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  102. MultiMode alter_mode = map.at(mode);
  103. return opr::ElemwiseMultiType::make(
  104. inps, {alter_mode}, OperatorNodeConfig{out_dtype});
  105. } else {
  106. return opr::Elemwise::make(inps, mode);
  107. }
  108. }
  109. SymbolVar Network::add_pooling(
  110. SymbolVar f, Window window, Stride stride, Padding padding,
  111. opr::Pooling::Param::Mode mode) {
  112. opr::Pooling::Param param;
  113. param.window_h = window[0], param.window_w = window[1];
  114. param.stride_h = stride[0], param.stride_w = stride[1];
  115. param.pad_h = padding[0], param.pad_w = padding[1];
  116. param.mode = mode;
  117. return opr::Pooling::make(f, param);
  118. }
  119. SymbolVar Network::add_type_cvt(SymbolVar f, DType out_dtype) {
  120. return opr::TypeCvt::make(f, out_dtype);
  121. }
  122. SymbolVar Network::add_concat(SymbolVar f, SymbolVar g, int axis) {
  123. return opr::Concat::make({f, g}, axis);
  124. }
  125. SymbolVar Network::add_dimshuffle(SymbolVar f, std::vector<int> pattern) {
  126. return opr::Dimshuffle::make(f, pattern);
  127. }
  128. SymbolVar Network::add_axisaddremove(SymbolVar f) {
  129. return opr::AxisAddRemove::make(
  130. f, {{opr::AxisAddRemove::AxisDesc::Method::REMOVE, {0}}});
  131. }
  132. SymbolVar Network::add_subtensor(SymbolVar f) {
  133. using AIdx = opr::indexing::AxisIndexer;
  134. return opr::Subtensor::make(
  135. f, {AIdx::make_interval(0, f.make_scalar(0), None, None)});
  136. }
  137. SymbolVar Network::add_reshape(SymbolVar f) {
  138. auto shp = opr::GetVarShape::make(f);
  139. return opr::Reshape::make(f, shp);
  140. }
  141. SymbolVar Network::add_broadcast(SymbolVar f) {
  142. auto shp = opr::GetVarShape::make(f);
  143. return opr::Broadcast::make(f, shp);
  144. }
  145. SymbolVar Network::add_copy(SymbolVar f) {
  146. return opr::Copy::make(f);
  147. }
  148. SymbolVar mgb::create_block(
  149. Network& network, SymbolVar f_in, size_t stride, size_t num_outputs1,
  150. bool has_proj, DType out_dtype) {
  151. auto proj = f_in;
  152. if (has_proj) {
  153. proj = network.add_conv(
  154. f_in, num_outputs1, {1, 1}, out_dtype, false, {stride, stride});
  155. }
  156. auto f = network.add_conv(
  157. f_in, num_outputs1, {3, 3}, out_dtype, true, {stride, stride}, {1, 1});
  158. f = network.add_conv(f, num_outputs1, {3, 3}, out_dtype, true, {1, 1}, {1, 1});
  159. f = network.add_elemwise({f, proj}, out_dtype, opr::Elemwise::Mode::FUSE_ADD_RELU);
  160. return f;
  161. }
  162. SymbolVar mgb::make_resnet18(Network& network, size_t batch, DType out_dtype) {
  163. auto data = network.add_var("data", {batch, 4, 224, 224});
  164. if (out_dtype.category() == DTypeCategory::QUANTIZED)
  165. data = network.add_type_cvt(data, dtype::QuantizedS8{1.f});
  166. auto first = out_dtype;
  167. if (out_dtype.category() == DTypeCategory::QUANTIZED)
  168. first = dtype::QuantizedS8{1.f};
  169. auto f = network.add_conv(data, 64, {7, 7}, first, true, {2, 2}, {3, 3});
  170. if (out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
  171. out_dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  172. f = network.add_type_cvt(f, out_dtype);
  173. }
  174. f = network.add_pooling(f, {3, 3}, {2, 2}, {1, 1});
  175. using Vector = SmallVector<size_t, 4>;
  176. Vector stages = {2, 2, 2, 2};
  177. Vector mid_outputs = {64, 128, 256, 512};
  178. Vector enable_stride = {0, 1, 1, 1};
  179. for (size_t i = 0; i < 4; ++i) {
  180. auto s = stages[i];
  181. auto o = mid_outputs[i];
  182. auto es = enable_stride[i];
  183. for (size_t j = 0; j < s; ++j) {
  184. size_t stride = !es || j > 0 ? 1 : 2;
  185. bool has_proj = j > 0 ? false : true;
  186. f = create_block(network, f, stride, o, has_proj, out_dtype);
  187. }
  188. }
  189. f = network.add_pooling(
  190. f, {7, 7}, {7, 7}, {0, 0}, opr::Pooling::Param::Mode::AVERAGE);
  191. f = network.add_type_cvt(f, dtype::Float32());
  192. return f;
  193. }
  194. namespace {
  195. SymbolVarArray make_pyramids(Network& network, size_t batch, DType out_dtype) {
  196. SymbolVarArray pyramids;
  197. auto data = network.add_var("data", {batch, 3, 256, 256});
  198. data = data + (-128.f);
  199. if (out_dtype.category() == DTypeCategory::QUANTIZED)
  200. data = network.add_type_cvt(data, dtype::QuantizedS8{1.f});
  201. auto first = out_dtype;
  202. if (out_dtype.category() == DTypeCategory::QUANTIZED)
  203. first = dtype::QuantizedS8{1.f};
  204. auto f = network.add_conv(data, 16, {3, 3}, first, true, {2, 2}, {1, 1});
  205. f = network.add_conv(f, 16, {3, 3}, first, true, {1, 1}, {1, 1});
  206. f = network.add_conv(f, 32, {3, 3}, first, true, {2, 2}, {1, 1});
  207. if (out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
  208. out_dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  209. f = network.add_type_cvt(f, out_dtype);
  210. }
  211. using Vector = SmallVector<size_t, 4>;
  212. Vector stages = {3, 6, 6, 3};
  213. Vector mid_outputs = {32, 64, 128, 256};
  214. Vector enable_stride = {0, 1, 1, 1};
  215. for (size_t i = 0; i < 4; ++i) {
  216. auto s = stages[i];
  217. auto o = mid_outputs[i];
  218. auto es = enable_stride[i];
  219. for (size_t j = 0; j < s; ++j) {
  220. size_t stride = !es || j > 0 ? 1 : 2;
  221. bool has_proj = j > 0 ? false : true;
  222. f = create_block(network, f, stride, o, has_proj, out_dtype);
  223. }
  224. pyramids.push_back(f);
  225. }
  226. for (size_t i = 0; i < pyramids.size(); ++i) {
  227. pyramids[i] = network.add_type_cvt(pyramids[i], first);
  228. }
  229. return pyramids;
  230. }
  231. SymbolVarArray fusion_pyramids_feature(
  232. Network& network, SymbolVarArray pyramids, size_t fpn_conv_channels) {
  233. bool touch = false;
  234. SymbolVar x;
  235. SymbolVarArray fpn;
  236. for (int i = 5; i >= 3; --i) {
  237. auto f = network.add_conv(
  238. pyramids[i - 2], fpn_conv_channels, {1, 1}, dtype::QuantizedS8{1.f},
  239. false, {1, 1}, {0, 0});
  240. if (!touch) {
  241. x = f;
  242. touch = true;
  243. } else {
  244. x = network.add_deconv(x, 2, 16, dtype::QuantizedS8{1.f});
  245. x = network.add_elemwise(
  246. {x, f}, dtype::QuantizedS8{1.f}, opr::Elemwise::Mode::ADD);
  247. }
  248. fpn.push_back(x);
  249. }
  250. x = fpn[0];
  251. for (int i = 6; i < 8; ++i) {
  252. x = network.add_conv(
  253. x, fpn_conv_channels, {3, 3}, dtype::QuantizedS8{1.f}, true, {2, 2},
  254. {1, 1});
  255. }
  256. return fpn;
  257. }
  258. } // namespace
  259. SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) {
  260. SymbolVarArray outputs;
  261. auto pyramids = make_pyramids(network, batch, out_dtype);
  262. auto fpn_hv = fusion_pyramids_feature(network, pyramids, 16);
  263. auto fpn_plate = fusion_pyramids_feature(network, pyramids, 16);
  264. outputs.insert(outputs.end(), fpn_hv.begin(), fpn_hv.end());
  265. outputs.insert(outputs.end(), fpn_plate.begin(), fpn_plate.end());
  266. return outputs;
  267. }
  268. SymbolVar mgb::bottleneck(
  269. Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t,
  270. size_t stride, DType out_dtype) {
  271. size_t in_channels = f.node()->shape()[1];
  272. SymbolVar x = f;
  273. if (t != 1) {
  274. x = network.add_conv(
  275. f, input_channels * t, {1, 1}, out_dtype, true, {1, 1}, {0, 0});
  276. }
  277. x = network.add_group_conv(
  278. x, input_channels * t, input_channels * t, {3, 3}, out_dtype, true,
  279. {stride, stride}, {1, 1});
  280. x = network.add_conv(x, channels, {1, 1}, out_dtype, false, {1, 1}, {0, 0});
  281. if (stride == 1 && in_channels == channels)
  282. x = f + x;
  283. return x;
  284. }
  285. SymbolVar mgb::bottleneck_group(
  286. Network& network, SymbolVar f, size_t input_channels, size_t channels,
  287. size_t stages, size_t s, size_t t, DType out_dtype) {
  288. SymbolVar x = f;
  289. for (size_t i = 0; i < stages; ++i) {
  290. size_t stride = i == 0 ? s : 1;
  291. x = bottleneck(network, x, input_channels, channels, t, stride, out_dtype);
  292. input_channels = channels;
  293. }
  294. return x;
  295. }
  296. namespace {
  297. size_t make_divisible(size_t v, size_t divisor) {
  298. size_t min_value = divisor;
  299. size_t new_v = std::max(min_value, (v + divisor / 2) / divisor * divisor);
  300. if (new_v < 0.9 * v)
  301. new_v += divisor;
  302. return new_v;
  303. }
  304. } // namespace
  305. SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch, DType out_dtype) {
  306. auto data = network.add_var("data", {batch, 3, 224, 224});
  307. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  308. data = network.add_type_cvt(data, dtype::QuantizedS8{1.f});
  309. }
  310. constexpr size_t round_nearest = 8;
  311. auto x = network.add_conv(
  312. data, make_divisible(32, round_nearest), {3, 3}, out_dtype, true, {2, 2},
  313. {1, 1});
  314. x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1, out_dtype);
  315. x = bottleneck_group(
  316. network, x, 16, make_divisible(24, round_nearest), 2, 2, 6, out_dtype);
  317. x = bottleneck_group(
  318. network, x, 24, make_divisible(32, round_nearest), 3, 2, 6, out_dtype);
  319. x = bottleneck_group(
  320. network, x, 32, make_divisible(64, round_nearest), 4, 2, 6, out_dtype);
  321. x = bottleneck_group(
  322. network, x, 64, make_divisible(96, round_nearest), 3, 1, 6, out_dtype);
  323. x = bottleneck_group(
  324. network, x, 96, make_divisible(160, round_nearest), 3, 2, 6, out_dtype);
  325. x = bottleneck_group(
  326. network, x, 160, make_divisible(320, round_nearest), 1, 1, 6, out_dtype);
  327. x = network.add_conv(
  328. x, make_divisible(1280, round_nearest), {1, 1}, out_dtype, true, {1, 1},
  329. {0, 0});
  330. if (out_dtype.category() == DTypeCategory::QUANTIZED) {
  331. x = network.add_type_cvt(x, dtype::Float32());
  332. }
  333. return x;
  334. }
  335. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}