You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """GraphKernel cost model"""
  16. class Utils:
  17. """Model utils"""
  18. @staticmethod
  19. def get_attr_type(attr):
  20. """Get attr type"""
  21. if isinstance(attr, bool):
  22. return 'bool'
  23. if isinstance(attr, str):
  24. return 'str'
  25. if isinstance(attr, int):
  26. return 'int'
  27. if isinstance(attr, float):
  28. return 'bool'
  29. if isinstance(attr, (list, tuple)):
  30. if not attr:
  31. raise ValueError("Length of attr is 0")
  32. if isinstance(attr[0], int):
  33. return 'listInt'
  34. if isinstance(attr[0], str):
  35. return 'listStr'
  36. raise ValueError("Unknown type of attr: {}".format(attr))
  37. class DataFormat:
  38. """DataFormat"""
  39. DEFAULT = "DefaultFormat"
  40. NC1KHKWHWC0 = "NC1KHKWHWC0"
  41. ND = "ND"
  42. NCHW = "NCHW"
  43. NHWC = "NHWC"
  44. HWCN = "HWCN"
  45. NC1HWC0 = "NC1HWC0"
  46. FRAC_Z = "FracZ"
  47. FRAC_NZ = "FRACTAL_NZ"
  48. C1HWNCOC0 = "C1HWNCoC0"
  49. NC1HWC0_C04 = "NC1HWC0_C04"
  50. FRACTAL_Z_C04 = "FRACTAL_Z_C04"
  51. NDHWC = "NDHWC"
  52. class Config:
  53. R0 = 8.0
  54. UB_SIZE = 256 * 1024
  55. MAX_BLOCK = 32
  56. class PrimLib:
  57. """Prim lib"""
  58. UNKNOWN = 0
  59. ELEMWISE = 1
  60. BROADCAST = 2
  61. REDUCE = 3
  62. TRANSFORM = 4
  63. CONTROL = 5
  64. class Prim:
  65. """Prim"""
  66. def __init__(self, iter_type, calibrate=1, relation_func=None):
  67. self.iter_type = iter_type
  68. self.calibrate = calibrate
  69. self.relation_func = relation_func
  70. if relation_func is None:
  71. self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
  72. def default_elemwise_broadcast_relation(self, op, input_idx):
  73. """Process elemwise and broadcast relation"""
  74. out_shape = op.output.shape
  75. in_shape = op.inputs[input_idx].shape
  76. assert len(out_shape) >= len(in_shape)
  77. axis_relation, elem_relation = [], []
  78. delta = len(out_shape) - len(in_shape)
  79. if delta > 0:
  80. for i in range(0, delta):
  81. axis_relation.append(None)
  82. elem_relation.append(None)
  83. for i, _ in enumerate(in_shape):
  84. axis_relation.append(i)
  85. elem_relation.append(
  86. PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
  87. return axis_relation, elem_relation
  88. def default_reduce_relation(self, op, input_idx):
  89. """Process reduce relation"""
  90. axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
  91. for i in op.attrs['reduce_axis']:
  92. elem_relation[i] = PrimLib.REDUCE
  93. return axis_relation, elem_relation
  94. def unknown_relation(self, op, input_idx):
  95. """Process unknown relation"""
  96. out_shape = op.output.shape
  97. in_shape = op.inputs[input_idx].shape
  98. all_relation = list(range(len(in_shape)))
  99. axis_relation = [all_relation for i in range(0, len(out_shape))]
  100. elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
  101. return axis_relation, elem_relation
  102. default_relation_func = [
  103. unknown_relation,
  104. default_elemwise_broadcast_relation,
  105. default_elemwise_broadcast_relation,
  106. default_reduce_relation,
  107. unknown_relation,
  108. unknown_relation,
  109. ]
  110. primtives = {
  111. 'TensorAdd': Prim(ELEMWISE),
  112. 'Abs': Prim(ELEMWISE),
  113. 'Neg': Prim(ELEMWISE),
  114. 'Mul': Prim(ELEMWISE),
  115. 'Sub': Prim(ELEMWISE),
  116. 'Log': Prim(ELEMWISE),
  117. 'Exp': Prim(ELEMWISE),
  118. 'Rsqrt': Prim(ELEMWISE),
  119. 'Sqrt': Prim(ELEMWISE),
  120. 'RealDiv': Prim(ELEMWISE),
  121. 'Cast': Prim(ELEMWISE),
  122. 'Pow': Prim(ELEMWISE),
  123. 'Minimum': Prim(ELEMWISE),
  124. 'Maximum': Prim(ELEMWISE),
  125. 'Reciprocal': Prim(ELEMWISE),
  126. 'Equal': Prim(ELEMWISE),
  127. 'Greater': Prim(ELEMWISE),
  128. 'GreaterEqual': Prim(ELEMWISE),
  129. 'Less': Prim(ELEMWISE),
  130. 'LessEqual': Prim(ELEMWISE),
  131. 'Square': Prim(ELEMWISE),
  132. 'AddN': Prim(ELEMWISE),
  133. 'Select': Prim(ELEMWISE, 8),
  134. 'ReduceSum': Prim(REDUCE),
  135. 'ReduceMax': Prim(REDUCE),
  136. 'ReduceMin': Prim(REDUCE),
  137. 'make_tuple': Prim(CONTROL),
  138. 'ControlDepend': Prim(CONTROL),
  139. 'Assign': Prim(ELEMWISE),
  140. '@ReduceInit': Prim(ELEMWISE),
  141. }
  142. default_primtive = Prim(UNKNOWN)
  143. @classmethod
  144. def get_prim(cls, op):
  145. prim = cls.primtives.get(op.prim, None)
  146. if prim is None:
  147. print('[WARN] primtive is not registered: ' + op.prim)
  148. prim = cls.default_primtive
  149. return prim
  150. @classmethod
  151. def input_relation(cls, op, input_idx):
  152. return cls.get_prim(op).relation_func(op, input_idx)
  153. @classmethod
  154. def iter_type(cls, op):
  155. return cls.get_prim(op).iter_type
  156. @classmethod
  157. def is_reduce(cls, op):
  158. return cls.get_prim(op).iter_type == cls.REDUCE
  159. @classmethod
  160. def calibrate_iter_size(cls, op, iter_size):
  161. return cls.get_prim(op).calibrate * iter_size
  162. @classmethod
  163. def dtype_bytes(cls, dtype):
  164. bits, unit = 1, 1
  165. for i in range(len(dtype) - 1, 0, -1):
  166. if dtype[i].isdecimal():
  167. bits += int(dtype[i]) * unit
  168. unit *= 10
  169. else:
  170. break
  171. return bits // 8
  172. @classmethod
  173. def inplace_reuse(cls, op, input_idx, start_axis=0):
  174. if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
  175. return False
  176. _, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
  177. for i in range(start_axis, len(elem_relation)):
  178. if elem_relation[i] != cls.ELEMWISE:
  179. return False
  180. return True
  181. class Tensor:
  182. """Tensor"""
  183. PARA_NONE = 0
  184. PARA_INPUT = 1
  185. PARA_OUTPUT = 2
  186. class Buddy:
  187. def __init__(self, leader):
  188. self.members = [leader]
  189. def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
  190. self.name = name
  191. self.shape = shape
  192. self.dtype = dtype
  193. self.data_format = data_format
  194. self.para_type = para_type
  195. self.op = None
  196. self.to_ops = []
  197. self.buddy = None
  198. def __str__(self):
  199. return self.name + str(list(self.shape))
  200. def __repr__(self):
  201. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  202. def get_size(self):
  203. """Get size"""
  204. size = PrimLib.dtype_bytes(self.dtype)
  205. for i in self.shape:
  206. size *= i
  207. return size
  208. def add_buddy(self, tensor):
  209. """Add buddy"""
  210. if self.buddy is None:
  211. self.buddy = self.Buddy(self)
  212. self.buddy.members.append(tensor)
  213. tensor.buddy = self.buddy
  214. class Value:
  215. """Value"""
  216. def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
  217. self.name = name
  218. self.shape = [1]
  219. self.dtype = dtype
  220. self.value = value
  221. self.data_format = data_format
  222. def __str__(self):
  223. return self.name + str(list(self.shape)) + str(self.value)
  224. def __repr__(self):
  225. return "%s.%s%s%s" % (self.name, self.dtype, str(list(self.shape)), str(self.value))
  226. def get_size(self):
  227. return 1
  228. class Operator:
  229. """Operator"""
  230. def __init__(self, primtive, inputs, output, attrs):
  231. self.prim = primtive
  232. self.inputs = inputs
  233. self.output = output
  234. self.attrs = attrs
  235. for t in inputs:
  236. t.to_ops.append(self)
  237. if output.op is None:
  238. output.op = self
  239. self.all_inputs = [] # include Tensor inputs and Value inputs.
  240. def __str__(self):
  241. args = ', '.join([str(t) for t in self.all_inputs])
  242. expr = "%s = %s.%s(%s)" % (
  243. str(self.output), self.prim, self.output.dtype, args)
  244. return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
  245. def __repr__(self):
  246. return str(self)
  247. class Graph:
  248. """Graph"""
  249. def __init__(self, name, ops):
  250. self.name = name
  251. self.ops = ops # in topo order, can not use set
  252. self.outputs = []
  253. def set_processor(self, processor):
  254. """Set processor"""
  255. self.processor = processor
  256. def add(self, ops):
  257. """Add ops"""
  258. if isinstance(ops, Operator):
  259. self.ops.append(ops)
  260. else:
  261. self.ops.extend(ops)
  262. def extract_subgraph(self, graph_name, tensor_names, difference=False):
  263. """Extract subgraph from this graph"""
  264. graph = Graph(graph_name, [])
  265. outputs = set(tensor_names)
  266. if difference:
  267. for op in self.ops:
  268. if op.output.name not in outputs:
  269. graph.add(op)
  270. else:
  271. for op in self.ops:
  272. if op.output.name in outputs:
  273. graph.add(op)
  274. outputs.remove(op.output.name)
  275. for name in outputs:
  276. raise ValueError("invalid input tensor : " + name)
  277. return graph
  278. def deduce_parameters(self):
  279. """Deduce parameters"""
  280. inputs, outputs = [], []
  281. for op in self.ops:
  282. for t in op.inputs:
  283. if t not in inputs and t.op not in self.ops:
  284. inputs.append(t)
  285. if op.output not in outputs:
  286. if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
  287. outputs.append(op.output)
  288. else:
  289. for d in op.output.to_ops:
  290. if d not in self.ops:
  291. outputs.append(op.output)
  292. break
  293. if self.outputs:
  294. outputs = self.outputs
  295. return inputs, outputs
  296. def __str__(self):
  297. inputs, outputs = self.deduce_parameters()
  298. para_str = ', '.join([repr(t) for t in inputs])
  299. out_str = ', '.join([repr(t) for t in outputs])
  300. lines = []
  301. lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
  302. for op in self.ops:
  303. lines.append(' ' + str(op))
  304. lines.append('}')
  305. return '\n'.join(lines)
  306. def __repr__(self):
  307. return str(self)
  308. def dump(self):
  309. """Dump Graph to json"""
  310. attr_name = {'reduce_axis': 'axis'}
  311. inputs, outputs = self.deduce_parameters()
  312. input_desc, output_desc, op_desc = [], [], []
  313. for t in inputs:
  314. input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
  315. 'tensor_name': t.name, 'format': t.data_format}])
  316. for t in outputs:
  317. output_desc.append({'data_type': t.dtype, 'shape': t.shape,
  318. 'tensor_name': t.name, 'format': t.data_format})
  319. for op in self.ops:
  320. attrs, in_desc = [], []
  321. for a in op.attrs:
  322. name = attr_name.get(a, a)
  323. attrs.append(
  324. {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
  325. for t in op.all_inputs:
  326. if isinstance(t, Tensor):
  327. in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
  328. 'tensor_name': t.name, 'format': t.data_format}])
  329. else:
  330. in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
  331. 'tensor_name': t.name, 'format': t.data_format}])
  332. out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
  333. 'tensor_name': op.output.name, 'format': t.data_format}]
  334. op_desc.append({'attr': attrs, 'impl_path': '',
  335. 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
  336. graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
  337. 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
  338. 'platform': 'AKG', 'process': self.processor}
  339. return graph_desc
  340. class GraphVisitor:
  341. """Graph visitor"""
  342. def __init__(self, forward=True, once_mode=True):
  343. self.forward = forward
  344. self.once_mode = once_mode
  345. if self.once_mode:
  346. self.visited = set()
  347. def visit_graph(self, graph):
  348. """Visit graph"""
  349. inputs, outputs = graph.deduce_parameters()
  350. if self.forward:
  351. for tensor in inputs:
  352. for op in tensor.to_ops:
  353. self.visit(op)
  354. else:
  355. for tensor in outputs:
  356. if not tensor.to_ops:
  357. self.visit(tensor.op)
  358. def visit(self, op):
  359. """Visit op"""
  360. next_ops = op.output.to_ops if self.forward else [
  361. t.op for t in op.inputs if t.op is not None]
  362. if self.once_mode:
  363. self.visited.add(op)
  364. for n in next_ops:
  365. if n not in self.visited:
  366. self.visit(n)
  367. else:
  368. for n in next_ops:
  369. self.visit(n)
  370. class AlignShape(GraphVisitor):
  371. """Align shape"""
  372. def __init__(self):
  373. super().__init__(once_mode=False)
  374. def visit(self, op):
  375. prim = PrimLib.get_prim(op)
  376. if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
  377. out_dim = len(op.output.shape)
  378. align_dim = out_dim
  379. for t in op.inputs:
  380. if len(t.shape) > align_dim:
  381. align_dim = len(t.shape)
  382. if align_dim > out_dim:
  383. op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
  384. super().visit(op)
  385. class AddControlBuddy(GraphVisitor):
  386. """Add control buddy"""
  387. def __init__(self):
  388. super().__init__()
  389. self.buddies = {} # {op : [ctrl_op]}
  390. def visit(self, op):
  391. if PrimLib.iter_type(op) == PrimLib.CONTROL:
  392. assert len(op.output.to_ops) == 1
  393. owner = op.output.to_ops[0]
  394. if owner in self.buddies:
  395. self.buddies[owner].append(op)
  396. else:
  397. self.buddies[owner] = [op]
  398. if op in self.buddies:
  399. ops = self.buddies.pop(op)
  400. self.buddies[owner].extend(ops)
  401. super().visit(op)
  402. def visit_graph(self, graph):
  403. super().visit_graph(graph)
  404. for owner in self.buddies:
  405. for op in self.buddies[owner]:
  406. owner.add_buddy(op.output)