You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_layer.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import te.lang.cce
  16. from te import tvm
  17. from te.platform import CUBE_MKN
  18. from topi import generic
  19. from topi.cce import util
  20. from topi.cce.util import is_v200_version
  21. # pylint: disable=R0912,R0913,R0914,R0915,E1101
  22. # the dim of shape in conv must be 4
  23. PAD_SHAPE_DIM = 2
  24. NONETYPE = type(None)
  25. @util.check_input_type((list, tuple), (list, tuple), str, str, str, (list, int), (list, int),
  26. int, int, (list, tuple), (list, tuple),
  27. str, str, str,
  28. str, str, str,
  29. str, bool, str)
  30. def conv_layer_cce_para_check(shape_in, shape_w, in_dtype, w_dtype, res_dtype, padh, padw,
  31. strideh, stridew, quantize_config, scale_sqrt,
  32. scale_q_dtype, offset_q_dtype, scale_dq_dtype,
  33. scale_rq_dtype, offset_rq_dtype, offset_w_dtype,
  34. offset_pad_dtype, bias, kernel_name):
  35. # conv shape check
  36. util.check_kernel_name(kernel_name)
  37. # conv data type check
  38. util.check_dtype_rule(in_dtype, ['float16', 'int8', 'uint8'])
  39. util.check_dtype_rule(w_dtype, ['float16', 'int8', 'uint8'])
  40. res_dtype_list = ['float16', 'int8', 'uint8']
  41. if is_v200_version():
  42. res_dtype_list.append('int32')
  43. util.check_dtype_rule(res_dtype, res_dtype_list)
  44. util.check_dtype_rule(scale_q_dtype, ['float16'])
  45. util.check_dtype_rule(offset_q_dtype, ['float16'])
  46. util.check_dtype_rule(scale_dq_dtype, ['float16'])
  47. util.check_dtype_rule(scale_rq_dtype, ['float16'])
  48. util.check_dtype_rule(offset_rq_dtype, ['float16'])
  49. util.check_dtype_rule(offset_w_dtype, ['int32'])
  50. util.check_dtype_rule(offset_pad_dtype, ['uint8'])
  51. if not isinstance(bias, bool):
  52. raise RuntimeError("bias dtype should be bool.")
  53. if quantize_config[0] == 0:
  54. if is_v200_version():
  55. util.check_dtype_rule(in_dtype, ('int8',))
  56. util.check_dtype_rule(w_dtype, ('int8',))
  57. util.check_dtype_rule(res_dtype, ('int32',))
  58. else:
  59. util.check_dtype_rule(in_dtype, ['float16'])
  60. util.check_dtype_rule(w_dtype, ['float16'])
  61. util.check_dtype_rule(res_dtype, ['float16'])
  62. if quantize_config[0] == 1:
  63. util.check_dtype_rule(w_dtype, ['int8'])
  64. if quantize_config[1] == 0:
  65. util.check_dtype_rule(in_dtype, ['int8', 'float16'])
  66. util.check_dtype_rule(res_dtype, ['int8', 'float16'])
  67. elif quantize_config[1] == 1:
  68. util.check_dtype_rule(in_dtype, ['uint8', 'float16'])
  69. util.check_dtype_rule(res_dtype, ['uint8', 'float16'])
  70. elif quantize_config[1] == 2:
  71. raise RuntimeError("All Offset mode quantize not support.")
  72. else:
  73. raise RuntimeError("Invalid quantize algorithm.")
  74. # quantize switch on
  75. if quantize_config[0] == 1:
  76. # quantize -> DeQuantize dataflow
  77. if in_dtype == 'float16' and w_dtype == 'int8' and res_dtype == 'float16':
  78. pass
  79. # DeQuantize dataflow
  80. elif (in_dtype in ['int8', 'uint8'] and w_dtype == 'int8' and
  81. res_dtype == 'float16'):
  82. pass
  83. # quantize -> ReQuantize dataflow
  84. elif (in_dtype == 'float16' and w_dtype == 'int8' and res_dtype in
  85. ['int8', 'uint8']):
  86. pass
  87. # ReQuantize dataflow
  88. elif (in_dtype in ['int8', 'uint8'] and w_dtype == 'int8' and res_dtype in
  89. ['int8', 'uint8']):
  90. pass
  91. else:
  92. raise RuntimeError("Not support in/out data type for quantize.")
  93. if quantize_config not in ([1, 0, 0], [1, 1, 0], [1, 0, 1], [1, 1, 1]):
  94. raise RuntimeError("Invalid Quantize Config.")
  95. if scale_sqrt not in ([0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1],
  96. [1, 0, 1], [0, 1, 1], [1, 1, 1]):
  97. raise RuntimeError("Invalid Quantize Config.")
  98. # quantize switch off
  99. elif quantize_config[0] == 0:
  100. if quantize_config != [0, 0, 0]:
  101. raise RuntimeError("Invalid Quantize Config.")
  102. if scale_sqrt != [0, 0, 0]:
  103. raise RuntimeError("Invalid Quantize Config.")
  104. else:
  105. raise RuntimeError("Invalid Quantize Config.")
  106. if isinstance(padh, list):
  107. if len(padh) != PAD_SHAPE_DIM:
  108. raise RuntimeError("Dimension must be %d when padh is a list." % PAD_SHAPE_DIM)
  109. pad_top = padh[0]
  110. pad_bottom = padh[1]
  111. else:
  112. pad_top = padh
  113. pad_bottom = padh
  114. if isinstance(padw, list):
  115. if len(padw) != PAD_SHAPE_DIM:
  116. raise RuntimeError("Dimension must be %d when padw is a list." % PAD_SHAPE_DIM)
  117. pad_left = padw[0]
  118. pad_right = padw[1]
  119. else:
  120. pad_left = padw
  121. pad_right = padw
  122. shape_in, shape_w = te.lang.cce.check_conv_shape(shape_in, shape_w, pad_top, pad_bottom, \
  123. pad_left, pad_right, strideh, \
  124. stridew, in_dtype, w_dtype, res_dtype)
  125. return shape_in, shape_w
  126. @util.check_input_type((list, tuple), (list, tuple), str, str, str, \
  127. (list, int), (list, int), int, int,
  128. (list, NONETYPE), (list, NONETYPE),
  129. str, str, str,
  130. str, str, str, str,
  131. bool, str, bool, bool)
  132. def conv_layer_cce(shape_in, shape_w, in_dtype, w_dtype, res_dtype, padh, padw, strideh, stridew,
  133. quantize_config=None, scale_sqrt=None,
  134. scale_q_dtype='float16', offset_q_dtype='float16', scale_dq_dtype='float16',
  135. scale_rq_dtype='float16', offset_rq_dtype='float16', offset_w_dtype='int32',
  136. offset_pad_dtype='uint8', bias=False, kernel_name="cce_conv", need_build=False,
  137. need_print=False):
  138. """
  139. Parameters
  140. ----------
  141. shape_in : shape of data_in
  142. shape_w : shape of filter
  143. in_dtype : the feature map data type
  144. w_dtype : the weight data type
  145. res_dtype : the result data type
  146. padh: the padding shape in H
  147. padw: the padding shape in weight
  148. strideh: the stride value in H
  149. stridew: the stride value in weight
  150. quantize_config: quantize config table, default [0, 0, 0]
  151. quantize_config[0] - quantize function switch
  152. 0: quantize off
  153. 1: quantize on
  154. quantize_config[1] - quantize_algorithm
  155. 0: non offset
  156. 1: half offset
  157. 2: all offset ( Not supported now )
  158. quantize_config[2] - QuantizeScaleType (for Dequantize/Requantize, quantize always scalar)
  159. 0: scalar
  160. 1: vector
  161. scale_sqrt: scale mode
  162. scale_sqrt[0] - Quantize scale mode
  163. 0: non sqrt
  164. 1: sqrt
  165. scale_sqrt[1] - DeQuantize scale mode
  166. 0: non sqrt
  167. 1: sqrt
  168. scale_sqrt[2] - ReQuantize scale mode
  169. 0: non sqrt
  170. 1: sqrt
  171. scale_q_dtype: Quantize scale data type, default 'float16'
  172. offset_q_dtype: Quantize offset data type, default 'float16'
  173. scale_dq_dtype: DeQuantize scale data type, default 'float16'
  174. scale_rq_dtype: ReQuantize scale data type, default 'float16'
  175. offset_rq_dtype: ReQuantize offset data type, default 'float16'
  176. offset_w_dtype: weight offset data type, default 'int32'
  177. offset_pad_dtype: Quantize Cube offset data type, default 'uint8'
  178. bias: the tag for bias or not
  179. kernel_name : cce kernel name, default value is "cce_conv"
  180. need_build : if need to build CCEC kernel, default value is False
  181. need_print : if need to print the ir, default value is False
  182. Returns
  183. -------
  184. wrapped_tensor
  185. """
  186. # for pylint, otherwise "Dangerous default value [] as argument"
  187. if quantize_config is None:
  188. quantize_config = [0, 0, 0]
  189. if scale_sqrt is None:
  190. scale_sqrt = [0, 0, 0]
  191. in_dtype = in_dtype.lower()
  192. w_dtype = w_dtype.lower()
  193. res_dtype = res_dtype.lower()
  194. scale_q_dtype = scale_q_dtype.lower()
  195. offset_q_dtype = offset_q_dtype.lower()
  196. scale_dq_dtype = scale_dq_dtype.lower()
  197. scale_rq_dtype = scale_rq_dtype.lower()
  198. offset_rq_dtype = offset_rq_dtype.lower()
  199. offset_w_dtype = offset_w_dtype.lower()
  200. offset_pad_dtype = offset_pad_dtype.lower()
  201. mad_dtype = 'float32'
  202. if w_dtype == 'int8':
  203. mad_dtype = 'int32'
  204. shape_in = list(shape_in)
  205. shape_w = list(shape_w)
  206. shape_in, shape_w = conv_layer_cce_para_check(shape_in, shape_w, in_dtype, w_dtype, res_dtype, padh, padw, strideh,
  207. stridew,
  208. quantize_config, scale_sqrt, scale_q_dtype, offset_q_dtype,
  209. scale_dq_dtype,
  210. scale_rq_dtype, offset_rq_dtype, offset_w_dtype, offset_pad_dtype,
  211. bias, kernel_name)
  212. # quantize switch on
  213. if quantize_config[0] == 1:
  214. quantize_turn_on = True
  215. # quantize -> DeQuantize dataflow
  216. if in_dtype == 'float16' and w_dtype == 'int8' and res_dtype == 'float16':
  217. is_quantize = True
  218. is_dequantize = True
  219. is_requantize = False
  220. # DeQuantize dataflow
  221. elif (in_dtype in ['int8', 'uint8'] and w_dtype == 'int8' and
  222. res_dtype == 'float16'):
  223. is_quantize = False
  224. is_dequantize = True
  225. is_requantize = False
  226. # quantize -> ReQuantize dataflow
  227. elif (in_dtype == 'float16' and w_dtype == 'int8' and res_dtype in
  228. ['int8', 'uint8']):
  229. is_quantize = True
  230. is_dequantize = False
  231. is_requantize = True
  232. # ReQuantize dataflow
  233. elif (in_dtype in ['int8', 'uint8'] and w_dtype == 'int8' and res_dtype in
  234. ['int8', 'uint8']):
  235. is_quantize = False
  236. is_dequantize = False
  237. is_requantize = True
  238. else:
  239. raise RuntimeError("Not support in/out data type for quantize.")
  240. # quantize switch off
  241. elif quantize_config[0] == 0:
  242. quantize_turn_on = False
  243. is_quantize = False
  244. is_dequantize = False
  245. is_requantize = False
  246. if quantize_config != [0, 0, 0]:
  247. raise RuntimeError("Invalid Quantize Config.")
  248. if scale_sqrt != [0, 0, 0]:
  249. raise RuntimeError("Invalid Quantize Config.")
  250. else:
  251. raise RuntimeError("Invalid Quantize Config.")
  252. batch_size = shape_in[0]
  253. in_channel = shape_in[1]
  254. feature_map_h = shape_in[2]
  255. feature_map_w = shape_in[3]
  256. block_size_k = CUBE_MKN[in_dtype]['mac'][1]
  257. fmap_shape_nc1hwc0 = (batch_size, (in_channel + block_size_k - 1) // block_size_k,
  258. feature_map_h, feature_map_w, block_size_k)
  259. out_channel = shape_w[0]
  260. in_channel_weight = shape_w[1]
  261. filter_h = shape_w[2]
  262. filter_w = shape_w[3]
  263. block_size_k = CUBE_MKN[w_dtype]['mac'][1]
  264. block_size_n = CUBE_MKN[w_dtype]['mac'][2]
  265. filter_shape_frac_z = (in_channel_weight * filter_h * filter_w // block_size_k,
  266. out_channel // block_size_n, block_size_n, block_size_k)
  267. with tvm.target.cce():
  268. data = tvm.placeholder(
  269. fmap_shape_nc1hwc0, name='Fmap', dtype=in_dtype)
  270. weight = tvm.placeholder(
  271. filter_shape_frac_z, name='Filter', dtype=w_dtype)
  272. bias_tensor = None
  273. scale_q = None
  274. scale_dq = None
  275. scale_rq = None
  276. offset_pad = None
  277. offset_rq = None
  278. offset_q = None
  279. scale_drq = None
  280. # bias or fusion_bias(half offset)
  281. if bias or (quantize_config[1] == 1 and quantize_turn_on):
  282. bias_tensor = tvm.placeholder(
  283. (out_channel,), name='bias_tensor', \
  284. dtype="int32" if quantize_turn_on else res_dtype)
  285. # quantize on
  286. if quantize_turn_on:
  287. quantize_algorithm = quantize_config[1]
  288. if is_quantize:
  289. scale_q = tvm.placeholder(
  290. (CUBE_MKN[scale_q_dtype]['mac'][1],), name='scaleQ', dtype=scale_q_dtype)
  291. if quantize_algorithm == 1:
  292. offset_q = tvm.placeholder(
  293. (CUBE_MKN[offset_q_dtype]['mac'][1],), name='offsetQ', dtype=offset_q_dtype)
  294. if is_dequantize:
  295. scale_dq_shape = (CUBE_MKN[scale_dq_dtype]['mac'][1],) if quantize_config[2] == 0 \
  296. else (out_channel,)
  297. scale_dq = tvm.placeholder(
  298. scale_dq_shape, name='scaleDq', dtype=scale_dq_dtype)
  299. if is_requantize:
  300. scale_rq_shape = (CUBE_MKN[scale_rq_dtype]['mac'][1],) if quantize_config[2] == 0 \
  301. else (out_channel,)
  302. scale_rq = tvm.placeholder(
  303. scale_rq_shape, name='scaleRq', dtype=scale_rq_dtype)
  304. if quantize_algorithm == 1:
  305. offset_rq_shape = (CUBE_MKN[offset_rq_dtype]['mac'][1],)
  306. offset_rq = tvm.placeholder(
  307. offset_rq_shape, name='offsetRq', dtype=offset_rq_dtype)
  308. # need offset_pad , for half offset
  309. if quantize_algorithm == 1:
  310. offset_pad = tvm.placeholder(
  311. (CUBE_MKN[offset_pad_dtype]['mac'][1],), name='offset_pad',
  312. dtype=offset_pad_dtype)
  313. if quantize_algorithm == 0:
  314. if is_quantize:
  315. if is_dequantize:
  316. scale_drq = scale_dq
  317. else:
  318. scale_drq = scale_rq
  319. conv_res = te.lang.cce.conv(
  320. data, weight, {"bias_tensor": bias_tensor,
  321. "scale_q": scale_q,
  322. "offset_q": offset_q,
  323. "scale_drq": scale_drq,
  324. "offset_pad": offset_pad,
  325. "offset_rq": offset_rq,
  326. "quantize_config": quantize_config,
  327. "is_quantize": is_quantize,
  328. "is_dequantize": is_dequantize,
  329. "is_requantize": is_requantize,
  330. "scale_sqrt": scale_sqrt,
  331. "pad_h": padh, "pad_w": padw,
  332. "stride_h": strideh, "stride_w": stridew,
  333. "filter_h": filter_h, "filter_w": filter_w,
  334. "res_dtype": res_dtype, "mad_dtype": mad_dtype},
  335. dsl_flag=False)
  336. if bias:
  337. tensor_list = [data, weight, bias_tensor, scale_q,
  338. scale_drq, conv_res]
  339. else:
  340. tensor_list = [data, weight, scale_q,
  341. scale_drq, conv_res]
  342. else:
  343. if is_dequantize:
  344. scale_drq = scale_dq
  345. else:
  346. scale_drq = scale_rq
  347. conv_res = te.lang.cce.conv(
  348. data, weight, {"bias_tensor": bias_tensor,
  349. "scale_q": scale_q,
  350. "offset_q": offset_q,
  351. "scale_drq": scale_drq,
  352. "offset_pad": offset_pad,
  353. "offset_rq": offset_rq,
  354. "quantize_config": quantize_config,
  355. "is_quantize": is_quantize,
  356. "is_dequantize": is_dequantize,
  357. "is_requantize": is_requantize,
  358. "scale_sqrt": scale_sqrt,
  359. "pad_h": padh, "pad_w": padw,
  360. "stride_h": strideh, "stride_w": stridew,
  361. "filter_h": filter_h, "filter_w": filter_w,
  362. "res_dtype": res_dtype, "mad_dtype": mad_dtype},
  363. dsl_flag=False)
  364. if bias:
  365. tensor_list = [data, weight, bias_tensor,
  366. scale_drq, conv_res]
  367. else:
  368. tensor_list = [data, weight,
  369. scale_drq, conv_res]
  370. # half offset
  371. else:
  372. if is_quantize:
  373. if is_dequantize:
  374. scale_drq = scale_dq
  375. else:
  376. scale_drq = scale_rq
  377. conv_res = te.lang.cce.conv(
  378. data, weight, {"bias_tensor": bias_tensor,
  379. "scale_q": scale_q,
  380. "offset_q": offset_q,
  381. "scale_drq": scale_drq,
  382. "offset_pad": offset_pad,
  383. "offset_rq": offset_rq,
  384. "quantize_config": quantize_config,
  385. "is_quantize": is_quantize,
  386. "is_dequantize": is_dequantize,
  387. "is_requantize": is_requantize,
  388. "scale_sqrt": scale_sqrt,
  389. "pad_h": padh, "pad_w": padw,
  390. "stride_h": strideh, "stride_w": stridew,
  391. "filter_h": filter_h, "filter_w": filter_w,
  392. "res_dtype": res_dtype, "mad_dtype": mad_dtype},
  393. dsl_flag=False)
  394. if is_dequantize:
  395. tensor_list = [data, weight, bias_tensor, scale_q, offset_q,
  396. scale_drq, offset_pad, conv_res]
  397. else:
  398. tensor_list = [data, weight, bias_tensor, scale_q, offset_q,
  399. scale_drq, offset_rq, offset_pad, conv_res]
  400. else:
  401. if is_dequantize:
  402. scale_drq = scale_dq
  403. else:
  404. scale_drq = scale_rq
  405. conv_res = te.lang.cce.conv(
  406. data, weight, {"bias_tensor": bias_tensor,
  407. "scale_q": scale_q,
  408. "offset_q": offset_q,
  409. "scale_drq": scale_drq,
  410. "offset_pad": offset_pad,
  411. "offset_rq": offset_rq,
  412. "quantize_config": quantize_config,
  413. "is_quantize": is_quantize,
  414. "is_dequantize": is_dequantize,
  415. "is_requantize": is_requantize,
  416. "scale_sqrt": scale_sqrt,
  417. "pad_h": padh, "pad_w": padw,
  418. "stride_h": strideh, "stride_w": stridew,
  419. "filter_h": filter_h, "filter_w": filter_w,
  420. "res_dtype": res_dtype, "mad_dtype": mad_dtype},
  421. dsl_flag=False)
  422. if is_dequantize:
  423. tensor_list = [data, weight, bias_tensor,
  424. scale_drq, offset_pad, conv_res]
  425. else:
  426. tensor_list = [data, weight, bias_tensor,
  427. scale_drq, offset_rq, offset_pad, conv_res]
  428. else:
  429. conv_res = te.lang.cce.conv(
  430. data, weight, {"bias_tensor": bias_tensor,
  431. "scale_q": scale_q,
  432. "offset_q": offset_q,
  433. "scale_drq": scale_drq,
  434. "offset_pad": offset_pad,
  435. "offset_rq": offset_rq,
  436. "quantize_config": quantize_config,
  437. "is_quantize": is_quantize,
  438. "is_dequantize": is_dequantize,
  439. "is_requantize": is_requantize,
  440. "scale_sqrt": scale_sqrt,
  441. "pad_h": padh, "pad_w": padw,
  442. "stride_h": strideh, "stride_w": stridew,
  443. "filter_h": filter_h, "filter_w": filter_w,
  444. "res_dtype": res_dtype, "mad_dtype": mad_dtype},
  445. dsl_flag=False)
  446. if bias:
  447. tensor_list = [data, weight, bias_tensor, conv_res]
  448. else:
  449. tensor_list = [data, weight, conv_res]
  450. sch = generic.auto_schedule(conv_res)
  451. config = {
  452. "print_ir": need_print,
  453. "need_build": need_build,
  454. "name": kernel_name,
  455. "tensor_list": tensor_list
  456. }
  457. te.lang.cce.cce_build_code(sch, config)