You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """GraphKernel cost model"""
  16. class Utils:
  17. """Model utils"""
  18. @staticmethod
  19. def get_attr_type(attr):
  20. """Get attr type"""
  21. if isinstance(attr, bool):
  22. return 'bool'
  23. if isinstance(attr, str):
  24. return 'str'
  25. if isinstance(attr, int):
  26. return 'int'
  27. if isinstance(attr, float):
  28. return 'bool'
  29. if isinstance(attr, (list, tuple)):
  30. if not attr:
  31. raise ValueError("Length of attr is 0")
  32. if isinstance(attr[0], int):
  33. return 'listInt'
  34. if isinstance(attr[0], str):
  35. return 'listStr'
  36. raise ValueError("Unknown type of attr: {}".format(attr))
  37. class DataFormat:
  38. """DataFormat"""
  39. DEFAULT = "DefaultFormat"
  40. NC1KHKWHWC0 = "NC1KHKWHWC0"
  41. ND = "ND"
  42. NCHW = "NCHW"
  43. NHWC = "NHWC"
  44. HWCN = "HWCN"
  45. NC1HWC0 = "NC1HWC0"
  46. FRAC_Z = "FracZ"
  47. FRAC_NZ = "FRACTAL_NZ"
  48. C1HWNCOC0 = "C1HWNCoC0"
  49. NC1HWC0_C04 = "NC1HWC0_C04"
  50. FRACTAL_Z_C04 = "FRACTAL_Z_C04"
  51. NDHWC = "NDHWC"
  52. class DataType:
  53. """Data Type"""
  54. FLOAT = "float"
  55. FLOAT16 = "float16"
  56. FLOAT32 = "float32"
  57. FLOAT64 = "float64"
  58. INT = "int"
  59. INT8 = "int8"
  60. INT16 = "int16"
  61. INT32 = "int32"
  62. INT64 = "int64"
  63. UINT = "uint"
  64. UINT8 = "uint8"
  65. UINT16 = "uint16"
  66. UINT32 = "uint32"
  67. UINT64 = "uint64"
  68. BOOL = "bool"
  69. class Config:
  70. R0 = 8.0
  71. UB_SIZE = 256 * 1024
  72. MAX_BLOCK = 32
  73. class PrimLib:
  74. """Prim lib"""
  75. UNKNOWN = 0
  76. RESHAPE = 1
  77. ELEMWISE = 2
  78. BROADCAST = 3
  79. REDUCE = 4
  80. TRANSFORM = 5
  81. CONTROL = 6
  82. class Prim:
  83. """Prim"""
  84. def __init__(self, iter_type, calibrate=1, relation_func=None):
  85. self.iter_type = iter_type
  86. self.calibrate = calibrate
  87. self.relation_func = relation_func
  88. if relation_func is None:
  89. self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
  90. def default_reshape_relation(self, op, input_idx):
  91. axis_relation, elem_relation = self.unknown_relation(op, input_idx)
  92. elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
  93. return axis_relation, elem_relation
  94. def default_elemwise_broadcast_relation(self, op, input_idx):
  95. """Process elemwise and broadcast relation"""
  96. out_shape = op.output.shape
  97. in_shape = op.inputs[input_idx].shape
  98. assert len(out_shape) >= len(in_shape)
  99. axis_relation, elem_relation = [], []
  100. delta = len(out_shape) - len(in_shape)
  101. if delta > 0:
  102. for i in range(0, delta):
  103. axis_relation.append(None)
  104. elem_relation.append(None)
  105. for i, _ in enumerate(in_shape):
  106. axis_relation.append(i)
  107. elem_relation.append(
  108. PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
  109. return axis_relation, elem_relation
  110. def default_reduce_relation(self, op, input_idx):
  111. """Process reduce relation"""
  112. axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
  113. for i in op.attrs['reduce_axis']:
  114. elem_relation[i] = PrimLib.REDUCE
  115. return axis_relation, elem_relation
  116. def unknown_relation(self, op, input_idx):
  117. """Process unknown relation"""
  118. out_shape = op.output.shape
  119. in_shape = op.inputs[input_idx].shape
  120. all_relation = list(range(len(in_shape)))
  121. axis_relation = [all_relation for i in range(0, len(out_shape))]
  122. elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
  123. return axis_relation, elem_relation
  124. default_relation_func = [
  125. unknown_relation,
  126. default_reshape_relation,
  127. default_elemwise_broadcast_relation,
  128. default_elemwise_broadcast_relation,
  129. default_reduce_relation,
  130. unknown_relation,
  131. unknown_relation,
  132. ]
  133. primtives = {
  134. 'Add': Prim(ELEMWISE),
  135. 'Abs': Prim(ELEMWISE),
  136. 'Neg': Prim(ELEMWISE),
  137. 'Mul': Prim(ELEMWISE),
  138. 'Sub': Prim(ELEMWISE),
  139. 'Log': Prim(ELEMWISE),
  140. 'Exp': Prim(ELEMWISE),
  141. 'Rsqrt': Prim(ELEMWISE),
  142. 'Sqrt': Prim(ELEMWISE),
  143. 'RealDiv': Prim(ELEMWISE),
  144. 'Cast': Prim(ELEMWISE),
  145. 'Pow': Prim(ELEMWISE),
  146. 'Minimum': Prim(ELEMWISE),
  147. 'Maximum': Prim(ELEMWISE),
  148. 'Reciprocal': Prim(ELEMWISE),
  149. 'Equal': Prim(ELEMWISE),
  150. 'Greater': Prim(ELEMWISE),
  151. 'GreaterEqual': Prim(ELEMWISE),
  152. 'Less': Prim(ELEMWISE),
  153. 'LessEqual': Prim(ELEMWISE),
  154. 'Square': Prim(ELEMWISE),
  155. 'AddN': Prim(ELEMWISE),
  156. 'Select': Prim(ELEMWISE, 8),
  157. 'ReduceSum': Prim(REDUCE),
  158. 'ReduceMax': Prim(REDUCE),
  159. 'ReduceMin': Prim(REDUCE),
  160. 'MakeTuple': Prim(CONTROL),
  161. 'ControlDepend': Prim(CONTROL),
  162. 'Assign': Prim(ELEMWISE),
  163. 'Tanh': Prim(ELEMWISE),
  164. 'ExpandDims': Prim(RESHAPE),
  165. 'InplaceAssign': Prim(ELEMWISE),
  166. '@ReduceInit': Prim(ELEMWISE),
  167. 'Reshape': Prim(RESHAPE),
  168. 'Squeeze': Prim(RESHAPE),
  169. 'Flatten': Prim(RESHAPE),
  170. 'FlattenGrad': Prim(RESHAPE),
  171. 'Transpose': Prim(TRANSFORM),
  172. 'Tile': Prim(BROADCAST),
  173. 'BroadcastTo': Prim(BROADCAST),
  174. }
  175. default_primtive = Prim(UNKNOWN)
  176. @classmethod
  177. def get_prim(cls, op):
  178. prim = cls.primtives.get(op.prim, None)
  179. if prim is None:
  180. print('[WARN] primtive is not registered: ' + op.prim)
  181. prim = cls.default_primtive
  182. return prim
  183. @classmethod
  184. def input_relation(cls, op, input_idx):
  185. return cls.get_prim(op).relation_func(op, input_idx)
  186. @classmethod
  187. def iter_type(cls, op):
  188. return cls.get_prim(op).iter_type
  189. @classmethod
  190. def is_reduce(cls, op):
  191. return cls.get_prim(op).iter_type == cls.REDUCE
  192. @classmethod
  193. def calibrate_iter_size(cls, op, iter_size):
  194. return cls.get_prim(op).calibrate * iter_size
  195. @classmethod
  196. def dtype_bytes(cls, dtype):
  197. bits, unit = 1, 1
  198. for i in range(len(dtype) - 1, 0, -1):
  199. if dtype[i].isdecimal():
  200. bits += int(dtype[i]) * unit
  201. unit *= 10
  202. else:
  203. break
  204. return bits // 8
  205. @classmethod
  206. def inplace_reuse(cls, op, input_idx, start_axis=0):
  207. if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
  208. return False
  209. _, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
  210. for i in range(start_axis, len(elem_relation)):
  211. if elem_relation[i] != cls.ELEMWISE:
  212. return False
  213. return True
  214. class Tensor:
  215. """Tensor"""
  216. PARA_NONE = 0
  217. PARA_INPUT = 1
  218. PARA_OUTPUT = 2
  219. class Buddy:
  220. def __init__(self, leader):
  221. self.members = [leader]
  222. def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
  223. self.name = name
  224. self.shape = shape
  225. self.dtype = dtype
  226. self.data_format = data_format
  227. self.para_type = para_type
  228. self.op = None
  229. self.to_ops = []
  230. self.buddy = None
  231. def __str__(self):
  232. return self.name + str(list(self.shape))
  233. def __repr__(self):
  234. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  235. def get_size(self):
  236. """Get size"""
  237. size = PrimLib.dtype_bytes(self.dtype)
  238. for i in self.shape:
  239. size *= i
  240. return size
  241. def add_buddy(self, tensor):
  242. """Add buddy"""
  243. if self.buddy is None:
  244. self.buddy = self.Buddy(self)
  245. self.buddy.members.append(tensor)
  246. tensor.buddy = self.buddy
  247. class Value:
  248. """Value"""
  249. def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
  250. self.name = name
  251. self.shape = [1]
  252. self.dtype = dtype
  253. self.value = value
  254. self.data_format = data_format
  255. def __str__(self):
  256. return self.name + str(list(self.shape))
  257. def __repr__(self):
  258. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  259. def get_size(self):
  260. return 1
  261. class Operator:
  262. """Operator"""
  263. def __init__(self, primtive, inputs, output, attrs):
  264. self.prim = primtive
  265. self.inputs = inputs
  266. self.output = output
  267. self.attrs = attrs
  268. for t in inputs:
  269. t.to_ops.append(self)
  270. if output.op is None:
  271. output.op = self
  272. self.all_inputs = [] # include Tensor inputs and Value inputs.
  273. def __str__(self):
  274. args = ', '.join([str(t) for t in self.all_inputs])
  275. expr = "%s = %s.%s(%s)" % (
  276. str(self.output), self.prim, self.output.dtype, args)
  277. return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
  278. def __repr__(self):
  279. return str(self)
  280. class Graph:
  281. """Graph"""
  282. def __init__(self, name, ops, stitch_info=None):
  283. self.name = name
  284. self.ops = ops # in topo order, can not use set
  285. self.inputs = []
  286. self.outputs = []
  287. self.stitch_info = stitch_info
  288. def set_processor(self, processor):
  289. """Set processor"""
  290. self.processor = processor
  291. def add(self, ops):
  292. """Add ops"""
  293. if isinstance(ops, Operator):
  294. self.ops.append(ops)
  295. else:
  296. self.ops.extend(ops)
  297. def extract_subgraph(self, graph_name, tensor_names, difference=False):
  298. """Extract subgraph from this graph"""
  299. graph = Graph(graph_name, [])
  300. outputs = set(tensor_names)
  301. if difference:
  302. for op in self.ops:
  303. if op.output.name not in outputs:
  304. graph.add(op)
  305. else:
  306. for op in self.ops:
  307. if op.output.name in outputs:
  308. graph.add(op)
  309. outputs.remove(op.output.name)
  310. for name in outputs:
  311. raise ValueError("invalid input tensor : " + name)
  312. return graph
  313. def deduce_parameters(self):
  314. """Deduce parameters"""
  315. inputs, outputs = [], []
  316. for op in self.ops:
  317. for t in op.inputs:
  318. if t not in inputs and t.op not in self.ops:
  319. inputs.append(t)
  320. if op.output not in outputs:
  321. if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
  322. outputs.append(op.output)
  323. else:
  324. for d in op.output.to_ops:
  325. if d not in self.ops:
  326. outputs.append(op.output)
  327. break
  328. if self.inputs:
  329. inputs = self.inputs
  330. if self.outputs:
  331. outputs = self.outputs
  332. return inputs, outputs
  333. def __str__(self):
  334. inputs, outputs = self.deduce_parameters()
  335. para_str = ', '.join([repr(t) for t in inputs])
  336. out_str = ', '.join([repr(t) for t in outputs])
  337. lines = []
  338. lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
  339. if self.stitch_info:
  340. if self.stitch_info.stitch_ops:
  341. lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
  342. if self.stitch_info.stitch_atomic_ops:
  343. lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
  344. for op in self.ops:
  345. lines.append(' ' + str(op))
  346. lines.append('}')
  347. return '\n'.join(lines)
  348. def __repr__(self):
  349. return str(self)
  350. def dump(self):
  351. """Dump Graph to json"""
  352. attr_name = {'reduce_axis': 'axis'}
  353. inputs, outputs = self.deduce_parameters()
  354. input_desc, output_desc, op_desc = [], [], []
  355. for t in inputs:
  356. input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
  357. 'tensor_name': t.name, 'format': t.data_format}])
  358. for t in outputs:
  359. output_desc.append({'data_type': t.dtype, 'shape': t.shape,
  360. 'tensor_name': t.name, 'format': t.data_format})
  361. for op in self.ops:
  362. attrs, in_desc = [], []
  363. for a in op.attrs:
  364. name = attr_name.get(a, a)
  365. attrs.append(
  366. {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
  367. for t in op.all_inputs:
  368. if isinstance(t, Tensor):
  369. in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
  370. 'tensor_name': t.name, 'format': t.data_format}])
  371. else:
  372. in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
  373. 'tensor_name': t.name, 'format': t.data_format}])
  374. out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
  375. 'tensor_name': op.output.name, 'format': op.output.data_format}]
  376. op_desc.append({'attr': attrs, 'impl_path': '',
  377. 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
  378. graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
  379. 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
  380. 'platform': 'AKG', 'process': self.processor}
  381. if self.stitch_info and self.stitch_info.stitch_ops:
  382. buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
  383. if self.stitch_info.stitch_atomic_ops:
  384. buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
  385. graph_desc['buffer_stitch'] = buffer_stitch
  386. return graph_desc
  387. class GraphVisitor:
  388. """Graph visitor"""
  389. def __init__(self, forward=True, once_mode=True):
  390. self.forward = forward
  391. self.once_mode = once_mode
  392. if self.once_mode:
  393. self.visited = set()
  394. def visit_graph(self, graph):
  395. """Visit graph"""
  396. inputs, outputs = graph.deduce_parameters()
  397. if self.forward:
  398. for tensor in inputs:
  399. for op in tensor.to_ops:
  400. self.visit(op)
  401. else:
  402. for tensor in outputs:
  403. if not tensor.to_ops:
  404. self.visit(tensor.op)
  405. def visit(self, op):
  406. """Visit op"""
  407. next_ops = op.output.to_ops if self.forward else [
  408. t.op for t in op.inputs if t.op is not None]
  409. if self.once_mode:
  410. self.visited.add(op)
  411. for n in next_ops:
  412. if n not in self.visited:
  413. self.visit(n)
  414. else:
  415. for n in next_ops:
  416. self.visit(n)
  417. class AlignShape(GraphVisitor):
  418. """Align shape"""
  419. def __init__(self):
  420. super().__init__(once_mode=False)
  421. def visit(self, op):
  422. prim = PrimLib.get_prim(op)
  423. if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
  424. out_dim = len(op.output.shape)
  425. align_dim = out_dim
  426. for t in op.inputs:
  427. if len(t.shape) > align_dim:
  428. align_dim = len(t.shape)
  429. if align_dim > out_dim:
  430. op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
  431. super().visit(op)
  432. class AddControlBuddy(GraphVisitor):
  433. """Add control buddy"""
  434. def __init__(self):
  435. super().__init__()
  436. self.buddies = {} # {op : [ctrl_op]}
  437. def visit(self, op):
  438. if PrimLib.iter_type(op) == PrimLib.CONTROL:
  439. assert len(op.output.to_ops) == 1
  440. owner = op.output.to_ops[0]
  441. if owner in self.buddies:
  442. self.buddies[owner].append(op)
  443. else:
  444. self.buddies[owner] = [op]
  445. if op in self.buddies:
  446. ops = self.buddies.pop(op)
  447. self.buddies[owner].extend(ops)
  448. super().visit(op)
  449. def visit_graph(self, graph):
  450. super().visit_graph(graph)
  451. for owner in self.buddies:
  452. for op in self.buddies[owner]:
  453. owner.add_buddy(op.output)
  454. class GraphKernelUnsupportedException(Exception):
  455. def __init__(self, message):
  456. super().__init__()
  457. self.message = message