You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """GraphKernel cost model"""
  16. class Utils:
  17. """Model utils"""
  18. @staticmethod
  19. def get_attr_type(attr):
  20. """Get attr type"""
  21. if isinstance(attr, bool):
  22. return 'bool'
  23. if isinstance(attr, str):
  24. return 'str'
  25. if isinstance(attr, int):
  26. return 'int'
  27. if isinstance(attr, float):
  28. return 'bool'
  29. if isinstance(attr, (list, tuple)):
  30. if not attr:
  31. raise ValueError("Length of attr is 0")
  32. if isinstance(attr[0], int):
  33. return 'listInt'
  34. if isinstance(attr[0], str):
  35. return 'listStr'
  36. raise ValueError("Unknown type of attr: {}".format(attr))
  37. class DataFormat:
  38. """DataFormat"""
  39. DEFAULT = "DefaultFormat"
  40. NC1KHKWHWC0 = "NC1KHKWHWC0"
  41. ND = "ND"
  42. NCHW = "NCHW"
  43. NHWC = "NHWC"
  44. HWCN = "HWCN"
  45. NC1HWC0 = "NC1HWC0"
  46. FRAC_Z = "FracZ"
  47. FRAC_NZ = "FRACTAL_NZ"
  48. C1HWNCOC0 = "C1HWNCoC0"
  49. NC1HWC0_C04 = "NC1HWC0_C04"
  50. FRACTAL_Z_C04 = "FRACTAL_Z_C04"
  51. NDHWC = "NDHWC"
  52. class Config:
  53. R0 = 8.0
  54. UB_SIZE = 256 * 1024
  55. MAX_BLOCK = 32
  56. class PrimLib:
  57. """Prim lib"""
  58. UNKNOWN = 0
  59. RESHAPE = 1
  60. ELEMWISE = 2
  61. BROADCAST = 3
  62. REDUCE = 4
  63. TRANSFORM = 5
  64. CONTROL = 6
  65. class Prim:
  66. """Prim"""
  67. def __init__(self, iter_type, calibrate=1, relation_func=None):
  68. self.iter_type = iter_type
  69. self.calibrate = calibrate
  70. self.relation_func = relation_func
  71. if relation_func is None:
  72. self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
  73. def default_reshape_relation(self, op, input_idx):
  74. axis_relation, elem_relation = self.unknown_relation(op, input_idx)
  75. elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
  76. return axis_relation, elem_relation
  77. def default_elemwise_broadcast_relation(self, op, input_idx):
  78. """Process elemwise and broadcast relation"""
  79. out_shape = op.output.shape
  80. in_shape = op.inputs[input_idx].shape
  81. assert len(out_shape) >= len(in_shape)
  82. axis_relation, elem_relation = [], []
  83. delta = len(out_shape) - len(in_shape)
  84. if delta > 0:
  85. for i in range(0, delta):
  86. axis_relation.append(None)
  87. elem_relation.append(None)
  88. for i, _ in enumerate(in_shape):
  89. axis_relation.append(i)
  90. elem_relation.append(
  91. PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
  92. return axis_relation, elem_relation
  93. def default_reduce_relation(self, op, input_idx):
  94. """Process reduce relation"""
  95. axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
  96. for i in op.attrs['reduce_axis']:
  97. elem_relation[i] = PrimLib.REDUCE
  98. return axis_relation, elem_relation
  99. def unknown_relation(self, op, input_idx):
  100. """Process unknown relation"""
  101. out_shape = op.output.shape
  102. in_shape = op.inputs[input_idx].shape
  103. all_relation = list(range(len(in_shape)))
  104. axis_relation = [all_relation for i in range(0, len(out_shape))]
  105. elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
  106. return axis_relation, elem_relation
  107. default_relation_func = [
  108. unknown_relation,
  109. default_reshape_relation,
  110. default_elemwise_broadcast_relation,
  111. default_elemwise_broadcast_relation,
  112. default_reduce_relation,
  113. unknown_relation,
  114. unknown_relation,
  115. ]
  116. primtives = {
  117. 'TensorAdd': Prim(ELEMWISE),
  118. 'Abs': Prim(ELEMWISE),
  119. 'Neg': Prim(ELEMWISE),
  120. 'Mul': Prim(ELEMWISE),
  121. 'Sub': Prim(ELEMWISE),
  122. 'Log': Prim(ELEMWISE),
  123. 'Exp': Prim(ELEMWISE),
  124. 'Rsqrt': Prim(ELEMWISE),
  125. 'Sqrt': Prim(ELEMWISE),
  126. 'RealDiv': Prim(ELEMWISE),
  127. 'Cast': Prim(ELEMWISE),
  128. 'Pow': Prim(ELEMWISE),
  129. 'Minimum': Prim(ELEMWISE),
  130. 'Maximum': Prim(ELEMWISE),
  131. 'Reciprocal': Prim(ELEMWISE),
  132. 'Equal': Prim(ELEMWISE),
  133. 'Greater': Prim(ELEMWISE),
  134. 'GreaterEqual': Prim(ELEMWISE),
  135. 'Less': Prim(ELEMWISE),
  136. 'LessEqual': Prim(ELEMWISE),
  137. 'Square': Prim(ELEMWISE),
  138. 'AddN': Prim(ELEMWISE),
  139. 'Select': Prim(ELEMWISE, 8),
  140. 'ReduceSum': Prim(REDUCE),
  141. 'ReduceMax': Prim(REDUCE),
  142. 'ReduceMin': Prim(REDUCE),
  143. 'make_tuple': Prim(CONTROL),
  144. 'ControlDepend': Prim(CONTROL),
  145. 'Assign': Prim(ELEMWISE),
  146. 'Tanh': Prim(ELEMWISE),
  147. 'ExpandDims': Prim(RESHAPE),
  148. 'InplaceAssign': Prim(ELEMWISE),
  149. '@ReduceInit': Prim(ELEMWISE),
  150. 'Reshape': Prim(RESHAPE),
  151. 'Squeeze': Prim(RESHAPE),
  152. 'Flatten': Prim(RESHAPE),
  153. 'FlattenGrad': Prim(RESHAPE),
  154. 'Transpose': Prim(TRANSFORM),
  155. 'Tile': Prim(BROADCAST),
  156. 'BroadcastTo': Prim(BROADCAST),
  157. }
  158. default_primtive = Prim(UNKNOWN)
  159. @classmethod
  160. def get_prim(cls, op):
  161. prim = cls.primtives.get(op.prim, None)
  162. if prim is None:
  163. print('[WARN] primtive is not registered: ' + op.prim)
  164. prim = cls.default_primtive
  165. return prim
  166. @classmethod
  167. def input_relation(cls, op, input_idx):
  168. return cls.get_prim(op).relation_func(op, input_idx)
  169. @classmethod
  170. def iter_type(cls, op):
  171. return cls.get_prim(op).iter_type
  172. @classmethod
  173. def is_reduce(cls, op):
  174. return cls.get_prim(op).iter_type == cls.REDUCE
  175. @classmethod
  176. def calibrate_iter_size(cls, op, iter_size):
  177. return cls.get_prim(op).calibrate * iter_size
  178. @classmethod
  179. def dtype_bytes(cls, dtype):
  180. bits, unit = 1, 1
  181. for i in range(len(dtype) - 1, 0, -1):
  182. if dtype[i].isdecimal():
  183. bits += int(dtype[i]) * unit
  184. unit *= 10
  185. else:
  186. break
  187. return bits // 8
  188. @classmethod
  189. def inplace_reuse(cls, op, input_idx, start_axis=0):
  190. if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
  191. return False
  192. _, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
  193. for i in range(start_axis, len(elem_relation)):
  194. if elem_relation[i] != cls.ELEMWISE:
  195. return False
  196. return True
  197. class Tensor:
  198. """Tensor"""
  199. PARA_NONE = 0
  200. PARA_INPUT = 1
  201. PARA_OUTPUT = 2
  202. class Buddy:
  203. def __init__(self, leader):
  204. self.members = [leader]
  205. def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
  206. self.name = name
  207. self.shape = shape
  208. self.dtype = dtype
  209. self.data_format = data_format
  210. self.para_type = para_type
  211. self.op = None
  212. self.to_ops = []
  213. self.buddy = None
  214. def __str__(self):
  215. return self.name + str(list(self.shape))
  216. def __repr__(self):
  217. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  218. def get_size(self):
  219. """Get size"""
  220. size = PrimLib.dtype_bytes(self.dtype)
  221. for i in self.shape:
  222. size *= i
  223. return size
  224. def add_buddy(self, tensor):
  225. """Add buddy"""
  226. if self.buddy is None:
  227. self.buddy = self.Buddy(self)
  228. self.buddy.members.append(tensor)
  229. tensor.buddy = self.buddy
  230. class Value:
  231. """Value"""
  232. def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
  233. self.name = name
  234. self.shape = [1]
  235. self.dtype = dtype
  236. self.value = value
  237. self.data_format = data_format
  238. def __str__(self):
  239. return self.name + str(list(self.shape))
  240. def __repr__(self):
  241. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  242. def get_size(self):
  243. return 1
  244. class Operator:
  245. """Operator"""
  246. def __init__(self, primtive, inputs, output, attrs):
  247. self.prim = primtive
  248. self.inputs = inputs
  249. self.output = output
  250. self.attrs = attrs
  251. for t in inputs:
  252. t.to_ops.append(self)
  253. if output.op is None:
  254. output.op = self
  255. self.all_inputs = [] # include Tensor inputs and Value inputs.
  256. def __str__(self):
  257. args = ', '.join([str(t) for t in self.all_inputs])
  258. expr = "%s = %s.%s(%s)" % (
  259. str(self.output), self.prim, self.output.dtype, args)
  260. return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
  261. def __repr__(self):
  262. return str(self)
  263. class Graph:
  264. """Graph"""
  265. def __init__(self, name, ops):
  266. self.name = name
  267. self.ops = ops # in topo order, can not use set
  268. self.inputs = []
  269. self.outputs = []
  270. def set_processor(self, processor):
  271. """Set processor"""
  272. self.processor = processor
  273. def add(self, ops):
  274. """Add ops"""
  275. if isinstance(ops, Operator):
  276. self.ops.append(ops)
  277. else:
  278. self.ops.extend(ops)
  279. def extract_subgraph(self, graph_name, tensor_names, difference=False):
  280. """Extract subgraph from this graph"""
  281. graph = Graph(graph_name, [])
  282. outputs = set(tensor_names)
  283. if difference:
  284. for op in self.ops:
  285. if op.output.name not in outputs:
  286. graph.add(op)
  287. else:
  288. for op in self.ops:
  289. if op.output.name in outputs:
  290. graph.add(op)
  291. outputs.remove(op.output.name)
  292. for name in outputs:
  293. raise ValueError("invalid input tensor : " + name)
  294. return graph
  295. def deduce_parameters(self):
  296. """Deduce parameters"""
  297. inputs, outputs = [], []
  298. for op in self.ops:
  299. for t in op.inputs:
  300. if t not in inputs and t.op not in self.ops:
  301. inputs.append(t)
  302. if op.output not in outputs:
  303. if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
  304. outputs.append(op.output)
  305. else:
  306. for d in op.output.to_ops:
  307. if d not in self.ops:
  308. outputs.append(op.output)
  309. break
  310. if self.inputs:
  311. inputs = self.inputs
  312. if self.outputs:
  313. outputs = self.outputs
  314. return inputs, outputs
  315. def __str__(self):
  316. inputs, outputs = self.deduce_parameters()
  317. para_str = ', '.join([repr(t) for t in inputs])
  318. out_str = ', '.join([repr(t) for t in outputs])
  319. lines = []
  320. lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
  321. for op in self.ops:
  322. lines.append(' ' + str(op))
  323. lines.append('}')
  324. return '\n'.join(lines)
  325. def __repr__(self):
  326. return str(self)
  327. def dump(self):
  328. """Dump Graph to json"""
  329. attr_name = {'reduce_axis': 'axis'}
  330. inputs, outputs = self.deduce_parameters()
  331. input_desc, output_desc, op_desc = [], [], []
  332. for t in inputs:
  333. input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
  334. 'tensor_name': t.name, 'format': t.data_format}])
  335. for t in outputs:
  336. output_desc.append({'data_type': t.dtype, 'shape': t.shape,
  337. 'tensor_name': t.name, 'format': t.data_format})
  338. for op in self.ops:
  339. attrs, in_desc = [], []
  340. for a in op.attrs:
  341. name = attr_name.get(a, a)
  342. attrs.append(
  343. {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
  344. for t in op.all_inputs:
  345. if isinstance(t, Tensor):
  346. in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
  347. 'tensor_name': t.name, 'format': t.data_format}])
  348. else:
  349. in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
  350. 'tensor_name': t.name, 'format': t.data_format}])
  351. out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
  352. 'tensor_name': op.output.name, 'format': t.data_format}]
  353. op_desc.append({'attr': attrs, 'impl_path': '',
  354. 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
  355. graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
  356. 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
  357. 'platform': 'AKG', 'process': self.processor}
  358. return graph_desc
  359. class GraphVisitor:
  360. """Graph visitor"""
  361. def __init__(self, forward=True, once_mode=True):
  362. self.forward = forward
  363. self.once_mode = once_mode
  364. if self.once_mode:
  365. self.visited = set()
  366. def visit_graph(self, graph):
  367. """Visit graph"""
  368. inputs, outputs = graph.deduce_parameters()
  369. if self.forward:
  370. for tensor in inputs:
  371. for op in tensor.to_ops:
  372. self.visit(op)
  373. else:
  374. for tensor in outputs:
  375. if not tensor.to_ops:
  376. self.visit(tensor.op)
  377. def visit(self, op):
  378. """Visit op"""
  379. next_ops = op.output.to_ops if self.forward else [
  380. t.op for t in op.inputs if t.op is not None]
  381. if self.once_mode:
  382. self.visited.add(op)
  383. for n in next_ops:
  384. if n not in self.visited:
  385. self.visit(n)
  386. else:
  387. for n in next_ops:
  388. self.visit(n)
  389. class AlignShape(GraphVisitor):
  390. """Align shape"""
  391. def __init__(self):
  392. super().__init__(once_mode=False)
  393. def visit(self, op):
  394. prim = PrimLib.get_prim(op)
  395. if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
  396. out_dim = len(op.output.shape)
  397. align_dim = out_dim
  398. for t in op.inputs:
  399. if len(t.shape) > align_dim:
  400. align_dim = len(t.shape)
  401. if align_dim > out_dim:
  402. op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
  403. super().visit(op)
  404. class AddControlBuddy(GraphVisitor):
  405. """Add control buddy"""
  406. def __init__(self):
  407. super().__init__()
  408. self.buddies = {} # {op : [ctrl_op]}
  409. def visit(self, op):
  410. if PrimLib.iter_type(op) == PrimLib.CONTROL:
  411. assert len(op.output.to_ops) == 1
  412. owner = op.output.to_ops[0]
  413. if owner in self.buddies:
  414. self.buddies[owner].append(op)
  415. else:
  416. self.buddies[owner] = [op]
  417. if op in self.buddies:
  418. ops = self.buddies.pop(op)
  419. self.buddies[owner].extend(ops)
  420. super().visit(op)
  421. def visit_graph(self, graph):
  422. super().visit_graph(graph)
  423. for owner in self.buddies:
  424. for op in self.buddies[owner]:
  425. owner.add_buddy(op.output)