You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_json_data.py 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """generate numpy data for composite json"""
  15. import json
  16. import logging
  17. import inspect
  18. import numpy as np
  19. from tests.common.gen_random import random_gaussian
  20. def get_attr(attr_desc, attr_type):
  21. """get op attr by type"""
  22. for attr in attr_desc:
  23. if attr["name"] == attr_type:
  24. return attr["value"]
  25. logging.warning("attr {} not found, please check.".format(attr_type))
  26. return None
  27. class CodePrinter(object):
  28. """print numpy file"""
  29. def __init__(self, out_file):
  30. self.fout_ = open(out_file, 'w')
  31. self.indent_ = 0
  32. def __del__(self):
  33. self.fout_.close()
  34. def out(self, data, new_line=False):
  35. """write data"""
  36. if new_line:
  37. self.fout_.write("\n")
  38. for i in range(0, self.indent_):
  39. self.fout_.write(' ')
  40. if isinstance(data, str):
  41. self.fout_.write(data)
  42. else:
  43. self.fout_.write(str(data))
  44. def null_line(self):
  45. """add null line"""
  46. self.fout_.write("\n")
  47. def close(self):
  48. """close file"""
  49. self.fout_.close()
  50. def get_input(desc):
  51. """get input values"""
  52. value = desc.get('value', None)
  53. return value if value is not None else desc['tensor_name']
  54. def reduce_str(inputs, output, attr, op_type):
  55. """gen sum string"""
  56. axis = []
  57. keepdims = False
  58. axis_value = get_attr(attr, "axis")
  59. if axis_value:
  60. axis = list(axis_value) if isinstance(axis_value, (list, tuple)) else [axis_value]
  61. keepdims_value = get_attr(attr, "keep_dims")
  62. keepdims = keepdims_value if keepdims_value else keepdims
  63. if axis == []:
  64. s = "%s = np.%s(%s.astype(np.float32) if %s.dtype == np.float16 else %s, keepdims=%s).astype(%s.dtype)" % (
  65. output[0]['tensor_name'], op_type, get_input(inputs[0][0]), get_input(inputs[0][0]),
  66. get_input(inputs[0][0]), keepdims, get_input(inputs[0][0]))
  67. else:
  68. s = "%s = np.%s(%s.astype(np.float32) if %s.dtype == np.float16 else %s, axis=tuple(%s), keepdims=%s).astype(%s.dtype); %s = np.reshape(%s, %s) " %\
  69. (output[0]['tensor_name'], op_type, get_input(inputs[0][0]), get_input(inputs[0][0]),
  70. get_input(inputs[0][0]), axis, keepdims, get_input(inputs[0][0]),
  71. output[0]['tensor_name'], output[0]['tensor_name'], output[0]['shape'])
  72. return s
  73. def cast_str(inputs, output, attr):
  74. """gen cast string"""
  75. dst_type = get_attr(attr, "dst_type")
  76. s = "%s = np.array(%s).astype(np.%s) if isinstance(%s, (float, int)) else %s.astype(np.%s)" % (
  77. output[0]['tensor_name'], get_input(inputs[0][0]), dst_type, get_input(inputs[0][0]),
  78. get_input(inputs[0][0]), dst_type)
  79. return s
  80. def broadcast_str(inputs, output, attr):
  81. """gen broadcast string"""
  82. dst_shape = get_attr(attr, "shape")
  83. s = "%s = np.broadcast_to(%s, %s)" % (
  84. output[0]["tensor_name"], get_input(inputs[0][0]), dst_shape)
  85. return s
  86. def transpose_str(inputs, output, attr):
  87. """gen transpose string"""
  88. axes = None
  89. axes_value = get_attr(attr, "perm")
  90. axes = axes_value if axes_value else axes
  91. s = "%s = np.transpose(%s, axes=%s)" % (
  92. output[0]['tensor_name'], get_input(inputs[0][0]), axes)
  93. return s
  94. def trans_data_two2fractal(input_, src_format, dst_format):
  95. shape = list(input_.shape)
  96. dtype = input_.dtype
  97. if dtype == "float32":
  98. input_ = input_.astype(np.float16)
  99. if src_format == "DefaultFormat" or src_format == "NCHW":
  100. m, n = shape[-2], shape[-1]
  101. m1, n1 = m // 16, n // 16
  102. m0, n0 = 16, 16
  103. needPad = m % 16 != 0 or n % 16 != 0
  104. if needPad:
  105. pad_m, pad_n = (m + 15) // 16 * 16, (n + 15) // 16 * 16
  106. pad_shape = [x for x in shape]
  107. pad_shape[-1] = pad_n
  108. pad_shape[-2] = pad_m
  109. pad_input = np.zeros(pad_shape).astype(dtype)
  110. if len(shape) == 2:
  111. pad_input[:m, :n] = input_
  112. elif len(shape) == 3:
  113. pad_input[:, :m, :n] = input_
  114. elif len(shape) == 4:
  115. pad_input[:, :, :m, :n] = input_
  116. m1, n1 = pad_m // 16, pad_n // 16
  117. reshape_shape = shape[:-2] + [m1, m0, n1, n0]
  118. reshape_input = pad_input.reshape(reshape_shape)
  119. else:
  120. reshape_shape = shape[:-2] + [m1, m0, n1, n0]
  121. reshape_input = input_.reshape(reshape_shape)
  122. if dst_format == "FRACTAL_NZ":
  123. transpose_axis = [2, 0, 1, 3]
  124. else:
  125. raise ValueError("dst_fromat %s is not suppored when src_format is %s" % (
  126. dst_format, src_format))
  127. transpose_axis = [x + len(shape) - 2 for x in transpose_axis]
  128. transpose_axis = [x for x in range(len(shape) - 2)] + transpose_axis
  129. bench_mark = reshape_input.transpose(transpose_axis).astype('float16')
  130. return bench_mark
  131. raise ValueError("src_format %s is not supported!" % src_format)
  132. def trans_data_fractal2two(input_, src_format, dst_format, shape_origin):
  133. shape_origin = [int(_) for _ in shape_origin]
  134. shape = list(input_.shape)
  135. n1, m1, m0, n0 = shape[-4:]
  136. new_shape = shape[:-4] + [m1 * m0, n1 * n0]
  137. tranpose_axis = [1, 2, 0, 3]
  138. tranpose_axis = [x + len(shape) - 4 for x in tranpose_axis]
  139. tranpose_axis = [i for i in range(len(shape) - 4)] + tranpose_axis
  140. bench_mark = input_.transpose(tranpose_axis).reshape(new_shape)
  141. if new_shape != shape_origin:
  142. if len(shape_origin) == 2:
  143. bench_mark = bench_mark[:shape_origin[0], :shape_origin[1]]
  144. elif len(shape_origin) == 3:
  145. bench_mark = bench_mark[:, shape_origin[0], :shape_origin[1]]
  146. elif len(shape_origin) == 4:
  147. bench_mark = bench_mark[:, :, shape_origin[0], :shape_origin[1]]
  148. return bench_mark
  149. def get_trans_data_str(input_name, output_name, ori_shape, src_format, dst_format):
  150. support_formats = [("DefaultFormat", "FRACTAL_NZ"),
  151. ("NCHW", "FRACTAL_NZ"),
  152. ("FRACTAL_NZ", "DefaultFormat"),
  153. ("FRACTAL_NZ", "NCHW")]
  154. if (src_format, dst_format) not in support_formats:
  155. raise ValueError("src_format %s and dst_format %s is not supported!" %
  156. (src_format, dst_format))
  157. if (src_format == 'DefaultFormat' or src_format == "NCHW") and dst_format == 'FRACTAL_NZ':
  158. res = "%s \n%s = %s(%s, '%s', '%s')" % (inspect.getsource(trans_data_two2fractal),
  159. output_name, trans_data_two2fractal.__name__, input_name,
  160. src_format, dst_format)
  161. elif src_format == 'FRACTAL_NZ' and (dst_format == 'DefaultFormat' or dst_format == "NCHW"):
  162. res = "%s \n%s = %s(%s, '%s', '%s', %s)" % (inspect.getsource(trans_data_fractal2two),
  163. output_name, trans_data_fractal2two.__name__, input_name,
  164. src_format, dst_format, ori_shape)
  165. else:
  166. raise ValueError("src_format(%s) and dst_format(%s) is not supported!" % (src_format, dst_format))
  167. return res
  168. def trans_data_dsl(inputs, output, attr):
  169. src_format = get_attr(attr, "src_format")
  170. dst_format = get_attr(attr, "dst_format")
  171. ori_shape = output[0]['shape']
  172. input_name = get_input(inputs[0][0])
  173. output_name = output[0]['tensor_name']
  174. return get_trans_data_str(input_name, output_name, ori_shape, src_format, dst_format)
  175. def batchmatmul_str(inputs, output, attr):
  176. trans_a = get_attr(attr, "transpose_a")
  177. trans_b = get_attr(attr, "transpose_b")
  178. if trans_a and trans_b:
  179. res = "%s = np.matmul(np.swapaxes(%s, -1, -2), np.swapaxes(%s, -1, -2))" %\
  180. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  181. elif trans_a:
  182. res = "%s = np.matmul(np.swapaxes(%s, -1, -2), %s)" %\
  183. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  184. elif trans_b:
  185. res = "%s = np.matmul(%s, np.swapaxes(%s, -1, -2))" %\
  186. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  187. else:
  188. res = "%s = np.matmul(%s, %s)" %\
  189. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  190. return res
  191. def convert_fracal_shape(ori_shape, fractal):
  192. ori_shape = tuple(ori_shape)
  193. if fractal == "zN":
  194. return ori_shape[:-4] + (ori_shape[-2] * ori_shape[-3], ori_shape[-1] * ori_shape[-4])
  195. if fractal == "zZ":
  196. return ori_shape[:-4] + (ori_shape[-4] * ori_shape[-2], ori_shape[-3] * ori_shape[-1])
  197. def matmul_str(inputs, output, attr):
  198. left_format = get_attr(attr, "left_format")
  199. right_format = get_attr(attr, "right_format")
  200. left_input = inputs[0][0]
  201. right_input = inputs[1][0]
  202. output_name = output[0]['tensor_name']
  203. output_format = output[0]['format']
  204. output_shape = output[0]['shape']
  205. right_ori_shape = convert_fracal_shape(right_input['shape'], right_format)
  206. left_input_name = get_input(left_input)
  207. right_input_name = get_input(right_input)
  208. res = ''
  209. if left_format == 'FRACTAL_NZ':
  210. left_ori_shape = convert_fracal_shape(left_input['shape'], "zN")
  211. left_trans_str = get_trans_data_str(left_input_name, left_input_name, left_ori_shape, left_format, 'DefaultFormat')
  212. res = res + left_trans_str + "\n"
  213. if right_format == 'FRACTAL_NZ':
  214. right_ori_shape = convert_fracal_shape(right_input['shape'], "zN")
  215. right_trans_str = get_trans_data_str(right_input_name, right_input_name, right_ori_shape, right_format, 'DefaultFormat')
  216. res = res + right_trans_str + "\n"
  217. matmul_str = batchmatmul_str(inputs, output, attr)
  218. res = res + matmul_str + "\n"
  219. if output_format != 'DefaultFormat':
  220. output_trans_str = get_trans_data_str(output_name, output_name, output_shape, 'DefaultFormat', output_format)
  221. res = res + output_trans_str + "\n"
  222. return res
  223. op_dsl = {
  224. "ReduceSum": lambda inputs, output, attr: reduce_str(inputs, output, attr, "sum"),
  225. "ReduceMax": lambda inputs, output, attr: reduce_str(inputs, output, attr, "max"),
  226. "ReduceMin": lambda inputs, output, attr: reduce_str(inputs, output, attr, "min"),
  227. "Tanh": lambda inputs, output, attr: "%s = np.tanh(%s)" %
  228. (output[0]['tensor_name'], get_input(inputs[0][0])),
  229. "Mul": lambda inputs, output, attr: "%s = np.multiply(%s, %s)" %
  230. (output[0]['tensor_name'], get_input(
  231. inputs[0][0]), get_input(inputs[1][0])),
  232. "Pow": lambda inputs, output, attr: "%s = np.power(%s, %s)" %
  233. (output[0]['tensor_name'], get_input(
  234. inputs[0][0]), get_input(inputs[1][0])),
  235. "Sub": lambda inputs, output, attr: "%s = np.subtract(%s, %s)" %
  236. (output[0]['tensor_name'], get_input(
  237. inputs[0][0]), get_input(inputs[1][0])),
  238. "TensorAdd": lambda inputs, output, attr: "%s = np.add(%s, %s)" %
  239. (output[0]['tensor_name'], get_input(
  240. inputs[0][0]), get_input(inputs[1][0])),
  241. "Add": lambda inputs, output, attr: "%s = np.add(%s, %s)" %
  242. (output[0]['tensor_name'], get_input(
  243. inputs[0][0]), get_input(inputs[1][0])),
  244. "Rsqrt": lambda inputs, output, attr: "%s = 1.0/np.sqrt(%s)" %
  245. (output[0]['tensor_name'], get_input(inputs[0][0])),
  246. "Neg": lambda inputs, output, attr: "%s = np.negative(%s)" %
  247. (output[0]['tensor_name'], get_input(inputs[0][0])),
  248. "Exp": lambda inputs, output, attr: "%s = np.exp(%s)" %
  249. (output[0]['tensor_name'], get_input(inputs[0][0])),
  250. "RealDiv": lambda inputs, output, attr: "%s = np.divide(%s, %s)" %
  251. (output[0]['tensor_name'], get_input(
  252. inputs[0][0]), get_input(inputs[1][0])),
  253. "Minimum": lambda inputs, output, attr: "%s = np.minimum(%s, %s)" %
  254. (output[0]['tensor_name'], get_input(
  255. inputs[0][0]), get_input(inputs[1][0])),
  256. "Maximum": lambda inputs, output, attr: "%s = np.maximum(%s, %s)" %
  257. (output[0]['tensor_name'], get_input(
  258. inputs[0][0]), get_input(inputs[1][0])),
  259. "Log": lambda inputs, output, attr: "%s = np.log(%s)" %
  260. (output[0]['tensor_name'], get_input(inputs[0][0])),
  261. "Sqrt": lambda inputs, output, attr: "%s = np.sqrt(%s)" %
  262. (output[0]['tensor_name'], get_input(inputs[0][0])),
  263. "Cast": lambda inputs, output, attr: cast_str(inputs, output, attr),
  264. "Reshape": lambda inputs, output, attr: "%s = np.reshape(%s, %s)" %
  265. (output[0]['tensor_name'], get_input(inputs[0][0]), get_attr(attr, "shape")),
  266. "OneHot": lambda inputs, output, attr: "%s = np.one_hot(%s, %s, %s, %s, %s, %s)" %
  267. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]), get_input(inputs[2][0]),
  268. attr[0]['value'], attr[1]['value'], inputs[0][0]['data_type']),
  269. "ZerosLike": lambda inputs, output, attr: "%s = np.zeros_like(%s)" %
  270. (output[0]['tensor_name'], get_input(inputs[0][0])),
  271. "AddN": lambda inputs, output, attr: "%s = %s" %
  272. (output[0]['tensor_name'], ' + '.join([get_input(inputs[0][i])
  273. for i in range(0, len(inputs[0]))])),
  274. "Tile": lambda inputs, output, attr: "%s = np.tile(%s, %s)" %
  275. (output[0]['tensor_name'], get_input(
  276. inputs[0][0]), get_attr(attr, "multiples")),
  277. "Reciprocal": lambda inputs, output, attr: "%s = np.divide(1.0, %s)" %
  278. (output[0]['tensor_name'], get_input(inputs[0][0])),
  279. "Equal": lambda inputs, output, attr: "%s = np.equal(%s, %s)" %
  280. (output[0]['tensor_name'], get_input(
  281. inputs[0][0]), get_input(inputs[1][0])),
  282. "GreaterEqual": lambda inputs, output, attr: "%s = np.greater_equal(%s, %s)" %
  283. (output[0]['tensor_name'], get_input(
  284. inputs[0][0]), get_input(inputs[1][0])),
  285. "Select": lambda inputs, output, attr: "%s = np.where(%s, %s, %s)" %
  286. (output[0]['tensor_name'], get_input(inputs[0][0]),
  287. get_input(inputs[1][0]), get_input(inputs[2][0])),
  288. "InplaceAssign": lambda inputs, output, attr: "%s = %s; %s = %s" %
  289. (get_input(inputs[0][0]), get_input(inputs[1][0]),
  290. output[0]['tensor_name'], get_input(inputs[2][0])),
  291. "Greater": lambda inputs, output, attr: "%s = np.greater(%s, %s)" %
  292. (output[0]['tensor_name'], get_input(
  293. inputs[0][0]), get_input(inputs[1][0])),
  294. "SelectGT": lambda inputs, output, attr: "%s = np.where(%s > %s, %s, %s)" %
  295. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]),
  296. get_input(inputs[2][0]), get_input(inputs[3][0])),
  297. "SelectLT": lambda inputs, output, attr: "%s = np.where(%s < %s, %s, %s)" %
  298. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]),
  299. get_input(inputs[2][0]), get_input(inputs[3][0])),
  300. "Abs": lambda inputs, output, attr: "%s = np.absolute(%s)" %
  301. (output[0]['tensor_name'], get_input(inputs[0][0])),
  302. "LessEqual": lambda inputs, output, attr: "%s = np.less_equal(%s, %s)" %
  303. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0])),
  304. "EquivFormat": lambda inputs, output, attr: "%s = %s" %
  305. (output[0]['tensor_name'], get_input(inputs[0][0])),
  306. "ExpandDims": lambda inputs, output, attr: "%s = np.expand_dims(%s, %s)" %
  307. (output[0]['tensor_name'], get_input(inputs[0][0]), get_attr(attr, "axis")),
  308. "Transpose": lambda inputs, output, attr: transpose_str(inputs, output, attr),
  309. "TransData": trans_data_dsl,
  310. "BroadcastTo": lambda inputs, output, attr: broadcast_str(inputs, output, attr),
  311. "BatchMatMul": lambda inputs, output, attr: batchmatmul_str(inputs, output, attr),
  312. "Assign": lambda inputs, output, attr: "%s = %s; %s = %s" %
  313. (get_input(inputs[0][0]), get_input(inputs[1][0]), output[0]['tensor_name'],
  314. get_input(inputs[1][0])),
  315. "MatMul": lambda inputs, output, attr: matmul_str(inputs, output, attr)
  316. }
  317. def gen_json_data(op_desc):
  318. """Generating test data for composite json"""
  319. desc = json.loads(op_desc)
  320. input_for_mod = []
  321. input_dict = {}
  322. input_order = {}
  323. output_indexes = []
  324. expect = []
  325. p = CodePrinter('json_data.py')
  326. idx = 0
  327. # Collect input which should be processed by atomic clean.
  328. clean_input = []
  329. sum_out = None
  330. for op in desc["op_desc"]:
  331. if op["name"] == "ReduceSum":
  332. for a in op["attr"]:
  333. if a["name"] == "enable_atomic_add":
  334. sum_out = op["output_desc"][0]["tensor_name"]
  335. break
  336. elif op["name"] == "InplaceAssign":
  337. if not sum_out:
  338. continue
  339. if op["input_desc"][1][0]["tensor_name"] == sum_out:
  340. clean_input.append(op["input_desc"][0][0]["tensor_name"])
  341. for input_desc in desc["input_desc"] if desc["input_desc"] is not None else []:
  342. shape = [1] if not input_desc[0]["shape"] else input_desc[0]["shape"]
  343. dtype = input_desc[0]["data_type"]
  344. tensor_name = input_desc[0]["tensor_name"]
  345. if tensor_name in clean_input:
  346. item = np.zeros(shape).astype(dtype)
  347. else:
  348. item = random_gaussian(shape, miu=1, sigma=0.1).astype(dtype)
  349. input_for_mod.append(item)
  350. input_order[tensor_name] = idx
  351. input_dict[tensor_name] = item
  352. p.out("%s = np.array(input_dict.get('%s'))" % (tensor_name, tensor_name),
  353. new_line=False if idx == 0 else True)
  354. idx += 1
  355. inplace_assign_write = []
  356. fake_output_tensors = []
  357. elemwise_op_list = ["TensorAdd", "Add", "RealDiv", "Mul", "Minimum", "Maximum", "Sub"]
  358. for op in desc["op_desc"]:
  359. dsl_fun = op_dsl.get(op["name"], None)
  360. if op["name"] in ("InplaceAssign", "Assign"):
  361. if op["name"] == "InplaceAssign":
  362. fake_output = False
  363. for attr in op["attr"]:
  364. if attr["name"] == "fake_output":
  365. fake_output = attr["value"]
  366. if fake_output:
  367. fake_output_tensors.append(op["output_desc"][0]["tensor_name"])
  368. inplace_assign_write.append(op["input_desc"][0][0]["tensor_name"])
  369. elif op["name"] in elemwise_op_list and "format" in op["output_desc"][0]and \
  370. op["output_desc"][0]["format"] =="FRACTAL_NZ":
  371. need_reshape = False
  372. if op["input_desc"][0][0]["format"] == "DefaultFormat" and \
  373. op["input_desc"][1][0]["format"] == "FRACTAL_NZ":
  374. fractal_tensor = op["input_desc"][1][0]
  375. default_tensor = op["input_desc"][0][0]
  376. need_reshape = True
  377. elif op["input_desc"][0][0]["format"] == "FRACTAL_NZ" and \
  378. op["input_desc"][1][0]["format"] == "DefaultFormat":
  379. fractal_tensor = op["input_desc"][0][0]
  380. default_tensor = op["input_desc"][1][0]
  381. need_reshape = True
  382. if need_reshape:
  383. shape_fractal = fractal_tensor["shape"]
  384. shape_default = default_tensor["shape"]
  385. orig_shape = shape_fractal[:-4] + [shape_fractal[-3] * shape_fractal[-2]] + [shape_fractal[-4] * shape_fractal[-1]]
  386. shape_tmp = []
  387. shape_out = []
  388. diff_dims = len(orig_shape) - len(shape_default)
  389. for i in range(diff_dims):
  390. shape_tmp.append(1)
  391. shape_out.append(orig_shape[i])
  392. for i in range(len(shape_default)):
  393. shape_tmp.append(shape_default[i])
  394. if orig_shape[i + diff_dims] == 1:
  395. shape_out.append(shape_default[i])
  396. else:
  397. shape_out.append(orig_shape[i + diff_dims])
  398. shape_new = []
  399. for i in range(len(shape_out) - 2):
  400. shape_new.append(shape_out[i])
  401. if shape_tmp[-2] == 1 and shape_tmp[-1] == 1:
  402. shape_new.extend([1, 1, 1, 1])
  403. elif shape_tmp[-2] == 1 and shape_tmp[-1] == shape_default[-1]:
  404. shape_new.extend([shape_fractal[-4], 1, 1, shape_fractal[-1]])
  405. elif shape_tmp[-2] == shape_default[-2] and shape_tmp[-1] == 1:
  406. shape_new.extend([1, shape_fractal[-3], shape_fractal[-2], 1])
  407. if "value" in default_tensor:
  408. sent_reshape_tensor = "%s = np.full(%s, %s, np.%s)" \
  409. % (default_tensor["tensor_name"], shape_new, default_tensor["value"],
  410. default_tensor["data_type"])
  411. else:
  412. sent_reshape_tensor = "%s = np.reshape(%s, %s)" \
  413. % (default_tensor["tensor_name"], default_tensor["tensor_name"], tuple(shape_new))
  414. p.out(sent_reshape_tensor, True)
  415. if dsl_fun is None:
  416. logging.info("[%s] is not support for %s", op["name"], op)
  417. continue
  418. sent = dsl_fun(op['input_desc'], op['output_desc'], op['attr'])
  419. logging.debug(sent)
  420. p.out(sent, True)
  421. idx = 0
  422. out_nums = len(desc["output_desc"])
  423. for output_desc in desc["output_desc"]:
  424. shape = [1] if not output_desc["shape"] else output_desc["shape"]
  425. dtype = output_desc["data_type"]
  426. item = np.full(shape, np.nan, dtype)
  427. input_for_mod.append(item)
  428. tensor_name = output_desc["tensor_name"]
  429. if tensor_name not in fake_output_tensors:
  430. real_idx = idx - out_nums
  431. output_indexes.append(real_idx)
  432. p.out("expect.append(%s)" % (tensor_name), True)
  433. idx += 1
  434. # Add inplace tensors to expect, and add their index to output_indexes.
  435. if inplace_assign_write:
  436. inplace_tensors = "["
  437. inplace_tensors_index = []
  438. for tensor_name in inplace_assign_write:
  439. inplace_tensors_index.append(input_order[tensor_name])
  440. inplace_tensors += "{}, ".format(tensor_name)
  441. inplace_tensors += "]"
  442. p.out("inplace_tensors = {}".format(inplace_tensors), True)
  443. p.out("expect.extend(inplace_tensors)", True)
  444. output_indexes.extend(inplace_tensors_index)
  445. p.close()
  446. with open("json_data.py", 'r') as f:
  447. sent = f.read()
  448. exec(sent)
  449. return input_for_mod, expect, output_indexes

AKG(Auto Kernel Generator)对深度神经网络中的算子进行优化,并提供特定模式下的算子自动融合功能。AKG与MindSpore的图算融合功能协同工作,可提升在不同硬件后端上运行网络的性能。