You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_json_data.py 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """generate numpy data for composite json"""
  15. import json
  16. import logging
  17. import inspect
  18. import numpy as np
  19. import subprocess
  20. from tests.common.gen_random import random_gaussian
  21. def get_attr(attr_desc, attr_type):
  22. """get op attr by type"""
  23. for attr in attr_desc:
  24. if attr["name"] == attr_type:
  25. return attr["value"]
  26. logging.warning("attr {} not found, please check.".format(attr_type))
  27. return None
  28. class CodePrinter(object):
  29. """print numpy file"""
  30. def __init__(self, out_file):
  31. self.fout_ = open(out_file, 'w')
  32. self.indent_ = 0
  33. def __del__(self):
  34. self.fout_.close()
  35. def out(self, data, new_line=False):
  36. """write data"""
  37. if new_line:
  38. self.fout_.write("\n")
  39. for i in range(0, self.indent_):
  40. self.fout_.write(' ')
  41. if isinstance(data, str):
  42. self.fout_.write(data)
  43. else:
  44. self.fout_.write(str(data))
  45. def null_line(self):
  46. """add null line"""
  47. self.fout_.write("\n")
  48. def close(self):
  49. """close file"""
  50. self.fout_.close()
  51. def get_input(desc):
  52. """get input values"""
  53. value = desc.get('value', None)
  54. return value if value is not None else desc['tensor_name']
  55. def reduce_str(inputs, output, attr, op_type):
  56. """gen sum string"""
  57. axis = []
  58. keepdims = False
  59. axis_value = get_attr(attr, "axis")
  60. if axis_value:
  61. axis = list(axis_value) if isinstance(axis_value, (list, tuple)) else [axis_value]
  62. keepdims_value = get_attr(attr, "keep_dims")
  63. keepdims = keepdims_value if keepdims_value else keepdims
  64. if axis == []:
  65. s = "%s = np.%s(%s.astype(np.float32) if %s.dtype == np.float16 else %s, keepdims=%s).astype(%s.dtype)" % (
  66. output[0]['tensor_name'], op_type, get_input(inputs[0][0]), get_input(inputs[0][0]),
  67. get_input(inputs[0][0]), keepdims, get_input(inputs[0][0]))
  68. else:
  69. s = "%s = np.%s(%s.astype(np.float32) if %s.dtype == np.float16 else %s, axis=tuple(%s), keepdims=%s).astype(%s.dtype); %s = np.reshape(%s, %s) " %\
  70. (output[0]['tensor_name'], op_type, get_input(inputs[0][0]), get_input(inputs[0][0]),
  71. get_input(inputs[0][0]), axis, keepdims, get_input(inputs[0][0]),
  72. output[0]['tensor_name'], output[0]['tensor_name'], output[0]['shape'])
  73. return s
  74. def cast_str(inputs, output, attr):
  75. """gen cast string"""
  76. dst_type = get_attr(attr, "dst_type")
  77. s = "%s = np.array(%s).astype(np.%s) if isinstance(%s, (float, int)) else %s.astype(np.%s)" % (
  78. output[0]['tensor_name'], get_input(inputs[0][0]), dst_type, get_input(inputs[0][0]),
  79. get_input(inputs[0][0]), dst_type)
  80. return s
  81. def broadcast_str(inputs, output, attr):
  82. """gen broadcast string"""
  83. dst_shape = get_attr(attr, "shape")
  84. s = "%s = np.broadcast_to(%s, %s)" % (
  85. output[0]["tensor_name"], get_input(inputs[0][0]), dst_shape)
  86. return s
  87. def transpose_str(inputs, output, attr):
  88. """gen transpose string"""
  89. axes = None
  90. axes_value = get_attr(attr, "perm")
  91. axes = axes_value if axes_value else axes
  92. s = "%s = np.transpose(%s, axes=%s)" % (
  93. output[0]['tensor_name'], get_input(inputs[0][0]), axes)
  94. return s
  95. def trans_data_two2fractal(input_, src_format, dst_format):
  96. shape = list(input_.shape)
  97. dtype = input_.dtype
  98. if dtype == "float32":
  99. input_ = input_.astype(np.float16)
  100. if src_format == "DefaultFormat" or src_format == "NCHW":
  101. m, n = shape[-2], shape[-1]
  102. m1, n1 = m // 16, n // 16
  103. m0, n0 = 16, 16
  104. needPad = m % 16 != 0 or n % 16 != 0
  105. if needPad:
  106. pad_m, pad_n = (m + 15) // 16 * 16, (n + 15) // 16 * 16
  107. pad_shape = [x for x in shape]
  108. pad_shape[-1] = pad_n
  109. pad_shape[-2] = pad_m
  110. pad_input = np.zeros(pad_shape).astype(dtype)
  111. if len(shape) == 2:
  112. pad_input[:m, :n] = input_
  113. elif len(shape) == 3:
  114. pad_input[:, :m, :n] = input_
  115. elif len(shape) == 4:
  116. pad_input[:, :, :m, :n] = input_
  117. m1, n1 = pad_m // 16, pad_n // 16
  118. reshape_shape = shape[:-2] + [m1, m0, n1, n0]
  119. reshape_input = pad_input.reshape(reshape_shape)
  120. else:
  121. reshape_shape = shape[:-2] + [m1, m0, n1, n0]
  122. reshape_input = input_.reshape(reshape_shape)
  123. if dst_format == "FRACTAL_NZ":
  124. transpose_axis = [2, 0, 1, 3]
  125. else:
  126. raise ValueError("dst_fromat %s is not suppored when src_format is %s" % (
  127. dst_format, src_format))
  128. transpose_axis = [x + len(shape) - 2 for x in transpose_axis]
  129. transpose_axis = [x for x in range(len(shape) - 2)] + transpose_axis
  130. bench_mark = reshape_input.transpose(transpose_axis).astype('float16')
  131. return bench_mark
  132. raise ValueError("src_format %s is not supported!" % src_format)
  133. def trans_data_fractal2two(input_, src_format, dst_format, shape_origin):
  134. shape_origin = [int(_) for _ in shape_origin]
  135. shape = list(input_.shape)
  136. n1, m1, m0, n0 = shape[-4:]
  137. new_shape = shape[:-4] + [m1 * m0, n1 * n0]
  138. tranpose_axis = [1, 2, 0, 3]
  139. tranpose_axis = [x + len(shape) - 4 for x in tranpose_axis]
  140. tranpose_axis = [i for i in range(len(shape) - 4)] + tranpose_axis
  141. bench_mark = input_.transpose(tranpose_axis).reshape(new_shape)
  142. if new_shape != shape_origin:
  143. if len(shape_origin) == 2:
  144. bench_mark = bench_mark[:shape_origin[0], :shape_origin[1]]
  145. elif len(shape_origin) == 3:
  146. bench_mark = bench_mark[:, shape_origin[0], :shape_origin[1]]
  147. elif len(shape_origin) == 4:
  148. bench_mark = bench_mark[:, :, shape_origin[0], :shape_origin[1]]
  149. return bench_mark
  150. def get_trans_data_str(input_name, output_name, ori_shape, src_format, dst_format):
  151. support_formats = [("DefaultFormat", "FRACTAL_NZ"),
  152. ("NCHW", "FRACTAL_NZ"),
  153. ("FRACTAL_NZ", "DefaultFormat"),
  154. ("FRACTAL_NZ", "NCHW")]
  155. if (src_format, dst_format) not in support_formats:
  156. raise ValueError("src_format %s and dst_format %s is not supported!" %
  157. (src_format, dst_format))
  158. if (src_format == 'DefaultFormat' or src_format == "NCHW") and dst_format == 'FRACTAL_NZ':
  159. res = "%s \n%s = %s(%s, '%s', '%s')" % (inspect.getsource(trans_data_two2fractal),
  160. output_name, trans_data_two2fractal.__name__, input_name,
  161. src_format, dst_format)
  162. elif src_format == 'FRACTAL_NZ' and (dst_format == 'DefaultFormat' or dst_format == "NCHW"):
  163. res = "%s \n%s = %s(%s, '%s', '%s', %s)" % (inspect.getsource(trans_data_fractal2two),
  164. output_name, trans_data_fractal2two.__name__, input_name,
  165. src_format, dst_format, ori_shape)
  166. else:
  167. raise ValueError("src_format(%s) and dst_format(%s) is not supported!" % (src_format, dst_format))
  168. return res
  169. def trans_data_dsl(inputs, output, attr):
  170. src_format = get_attr(attr, "src_format")
  171. dst_format = get_attr(attr, "dst_format")
  172. ori_shape = output[0]['shape']
  173. input_name = get_input(inputs[0][0])
  174. output_name = output[0]['tensor_name']
  175. return get_trans_data_str(input_name, output_name, ori_shape, src_format, dst_format)
  176. def np_matmul_str(inputs, output, attr):
  177. trans_a = get_attr(attr, "transpose_a")
  178. trans_b = get_attr(attr, "transpose_b")
  179. input_0 = inputs[0][0]
  180. input_1 = inputs[1][0]
  181. res = ""
  182. if input_0['data_type'] == "float16":
  183. tmp = "%s = %s.astype(np.float32)" % (get_input(input_0), get_input(input_0))
  184. res = res + tmp + "\n"
  185. if input_1['data_type'] == "float16":
  186. tmp = "%s = %s.astype(np.float32)" % (get_input(input_1), get_input(input_1))
  187. res = res + tmp + "\n"
  188. if trans_a and trans_b:
  189. res += "%s = np.dot(np.swapaxes(%s, -1, -2), np.swapaxes(%s, -1, -2))" %\
  190. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  191. elif trans_a:
  192. res += "%s = np.dot(np.swapaxes(%s, -1, -2), %s)" %\
  193. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  194. elif trans_b:
  195. res += "%s = np.dot(%s, np.swapaxes(%s, -1, -2))" %\
  196. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  197. else:
  198. res += "%s = np.dot(%s, %s)" %\
  199. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  200. if output[0]['data_type'] == "float16":
  201. res += "\n" + "%s = %s.astype(np.float16)" % (output[0]['tensor_name'], output[0]['tensor_name'])
  202. return res
  203. def batchmatmul_str(inputs, output, attr):
  204. trans_a = get_attr(attr, "transpose_a")
  205. trans_b = get_attr(attr, "transpose_b")
  206. if trans_a and trans_b:
  207. res = "%s = np.matmul(np.swapaxes(%s, -1, -2), np.swapaxes(%s, -1, -2))" %\
  208. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  209. elif trans_a:
  210. res = "%s = np.matmul(np.swapaxes(%s, -1, -2), %s)" %\
  211. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  212. elif trans_b:
  213. res = "%s = np.matmul(%s, np.swapaxes(%s, -1, -2))" %\
  214. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  215. else:
  216. res = "%s = np.matmul(%s, %s)" %\
  217. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]))
  218. return res
  219. def convert_fracal_shape(ori_shape, fractal):
  220. ori_shape = tuple(ori_shape)
  221. if fractal == "zN":
  222. return ori_shape[:-4] + (ori_shape[-2] * ori_shape[-3], ori_shape[-1] * ori_shape[-4])
  223. if fractal == "zZ":
  224. return ori_shape[:-4] + (ori_shape[-4] * ori_shape[-2], ori_shape[-3] * ori_shape[-1])
  225. def matmul_str(inputs, output, attr):
  226. left_format = get_attr(attr, "left_format")
  227. right_format = get_attr(attr, "right_format")
  228. trans_a = get_attr(attr, "transpose_a")
  229. trans_b = get_attr(attr, "transpose_b")
  230. left_input = inputs[0][0]
  231. right_input = inputs[1][0]
  232. output_name = output[0]['tensor_name']
  233. output_format = output[0]['format']
  234. output_shape = output[0]['shape']
  235. right_ori_shape = convert_fracal_shape(right_input['shape'], right_format)
  236. left_input_name = get_input(left_input)
  237. right_input_name = get_input(right_input)
  238. res = ''
  239. if left_format == 'FRACTAL_NZ':
  240. left_ori_shape = convert_fracal_shape(left_input['shape'], "zN")
  241. left_trans_str = get_trans_data_str(left_input_name, left_input_name, left_ori_shape, left_format, 'DefaultFormat')
  242. res = res + left_trans_str + "\n"
  243. if right_format == 'FRACTAL_NZ':
  244. right_ori_shape = convert_fracal_shape(right_input['shape'], "zN")
  245. right_trans_str = get_trans_data_str(right_input_name, right_input_name, right_ori_shape, right_format, 'DefaultFormat')
  246. res = res + right_trans_str + "\n"
  247. matmul_str = np_matmul_str(inputs, output, attr)
  248. res = res + matmul_str + "\n"
  249. has_bias = (len(inputs) > 2)
  250. if has_bias:
  251. bias=inputs[2][0]
  252. bias_shape = right_ori_shape[-2] if trans_b else right_ori_shape[-1]
  253. if bias['shape'][0] != bias_shape:
  254. res += "%s = random_gaussian([%s, ], miu=1, sigma=0.1).astype(np.%s) \n" % (get_input(bias), str(bias_shape), bias['data_type'])
  255. res += "%s = np.add(%s, %s)\n" % (output_name, output_name, get_input(bias))
  256. if output_format != 'DefaultFormat':
  257. output_trans_str = get_trans_data_str(output_name, output_name, output_shape, 'DefaultFormat', output_format)
  258. res = res + output_trans_str + "\n"
  259. return res
  260. op_dsl = {
  261. "ReduceSum": lambda inputs, output, attr: reduce_str(inputs, output, attr, "sum"),
  262. "ReduceMax": lambda inputs, output, attr: reduce_str(inputs, output, attr, "max"),
  263. "ReduceMin": lambda inputs, output, attr: reduce_str(inputs, output, attr, "min"),
  264. "Tanh": lambda inputs, output, attr: "%s = np.tanh(%s)" %
  265. (output[0]['tensor_name'], get_input(inputs[0][0])),
  266. "Mul": lambda inputs, output, attr: "%s = np.multiply(%s, %s)" %
  267. (output[0]['tensor_name'], get_input(
  268. inputs[0][0]), get_input(inputs[1][0])),
  269. "Pow": lambda inputs, output, attr: "%s = np.power(%s, %s)" %
  270. (output[0]['tensor_name'], get_input(
  271. inputs[0][0]), get_input(inputs[1][0])),
  272. "Sub": lambda inputs, output, attr: "%s = np.subtract(%s, %s)" %
  273. (output[0]['tensor_name'], get_input(
  274. inputs[0][0]), get_input(inputs[1][0])),
  275. "TensorAdd": lambda inputs, output, attr: "%s = np.add(%s, %s)" %
  276. (output[0]['tensor_name'], get_input(
  277. inputs[0][0]), get_input(inputs[1][0])),
  278. "Add": lambda inputs, output, attr: "%s = np.add(%s, %s)" %
  279. (output[0]['tensor_name'], get_input(
  280. inputs[0][0]), get_input(inputs[1][0])),
  281. "Rsqrt": lambda inputs, output, attr: "%s = 1.0/np.sqrt(%s)" %
  282. (output[0]['tensor_name'], get_input(inputs[0][0])),
  283. "Neg": lambda inputs, output, attr: "%s = np.negative(%s)" %
  284. (output[0]['tensor_name'], get_input(inputs[0][0])),
  285. "Exp": lambda inputs, output, attr: "%s = np.exp(%s)" %
  286. (output[0]['tensor_name'], get_input(inputs[0][0])),
  287. "RealDiv": lambda inputs, output, attr: "%s = np.divide(%s, %s)" %
  288. (output[0]['tensor_name'], get_input(
  289. inputs[0][0]), get_input(inputs[1][0])),
  290. "Minimum": lambda inputs, output, attr: "%s = np.minimum(%s, %s)" %
  291. (output[0]['tensor_name'], get_input(
  292. inputs[0][0]), get_input(inputs[1][0])),
  293. "Maximum": lambda inputs, output, attr: "%s = np.maximum(%s, %s)" %
  294. (output[0]['tensor_name'], get_input(
  295. inputs[0][0]), get_input(inputs[1][0])),
  296. "Log": lambda inputs, output, attr: "%s = np.log(%s)" %
  297. (output[0]['tensor_name'], get_input(inputs[0][0])),
  298. "Sqrt": lambda inputs, output, attr: "%s = np.sqrt(%s)" %
  299. (output[0]['tensor_name'], get_input(inputs[0][0])),
  300. "Cast": lambda inputs, output, attr: cast_str(inputs, output, attr),
  301. "Reshape": lambda inputs, output, attr: "%s = np.reshape(%s, %s)" %
  302. (output[0]['tensor_name'], get_input(inputs[0][0]), get_attr(attr, "shape")),
  303. "OneHot": lambda inputs, output, attr: "%s = np.one_hot(%s, %s, %s, %s, %s, %s)" %
  304. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]), get_input(inputs[2][0]),
  305. attr[0]['value'], attr[1]['value'], inputs[0][0]['data_type']),
  306. "ZerosLike": lambda inputs, output, attr: "%s = np.zeros_like(%s)" %
  307. (output[0]['tensor_name'], get_input(inputs[0][0])),
  308. "AddN": lambda inputs, output, attr: "%s = %s" %
  309. (output[0]['tensor_name'], ' + '.join([get_input(inputs[0][i])
  310. for i in range(0, len(inputs[0]))])),
  311. "Tile": lambda inputs, output, attr: "%s = np.tile(%s, %s)" %
  312. (output[0]['tensor_name'], get_input(
  313. inputs[0][0]), get_attr(attr, "multiples")),
  314. "Reciprocal": lambda inputs, output, attr: "%s = np.divide(1.0, %s)" %
  315. (output[0]['tensor_name'], get_input(inputs[0][0])),
  316. "Equal": lambda inputs, output, attr: "%s = np.equal(%s, %s)" %
  317. (output[0]['tensor_name'], get_input(
  318. inputs[0][0]), get_input(inputs[1][0])),
  319. "GreaterEqual": lambda inputs, output, attr: "%s = np.greater_equal(%s, %s)" %
  320. (output[0]['tensor_name'], get_input(
  321. inputs[0][0]), get_input(inputs[1][0])),
  322. "Select": lambda inputs, output, attr: "%s = np.where(%s, %s, %s)" %
  323. (output[0]['tensor_name'], get_input(inputs[0][0]),
  324. get_input(inputs[1][0]), get_input(inputs[2][0])),
  325. "InplaceAssign": lambda inputs, output, attr: "%s = %s; %s = %s" %
  326. (get_input(inputs[0][0]), get_input(inputs[1][0]),
  327. output[0]['tensor_name'], get_input(inputs[2][0])),
  328. "Greater": lambda inputs, output, attr: "%s = np.greater(%s, %s)" %
  329. (output[0]['tensor_name'], get_input(
  330. inputs[0][0]), get_input(inputs[1][0])),
  331. "SelectGT": lambda inputs, output, attr: "%s = np.where(%s > %s, %s, %s)" %
  332. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]),
  333. get_input(inputs[2][0]), get_input(inputs[3][0])),
  334. "SelectLT": lambda inputs, output, attr: "%s = np.where(%s < %s, %s, %s)" %
  335. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0]),
  336. get_input(inputs[2][0]), get_input(inputs[3][0])),
  337. "Abs": lambda inputs, output, attr: "%s = np.absolute(%s)" %
  338. (output[0]['tensor_name'], get_input(inputs[0][0])),
  339. "LessEqual": lambda inputs, output, attr: "%s = np.less_equal(%s, %s)" %
  340. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0])),
  341. "Less": lambda inputs, output, attr: "%s = np.less(%s, %s)" %
  342. (output[0]['tensor_name'], get_input(inputs[0][0]), get_input(inputs[1][0])),
  343. "EquivFormat": lambda inputs, output, attr: "%s = %s" %
  344. (output[0]['tensor_name'], get_input(inputs[0][0])),
  345. "ExpandDims": lambda inputs, output, attr: "%s = np.expand_dims(%s, %s)" %
  346. (output[0]['tensor_name'], get_input(inputs[0][0]), get_attr(attr, "axis")),
  347. "Transpose": lambda inputs, output, attr: transpose_str(inputs, output, attr),
  348. "TransData": trans_data_dsl,
  349. "BroadcastTo": lambda inputs, output, attr: broadcast_str(inputs, output, attr),
  350. "BatchMatMul": lambda inputs, output, attr: batchmatmul_str(inputs, output, attr),
  351. "Assign": lambda inputs, output, attr: "%s = %s; %s = %s" %
  352. (get_input(inputs[0][0]), get_input(inputs[1][0]), output[0]['tensor_name'],
  353. get_input(inputs[1][0])),
  354. "MatMul": lambda inputs, output, attr: matmul_str(inputs, output, attr)
  355. }
  356. def gen_json_data(op_desc):
  357. """Generating test data for composite json"""
  358. desc = json.loads(op_desc)
  359. input_for_mod = []
  360. input_dict = {}
  361. input_order = {}
  362. output_indexes = []
  363. expect = []
  364. op_name = desc.get("op")
  365. if len(op_name.split("_")) > 0:
  366. op_hash = op_name.split("_")[-1]
  367. else:
  368. import time
  369. op_hash = str(time.time())
  370. uni_file_name = "json_data_" + op_hash + ".py"
  371. p = CodePrinter(uni_file_name)
  372. idx = 0
  373. # Collect input which should be processed by atomic clean.
  374. clean_input = []
  375. sum_out = None
  376. for op in desc["op_desc"]:
  377. if op["name"] == "ReduceSum":
  378. for a in op["attr"]:
  379. if a["name"] == "enable_atomic_add":
  380. sum_out = op["output_desc"][0]["tensor_name"]
  381. break
  382. elif op["name"] == "InplaceAssign":
  383. if not sum_out:
  384. continue
  385. if op["input_desc"][1][0]["tensor_name"] == sum_out:
  386. clean_input.append(op["input_desc"][0][0]["tensor_name"])
  387. for input_desc in desc["input_desc"] if desc["input_desc"] is not None else []:
  388. shape = [1] if not input_desc[0]["shape"] else input_desc[0]["shape"]
  389. dtype = input_desc[0]["data_type"]
  390. tensor_name = input_desc[0]["tensor_name"]
  391. if tensor_name in clean_input:
  392. item = np.zeros(shape).astype(dtype)
  393. else:
  394. item = random_gaussian(shape, miu=1, sigma=0.1).astype(dtype)
  395. input_for_mod.append(item)
  396. input_order[tensor_name] = idx
  397. input_dict[tensor_name] = item
  398. p.out("%s = np.array(input_dict.get('%s'))" % (tensor_name, tensor_name),
  399. new_line=False if idx == 0 else True)
  400. idx += 1
  401. inplace_assign_write = []
  402. fake_output_tensors = []
  403. elemwise_op_list = ["TensorAdd", "Add", "RealDiv", "Mul", "Minimum", "Maximum", "Sub"]
  404. for op in desc["op_desc"]:
  405. dsl_fun = op_dsl.get(op["name"], None)
  406. if op["name"] in ("InplaceAssign", "Assign"):
  407. if op["name"] == "InplaceAssign":
  408. fake_output = False
  409. for attr in op["attr"]:
  410. if attr["name"] == "fake_output":
  411. fake_output = attr["value"]
  412. if fake_output:
  413. fake_output_tensors.append(op["output_desc"][0]["tensor_name"])
  414. inplace_assign_write.append(op["input_desc"][0][0]["tensor_name"])
  415. elif op["name"] in elemwise_op_list and "format" in op["output_desc"][0]and \
  416. op["output_desc"][0]["format"] =="FRACTAL_NZ":
  417. need_reshape = False
  418. if op["input_desc"][0][0]["format"] == "DefaultFormat" and \
  419. op["input_desc"][1][0]["format"] == "FRACTAL_NZ":
  420. fractal_tensor = op["input_desc"][1][0]
  421. default_tensor = op["input_desc"][0][0]
  422. need_reshape = True
  423. elif op["input_desc"][0][0]["format"] == "FRACTAL_NZ" and \
  424. op["input_desc"][1][0]["format"] == "DefaultFormat":
  425. fractal_tensor = op["input_desc"][0][0]
  426. default_tensor = op["input_desc"][1][0]
  427. need_reshape = True
  428. if need_reshape:
  429. shape_fractal = fractal_tensor["shape"]
  430. shape_default = default_tensor["shape"]
  431. orig_shape = shape_fractal[:-4] + [shape_fractal[-3] * shape_fractal[-2]] + [shape_fractal[-4] * shape_fractal[-1]]
  432. shape_tmp = []
  433. shape_out = []
  434. diff_dims = len(orig_shape) - len(shape_default)
  435. for i in range(diff_dims):
  436. shape_tmp.append(1)
  437. shape_out.append(orig_shape[i])
  438. for i in range(len(shape_default)):
  439. shape_tmp.append(shape_default[i])
  440. if orig_shape[i + diff_dims] == 1:
  441. shape_out.append(shape_default[i])
  442. else:
  443. shape_out.append(orig_shape[i + diff_dims])
  444. shape_new = []
  445. for i in range(len(shape_out) - 2):
  446. shape_new.append(shape_out[i])
  447. if shape_tmp[-2] == 1 and shape_tmp[-1] == 1:
  448. shape_new.extend([1, 1, 1, 1])
  449. elif shape_tmp[-2] == 1 and shape_tmp[-1] == shape_default[-1]:
  450. shape_new.extend([shape_fractal[-4], 1, 1, shape_fractal[-1]])
  451. elif shape_tmp[-2] == shape_default[-2] and shape_tmp[-1] == 1:
  452. shape_new.extend([1, shape_fractal[-3], shape_fractal[-2], 1])
  453. if "value" in default_tensor:
  454. sent_reshape_tensor = "%s = np.full(%s, %s, np.%s)" \
  455. % (default_tensor["tensor_name"], shape_new, default_tensor["value"],
  456. default_tensor["data_type"])
  457. else:
  458. sent_reshape_tensor = "%s = np.reshape(%s, %s)" \
  459. % (default_tensor["tensor_name"], default_tensor["tensor_name"], tuple(shape_new))
  460. p.out(sent_reshape_tensor, True)
  461. if dsl_fun is None:
  462. logging.info("[%s] is not support for %s", op["name"], op)
  463. continue
  464. sent = dsl_fun(op['input_desc'], op['output_desc'], op['attr'])
  465. logging.debug(sent)
  466. p.out(sent, True)
  467. idx = 0
  468. out_nums = len(desc["output_desc"])
  469. for output_desc in desc["output_desc"]:
  470. shape = [1] if not output_desc["shape"] else output_desc["shape"]
  471. dtype = output_desc["data_type"]
  472. item = np.full(shape, np.nan, dtype)
  473. input_for_mod.append(item)
  474. tensor_name = output_desc["tensor_name"]
  475. if tensor_name not in fake_output_tensors:
  476. real_idx = idx - out_nums
  477. output_indexes.append(real_idx)
  478. p.out("expect.append(%s)" % (tensor_name), True)
  479. idx += 1
  480. # Add inplace tensors to expect, and add their index to output_indexes.
  481. if inplace_assign_write:
  482. inplace_tensors = "["
  483. inplace_tensors_index = []
  484. for tensor_name in inplace_assign_write:
  485. inplace_tensors_index.append(input_order[tensor_name])
  486. inplace_tensors += "{}, ".format(tensor_name)
  487. inplace_tensors += "]"
  488. p.out("inplace_tensors = {}".format(inplace_tensors), True)
  489. p.out("expect.extend(inplace_tensors)", True)
  490. output_indexes.extend(inplace_tensors_index)
  491. p.close()
  492. with open(uni_file_name, 'r') as f:
  493. sent = f.read()
  494. exec(sent)
  495. subprocess.run("rm %s" % uni_file_name, shell=True)
  496. return input_for_mod, expect, output_indexes