You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 19 kB

5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """GraphKernel cost model"""
  16. class Utils:
  17. """Model utils"""
  18. def __init__(self):
  19. pass
  20. @staticmethod
  21. def get_attr_type(attr):
  22. """Get attr type"""
  23. if isinstance(attr, bool):
  24. return 'bool'
  25. if isinstance(attr, str):
  26. return 'str'
  27. if isinstance(attr, int):
  28. return 'int'
  29. if isinstance(attr, float):
  30. return 'bool'
  31. if isinstance(attr, (list, tuple)):
  32. if not attr:
  33. raise ValueError("Length of attr is 0")
  34. if isinstance(attr[0], int):
  35. return 'listInt'
  36. if isinstance(attr[0], str):
  37. return 'listStr'
  38. raise ValueError("Unknown type of attr: {}".format(attr))
  39. class DataFormat:
  40. """DataFormat"""
  41. DEFAULT = "DefaultFormat"
  42. NC1KHKWHWC0 = "NC1KHKWHWC0"
  43. ND = "ND"
  44. NCHW = "NCHW"
  45. NHWC = "NHWC"
  46. HWCN = "HWCN"
  47. NC1HWC0 = "NC1HWC0"
  48. FRAC_Z = "FracZ"
  49. FRAC_NZ = "FRACTAL_NZ"
  50. C1HWNCOC0 = "C1HWNCoC0"
  51. NC1HWC0_C04 = "NC1HWC0_C04"
  52. FRACTAL_Z_C04 = "FRACTAL_Z_C04"
  53. NDHWC = "NDHWC"
  54. def __init__(self):
  55. pass
  56. class DataType:
  57. """Data Type"""
  58. FLOAT = "float"
  59. FLOAT16 = "float16"
  60. FLOAT32 = "float32"
  61. FLOAT64 = "float64"
  62. INT = "int"
  63. INT8 = "int8"
  64. INT16 = "int16"
  65. INT32 = "int32"
  66. INT64 = "int64"
  67. UINT = "uint"
  68. UINT8 = "uint8"
  69. UINT16 = "uint16"
  70. UINT32 = "uint32"
  71. UINT64 = "uint64"
  72. BOOL = "bool"
  73. def __init__(self):
  74. pass
  75. class PrimLib:
  76. """Prim lib"""
  77. UNKNOWN = 0
  78. RESHAPE = 1
  79. ELEMWISE = 2
  80. BROADCAST = 3
  81. REDUCE = 4
  82. OPAQUE = 5
  83. def __init__(self):
  84. pass
  85. class Prim:
  86. """Prim"""
  87. def __init__(self, iter_type, calibrate=1, relation_func=None):
  88. self.iter_type = iter_type
  89. self.calibrate = calibrate
  90. self.relation_func = relation_func
  91. if relation_func is None:
  92. self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
  93. def default_reshape_relation(self, op, input_idx):
  94. """Process reshape relation"""
  95. axis_relation, elem_relation = self.unknown_relation(op, input_idx)
  96. elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
  97. return axis_relation, elem_relation
  98. def default_elemwise_broadcast_relation(self, op, input_idx):
  99. """Process elemwise and broadcast relation"""
  100. out_shape = op.output.shape
  101. in_shape = op.inputs[input_idx].shape
  102. if len(out_shape) < len(in_shape):
  103. raise ValueError("input/output size is abnormal")
  104. axis_relation, elem_relation = [], []
  105. delta = len(out_shape) - len(in_shape)
  106. if delta > 0:
  107. for i in range(0, delta):
  108. axis_relation.append(None)
  109. elem_relation.append(None)
  110. for i, _ in enumerate(in_shape):
  111. axis_relation.append(i)
  112. elem_relation.append(
  113. PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
  114. return axis_relation, elem_relation
  115. def default_reduce_relation(self, op, input_idx):
  116. """Process reduce relation"""
  117. axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
  118. for i in op.attrs['reduce_axis']:
  119. elem_relation[i] = PrimLib.REDUCE
  120. return axis_relation, elem_relation
  121. def unknown_relation(self, op, input_idx):
  122. """Process unknown relation"""
  123. out_shape = op.output.shape
  124. in_shape = op.inputs[input_idx].shape
  125. all_relation = list(range(len(in_shape)))
  126. axis_relation = [all_relation for i in range(0, len(out_shape))]
  127. elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
  128. return axis_relation, elem_relation
  129. default_relation_func = [
  130. unknown_relation,
  131. default_reshape_relation,
  132. default_elemwise_broadcast_relation,
  133. default_elemwise_broadcast_relation,
  134. default_reduce_relation,
  135. unknown_relation,
  136. ]
  137. primtives = {
  138. 'Add': Prim(ELEMWISE),
  139. 'Abs': Prim(ELEMWISE),
  140. 'Neg': Prim(ELEMWISE),
  141. 'Mul': Prim(ELEMWISE),
  142. 'Sub': Prim(ELEMWISE),
  143. 'Log': Prim(ELEMWISE),
  144. 'IsNan': Prim(ELEMWISE),
  145. 'IsInf': Prim(ELEMWISE),
  146. 'IsFinite': Prim(ELEMWISE),
  147. 'Exp': Prim(ELEMWISE),
  148. 'Rsqrt': Prim(ELEMWISE),
  149. 'Sqrt': Prim(ELEMWISE),
  150. 'Div': Prim(ELEMWISE),
  151. 'FloorDiv': Prim(ELEMWISE),
  152. 'RealDiv': Prim(ELEMWISE),
  153. 'Mod': Prim(ELEMWISE),
  154. 'Floor': Prim(ELEMWISE),
  155. 'FloorMod': Prim(ELEMWISE),
  156. 'Erf': Prim(ELEMWISE),
  157. 'Erfc': Prim(ELEMWISE),
  158. 'Cast': Prim(ELEMWISE),
  159. 'Pow': Prim(ELEMWISE),
  160. 'Minimum': Prim(ELEMWISE),
  161. 'Maximum': Prim(ELEMWISE),
  162. 'Reciprocal': Prim(ELEMWISE),
  163. 'Equal': Prim(ELEMWISE),
  164. 'NotEqual': Prim(ELEMWISE),
  165. 'Greater': Prim(ELEMWISE),
  166. 'GreaterEqual': Prim(ELEMWISE),
  167. 'Less': Prim(ELEMWISE),
  168. 'LessEqual': Prim(ELEMWISE),
  169. 'LogicalNot': Prim(ELEMWISE),
  170. 'LogicalAnd': Prim(ELEMWISE),
  171. 'LogicalOr': Prim(ELEMWISE),
  172. 'Square': Prim(ELEMWISE),
  173. 'AddN': Prim(ELEMWISE),
  174. 'Select': Prim(ELEMWISE, 8),
  175. 'ReduceSum': Prim(REDUCE),
  176. 'ReduceMax': Prim(REDUCE),
  177. 'ReduceMin': Prim(REDUCE),
  178. 'Argmax': Prim(REDUCE),
  179. 'Argmin': Prim(REDUCE),
  180. 'Assign': Prim(ELEMWISE),
  181. 'Sign': Prim(ELEMWISE),
  182. 'Sin': Prim(ELEMWISE),
  183. 'Cos': Prim(ELEMWISE),
  184. 'Asin': Prim(ELEMWISE),
  185. 'ACos': Prim(ELEMWISE),
  186. 'Tanh': Prim(ELEMWISE),
  187. 'Asinh': Prim(ELEMWISE),
  188. 'Acosh': Prim(ELEMWISE),
  189. 'InplaceAssign': Prim(ELEMWISE),
  190. '@ReduceInit': Prim(ELEMWISE),
  191. 'Reshape': Prim(RESHAPE),
  192. 'Squeeze': Prim(RESHAPE),
  193. 'Flatten': Prim(RESHAPE),
  194. 'FlattenGrad': Prim(RESHAPE),
  195. 'Transpose': Prim(OPAQUE),
  196. 'Tile': Prim(BROADCAST),
  197. 'BroadcastTo': Prim(BROADCAST),
  198. 'StridedSlice': Prim(OPAQUE),
  199. 'MatMul': Prim(OPAQUE),
  200. 'TransData': Prim(OPAQUE),
  201. 'BatchMatMul': Prim(OPAQUE),
  202. 'UnPadAkg': Prim(OPAQUE),
  203. 'PadAkg': Prim(OPAQUE),
  204. 'Conv2D': Prim(OPAQUE),
  205. 'CReal': Prim(ELEMWISE),
  206. 'CImag': Prim(ELEMWISE),
  207. 'Complex': Prim(ELEMWISE),
  208. 'Atan': Prim(ELEMWISE),
  209. 'Atan2': Prim(ELEMWISE),
  210. 'Expm1': Prim(ELEMWISE),
  211. 'TensorScatterAdd': Prim(OPAQUE),
  212. 'Gather': Prim(OPAQUE),
  213. 'GatherNd': Prim(OPAQUE),
  214. 'UnsortedSegmentSum': Prim(OPAQUE),
  215. 'UserDefined': Prim(OPAQUE),
  216. }
  217. default_primtive = Prim(UNKNOWN)
  218. @classmethod
  219. def get_prim(cls, op):
  220. """Get op primtive"""
  221. prim = cls.primtives.get(op.prim, None)
  222. if prim is None:
  223. print('[WARN] primtive is not registered: ' + op.prim)
  224. prim = cls.default_primtive
  225. return prim
  226. @classmethod
  227. def input_relation(cls, op, input_idx):
  228. """Get op's input_relation according to input_idx"""
  229. return cls.get_prim(op).relation_func(op, input_idx)
  230. @classmethod
  231. def iter_type(cls, op):
  232. """Get op's iter type"""
  233. return cls.get_prim(op).iter_type
  234. @classmethod
  235. def is_reduce(cls, op):
  236. """Check whether op's iter type is reduce"""
  237. return cls.get_prim(op).iter_type == cls.REDUCE
  238. @classmethod
  239. def calibrate_iter_size(cls, op, iter_size):
  240. """Get calibrate_iter_size"""
  241. return cls.get_prim(op).calibrate * iter_size
  242. @classmethod
  243. def dtype_bytes(cls, dtype):
  244. """Get dtype bytes"""
  245. bits, unit = 1, 1
  246. for i in range(len(dtype) - 1, 0, -1):
  247. if dtype[i].isdecimal():
  248. bits += int(dtype[i]) * unit
  249. unit *= 10
  250. else:
  251. break
  252. return bits // 8
  253. @classmethod
  254. def inplace_reuse(cls, op, input_idx, start_axis=0):
  255. """Check whether op is inplace reuse"""
  256. if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
  257. return False
  258. _, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
  259. for i in range(start_axis, len(elem_relation)):
  260. if elem_relation[i] != cls.ELEMWISE:
  261. return False
  262. return True
  263. class Tensor:
  264. """Tensor"""
  265. PARA_NONE = 0
  266. PARA_INPUT = 1
  267. PARA_OUTPUT = 2
  268. class Buddy:
  269. """Buddy"""
  270. def __init__(self, leader):
  271. self.members = [leader]
  272. def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
  273. self.name = name
  274. self.shape = shape
  275. self.dtype = dtype
  276. self.data_format = data_format
  277. self.para_type = para_type
  278. self.op = None
  279. self.to_ops = []
  280. self.buddy = None
  281. def __str__(self):
  282. return self.name + str(list(self.shape))
  283. def __repr__(self):
  284. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  285. def get_size(self):
  286. """Get size"""
  287. size = PrimLib.dtype_bytes(self.dtype)
  288. for i in self.shape:
  289. size *= i
  290. return size
  291. def add_buddy(self, tensor):
  292. """Add buddy"""
  293. if self.buddy is None:
  294. self.buddy = self.Buddy(self)
  295. self.buddy.members.append(tensor)
  296. tensor.buddy = self.buddy
  297. class Value:
  298. """Value"""
  299. def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
  300. self.name = name
  301. self.shape = [1]
  302. self.dtype = dtype
  303. self.value = value
  304. self.data_format = data_format
  305. def __str__(self):
  306. return self.name + str(list(self.shape))
  307. def __repr__(self):
  308. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  309. def get_size(self):
  310. """Get size"""
  311. return 1
  312. class Operator:
  313. """Operator"""
  314. def __init__(self, primtive, inputs, output, attrs):
  315. self.prim = primtive
  316. self.inputs = inputs
  317. self.output = output
  318. self.attrs = attrs
  319. for t in inputs:
  320. t.to_ops.append(self)
  321. if output.op is None:
  322. output.op = self
  323. self.all_inputs = [] # include Tensor inputs and Value inputs.
  324. def __str__(self):
  325. args = ', '.join([str(t) for t in self.all_inputs])
  326. expr = "%s = %s.%s(%s) id:%s" % (
  327. str(self.output), self.prim, self.output.dtype, args, id(self))
  328. return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
  329. def __repr__(self):
  330. return str(self)
  331. class Graph:
  332. """Graph"""
  333. def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
  334. self.name = name
  335. self.ops = ops # in topo order, can not use set
  336. self.inputs = []
  337. self.outputs = []
  338. self.stitch_info = stitch_info
  339. self.recompute_ops = recompute_ops
  340. self.processor = ""
  341. def set_processor(self, processor):
  342. """Set processor"""
  343. self.processor = processor
  344. def add(self, ops):
  345. """Add ops"""
  346. if isinstance(ops, Operator):
  347. self.ops.append(ops)
  348. else:
  349. self.ops.extend(ops)
  350. def extract_subgraph(self, graph_name, tensor_names, difference=False):
  351. """Extract subgraph from this graph"""
  352. graph = Graph(graph_name, [])
  353. outputs = set(tensor_names)
  354. if difference:
  355. for op in self.ops:
  356. if op.output.name not in outputs:
  357. graph.add(op)
  358. else:
  359. for op in self.ops:
  360. if op.output.name in outputs:
  361. graph.add(op)
  362. outputs.remove(op.output.name)
  363. for name in outputs:
  364. raise ValueError("invalid input tensor : " + name)
  365. return graph
  366. def deduce_parameters(self):
  367. """Deduce parameters"""
  368. inputs, outputs = [], []
  369. for op in self.ops:
  370. for t in op.inputs:
  371. if t not in inputs and t.op not in self.ops:
  372. inputs.append(t)
  373. if op.output in outputs:
  374. continue
  375. if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
  376. outputs.append(op.output)
  377. continue
  378. if any([succ not in self.ops for succ in op.output.to_ops]):
  379. outputs.append(op.output)
  380. if self.inputs:
  381. inputs = self.inputs
  382. if self.outputs:
  383. outputs = self.outputs
  384. return inputs, outputs
  385. def __str__(self):
  386. inputs, outputs = self.deduce_parameters()
  387. para_str = ', '.join([repr(t) for t in inputs])
  388. out_str = ', '.join([repr(t) for t in outputs])
  389. lines = []
  390. lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
  391. if self.stitch_info:
  392. if self.stitch_info.stitch_ops:
  393. lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
  394. if self.stitch_info.stitch_atomic_ops:
  395. lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
  396. for op in self.ops:
  397. lines.append(' ' + str(op))
  398. lines.append('}')
  399. return '\n'.join(lines)
  400. def __repr__(self):
  401. return str(self)
  402. def dump(self):
  403. """Dump Graph to json"""
  404. attr_name = {'reduce_axis': 'axis'}
  405. inputs, outputs = self.deduce_parameters()
  406. input_desc, output_desc, op_desc = [], [], []
  407. for t in inputs:
  408. input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
  409. 'tensor_name': t.name, 'format': t.data_format}])
  410. for t in outputs:
  411. output_desc.append({'data_type': t.dtype, 'shape': t.shape,
  412. 'tensor_name': t.name, 'format': t.data_format})
  413. for op in self.ops:
  414. attrs, in_desc = [], []
  415. for a in op.attrs:
  416. name = attr_name.get(a, a)
  417. attrs.append(
  418. {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
  419. for t in op.all_inputs:
  420. if isinstance(t, Tensor):
  421. in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
  422. 'tensor_name': t.name, 'format': t.data_format}])
  423. else:
  424. in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
  425. 'tensor_name': t.name, 'format': t.data_format}])
  426. out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
  427. 'tensor_name': op.output.name, 'format': op.output.data_format}]
  428. op_desc.append({'attr': attrs, 'impl_path': '',
  429. 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
  430. graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
  431. 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
  432. 'platform': 'AKG', 'process': self.processor}
  433. if self.stitch_info and self.stitch_info.stitch_ops:
  434. buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
  435. if self.stitch_info.stitch_atomic_ops:
  436. buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
  437. graph_desc['buffer_stitch'] = buffer_stitch
  438. return graph_desc
  439. class GraphVisitor:
  440. """Graph visitor"""
  441. def __init__(self, forward=True):
  442. self.forward = forward
  443. def visit_graph(self, graph):
  444. """Visit graph"""
  445. if self.forward:
  446. for op in graph.ops:
  447. self.visit(op)
  448. else:
  449. for i in range(len(graph.ops)-1, -1, -1):
  450. self.visit(graph.ops[i])
  451. class AlignShape(GraphVisitor):
  452. """Align shape"""
  453. def __init__(self):
  454. super(AlignShape, self).__init__()
  455. def visit(self, op):
  456. """Visit op node"""
  457. prim = PrimLib.get_prim(op)
  458. if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
  459. out_dim = len(op.output.shape)
  460. align_dim = out_dim
  461. for t in op.inputs:
  462. if len(t.shape) > align_dim:
  463. align_dim = len(t.shape)
  464. if align_dim > out_dim:
  465. op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
  466. class AddControlBuddy(GraphVisitor):
  467. """Add control buddy"""
  468. def __init__(self):
  469. super(AddControlBuddy, self).__init__()
  470. self.buddies = {} # {op : [ctrl_op]}
  471. def visit(self, op):
  472. """Visit op node"""
  473. if op.prim == "MakeTuple":
  474. if len(op.output.to_ops) != 1:
  475. raise ValueError("operator's output size is abnormal")
  476. owner = op.output.to_ops[0]
  477. if owner in self.buddies:
  478. self.buddies[owner].append(op)
  479. else:
  480. self.buddies[owner] = [op]
  481. if op in self.buddies:
  482. ops = self.buddies.pop(op)
  483. self.buddies[owner].extend(ops)
  484. def visit_graph(self, graph):
  485. """Visit graph nodes"""
  486. super(AddControlBuddy, self).visit_graph(graph)
  487. for owner in self.buddies:
  488. for op in self.buddies[owner]:
  489. owner.add_buddy(op.output)
  490. class GraphKernelUnsupportedException(Exception):
  491. """"GraphKernel Unsupported Exception"""
  492. def __init__(self, message):
  493. super(GraphKernelUnsupportedException, self).__init__()
  494. self.message = message