You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """GraphKernel cost model"""
  16. class Utils:
  17. """Model utils"""
  18. @staticmethod
  19. def get_attr_type(attr):
  20. """Get attr type"""
  21. if isinstance(attr, bool):
  22. return 'bool'
  23. if isinstance(attr, str):
  24. return 'str'
  25. if isinstance(attr, int):
  26. return 'int'
  27. if isinstance(attr, float):
  28. return 'bool'
  29. if isinstance(attr, (list, tuple)):
  30. if not attr:
  31. raise ValueError("Length of attr is 0")
  32. if isinstance(attr[0], int):
  33. return 'listInt'
  34. if isinstance(attr[0], str):
  35. return 'listStr'
  36. raise ValueError("Unknown type of attr: {}".format(attr))
  37. class DataFormat:
  38. """DataFormat"""
  39. DEFAULT = "DefaultFormat"
  40. NC1KHKWHWC0 = "NC1KHKWHWC0"
  41. ND = "ND"
  42. NCHW = "NCHW"
  43. NHWC = "NHWC"
  44. HWCN = "HWCN"
  45. NC1HWC0 = "NC1HWC0"
  46. FRAC_Z = "FracZ"
  47. FRAC_NZ = "FRACTAL_NZ"
  48. C1HWNCOC0 = "C1HWNCoC0"
  49. NC1HWC0_C04 = "NC1HWC0_C04"
  50. FRACTAL_Z_C04 = "FRACTAL_Z_C04"
  51. NDHWC = "NDHWC"
  52. class DataType:
  53. """Data Type"""
  54. FLOAT = "float"
  55. FLOAT16 = "float16"
  56. FLOAT32 = "float32"
  57. FLOAT64 = "float64"
  58. INT = "int"
  59. INT8 = "int8"
  60. INT16 = "int16"
  61. INT32 = "int32"
  62. INT64 = "int64"
  63. UINT = "uint"
  64. UINT8 = "uint8"
  65. UINT16 = "uint16"
  66. UINT32 = "uint32"
  67. UINT64 = "uint64"
  68. BOOL = "bool"
  69. class Config:
  70. R0 = 8.0
  71. UB_SIZE = 256 * 1024
  72. MAX_BLOCK = 32
  73. class PrimLib:
  74. """Prim lib"""
  75. UNKNOWN = 0
  76. RESHAPE = 1
  77. ELEMWISE = 2
  78. BROADCAST = 3
  79. REDUCE = 4
  80. OPAQUE = 5
  81. class Prim:
  82. """Prim"""
  83. def __init__(self, iter_type, calibrate=1, relation_func=None):
  84. self.iter_type = iter_type
  85. self.calibrate = calibrate
  86. self.relation_func = relation_func
  87. if relation_func is None:
  88. self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
  89. def default_reshape_relation(self, op, input_idx):
  90. axis_relation, elem_relation = self.unknown_relation(op, input_idx)
  91. elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
  92. return axis_relation, elem_relation
  93. def default_elemwise_broadcast_relation(self, op, input_idx):
  94. """Process elemwise and broadcast relation"""
  95. out_shape = op.output.shape
  96. in_shape = op.inputs[input_idx].shape
  97. assert len(out_shape) >= len(in_shape)
  98. axis_relation, elem_relation = [], []
  99. delta = len(out_shape) - len(in_shape)
  100. if delta > 0:
  101. for i in range(0, delta):
  102. axis_relation.append(None)
  103. elem_relation.append(None)
  104. for i, _ in enumerate(in_shape):
  105. axis_relation.append(i)
  106. elem_relation.append(
  107. PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
  108. return axis_relation, elem_relation
  109. def default_reduce_relation(self, op, input_idx):
  110. """Process reduce relation"""
  111. axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
  112. for i in op.attrs['reduce_axis']:
  113. elem_relation[i] = PrimLib.REDUCE
  114. return axis_relation, elem_relation
  115. def unknown_relation(self, op, input_idx):
  116. """Process unknown relation"""
  117. out_shape = op.output.shape
  118. in_shape = op.inputs[input_idx].shape
  119. all_relation = list(range(len(in_shape)))
  120. axis_relation = [all_relation for i in range(0, len(out_shape))]
  121. elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
  122. return axis_relation, elem_relation
  123. default_relation_func = [
  124. unknown_relation,
  125. default_reshape_relation,
  126. default_elemwise_broadcast_relation,
  127. default_elemwise_broadcast_relation,
  128. default_reduce_relation,
  129. unknown_relation,
  130. ]
  131. primtives = {
  132. 'Add': Prim(ELEMWISE),
  133. 'Abs': Prim(ELEMWISE),
  134. 'Neg': Prim(ELEMWISE),
  135. 'Mul': Prim(ELEMWISE),
  136. 'Sub': Prim(ELEMWISE),
  137. 'Log': Prim(ELEMWISE),
  138. 'Exp': Prim(ELEMWISE),
  139. 'Rsqrt': Prim(ELEMWISE),
  140. 'Sqrt': Prim(ELEMWISE),
  141. 'RealDiv': Prim(ELEMWISE),
  142. 'Cast': Prim(ELEMWISE),
  143. 'Pow': Prim(ELEMWISE),
  144. 'Minimum': Prim(ELEMWISE),
  145. 'Maximum': Prim(ELEMWISE),
  146. 'Reciprocal': Prim(ELEMWISE),
  147. 'Equal': Prim(ELEMWISE),
  148. 'Greater': Prim(ELEMWISE),
  149. 'GreaterEqual': Prim(ELEMWISE),
  150. 'Less': Prim(ELEMWISE),
  151. 'LessEqual': Prim(ELEMWISE),
  152. 'Square': Prim(ELEMWISE),
  153. 'AddN': Prim(ELEMWISE),
  154. 'Select': Prim(ELEMWISE, 8),
  155. 'ReduceSum': Prim(REDUCE),
  156. 'ReduceMax': Prim(REDUCE),
  157. 'ReduceMin': Prim(REDUCE),
  158. 'Assign': Prim(ELEMWISE),
  159. 'Tanh': Prim(ELEMWISE),
  160. 'InplaceAssign': Prim(ELEMWISE),
  161. '@ReduceInit': Prim(ELEMWISE),
  162. 'Reshape': Prim(RESHAPE),
  163. 'Squeeze': Prim(RESHAPE),
  164. 'Flatten': Prim(RESHAPE),
  165. 'FlattenGrad': Prim(RESHAPE),
  166. 'Transpose': Prim(OPAQUE),
  167. 'Tile': Prim(BROADCAST),
  168. 'BroadcastTo': Prim(BROADCAST),
  169. 'MatMul': Prim(OPAQUE),
  170. 'TransData': Prim(OPAQUE),
  171. }
  172. default_primtive = Prim(UNKNOWN)
  173. @classmethod
  174. def get_prim(cls, op):
  175. prim = cls.primtives.get(op.prim, None)
  176. if prim is None:
  177. print('[WARN] primtive is not registered: ' + op.prim)
  178. prim = cls.default_primtive
  179. return prim
  180. @classmethod
  181. def input_relation(cls, op, input_idx):
  182. return cls.get_prim(op).relation_func(op, input_idx)
  183. @classmethod
  184. def iter_type(cls, op):
  185. return cls.get_prim(op).iter_type
  186. @classmethod
  187. def is_reduce(cls, op):
  188. return cls.get_prim(op).iter_type == cls.REDUCE
  189. @classmethod
  190. def calibrate_iter_size(cls, op, iter_size):
  191. return cls.get_prim(op).calibrate * iter_size
  192. @classmethod
  193. def dtype_bytes(cls, dtype):
  194. bits, unit = 1, 1
  195. for i in range(len(dtype) - 1, 0, -1):
  196. if dtype[i].isdecimal():
  197. bits += int(dtype[i]) * unit
  198. unit *= 10
  199. else:
  200. break
  201. return bits // 8
  202. @classmethod
  203. def inplace_reuse(cls, op, input_idx, start_axis=0):
  204. if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
  205. return False
  206. _, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
  207. for i in range(start_axis, len(elem_relation)):
  208. if elem_relation[i] != cls.ELEMWISE:
  209. return False
  210. return True
  211. class Tensor:
  212. """Tensor"""
  213. PARA_NONE = 0
  214. PARA_INPUT = 1
  215. PARA_OUTPUT = 2
  216. class Buddy:
  217. def __init__(self, leader):
  218. self.members = [leader]
  219. def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
  220. self.name = name
  221. self.shape = shape
  222. self.dtype = dtype
  223. self.data_format = data_format
  224. self.para_type = para_type
  225. self.op = None
  226. self.to_ops = []
  227. self.buddy = None
  228. def __str__(self):
  229. return self.name + str(list(self.shape))
  230. def __repr__(self):
  231. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  232. def get_size(self):
  233. """Get size"""
  234. size = PrimLib.dtype_bytes(self.dtype)
  235. for i in self.shape:
  236. size *= i
  237. return size
  238. def add_buddy(self, tensor):
  239. """Add buddy"""
  240. if self.buddy is None:
  241. self.buddy = self.Buddy(self)
  242. self.buddy.members.append(tensor)
  243. tensor.buddy = self.buddy
  244. class Value:
  245. """Value"""
  246. def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
  247. self.name = name
  248. self.shape = [1]
  249. self.dtype = dtype
  250. self.value = value
  251. self.data_format = data_format
  252. def __str__(self):
  253. return self.name + str(list(self.shape))
  254. def __repr__(self):
  255. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  256. def get_size(self):
  257. return 1
  258. class Operator:
  259. """Operator"""
  260. def __init__(self, primtive, inputs, output, attrs):
  261. self.prim = primtive
  262. self.inputs = inputs
  263. self.output = output
  264. self.attrs = attrs
  265. for t in inputs:
  266. t.to_ops.append(self)
  267. if output.op is None:
  268. output.op = self
  269. self.all_inputs = [] # include Tensor inputs and Value inputs.
  270. def __str__(self):
  271. args = ', '.join([str(t) for t in self.all_inputs])
  272. expr = "%s = %s.%s(%s)" % (
  273. str(self.output), self.prim, self.output.dtype, args)
  274. return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
  275. def __repr__(self):
  276. return str(self)
  277. class Graph:
  278. """Graph"""
  279. def __init__(self, name, ops, stitch_info=None):
  280. self.name = name
  281. self.ops = ops # in topo order, can not use set
  282. self.inputs = []
  283. self.outputs = []
  284. self.stitch_info = stitch_info
  285. def set_processor(self, processor):
  286. """Set processor"""
  287. self.processor = processor
  288. def add(self, ops):
  289. """Add ops"""
  290. if isinstance(ops, Operator):
  291. self.ops.append(ops)
  292. else:
  293. self.ops.extend(ops)
  294. def extract_subgraph(self, graph_name, tensor_names, difference=False):
  295. """Extract subgraph from this graph"""
  296. graph = Graph(graph_name, [])
  297. outputs = set(tensor_names)
  298. if difference:
  299. for op in self.ops:
  300. if op.output.name not in outputs:
  301. graph.add(op)
  302. else:
  303. for op in self.ops:
  304. if op.output.name in outputs:
  305. graph.add(op)
  306. outputs.remove(op.output.name)
  307. for name in outputs:
  308. raise ValueError("invalid input tensor : " + name)
  309. return graph
  310. def deduce_parameters(self):
  311. """Deduce parameters"""
  312. inputs, outputs = [], []
  313. for op in self.ops:
  314. for t in op.inputs:
  315. if t not in inputs and t.op not in self.ops:
  316. inputs.append(t)
  317. if op.output not in outputs:
  318. if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
  319. outputs.append(op.output)
  320. else:
  321. for d in op.output.to_ops:
  322. if d not in self.ops:
  323. outputs.append(op.output)
  324. break
  325. if self.inputs:
  326. inputs = self.inputs
  327. if self.outputs:
  328. outputs = self.outputs
  329. return inputs, outputs
  330. def __str__(self):
  331. inputs, outputs = self.deduce_parameters()
  332. para_str = ', '.join([repr(t) for t in inputs])
  333. out_str = ', '.join([repr(t) for t in outputs])
  334. lines = []
  335. lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
  336. if self.stitch_info:
  337. if self.stitch_info.stitch_ops:
  338. lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
  339. if self.stitch_info.stitch_atomic_ops:
  340. lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
  341. for op in self.ops:
  342. lines.append(' ' + str(op))
  343. lines.append('}')
  344. return '\n'.join(lines)
  345. def __repr__(self):
  346. return str(self)
  347. def dump(self):
  348. """Dump Graph to json"""
  349. attr_name = {'reduce_axis': 'axis'}
  350. inputs, outputs = self.deduce_parameters()
  351. input_desc, output_desc, op_desc = [], [], []
  352. for t in inputs:
  353. input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
  354. 'tensor_name': t.name, 'format': t.data_format}])
  355. for t in outputs:
  356. output_desc.append({'data_type': t.dtype, 'shape': t.shape,
  357. 'tensor_name': t.name, 'format': t.data_format})
  358. for op in self.ops:
  359. attrs, in_desc = [], []
  360. for a in op.attrs:
  361. name = attr_name.get(a, a)
  362. attrs.append(
  363. {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
  364. for t in op.all_inputs:
  365. if isinstance(t, Tensor):
  366. in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
  367. 'tensor_name': t.name, 'format': t.data_format}])
  368. else:
  369. in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
  370. 'tensor_name': t.name, 'format': t.data_format}])
  371. out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
  372. 'tensor_name': op.output.name, 'format': op.output.data_format}]
  373. op_desc.append({'attr': attrs, 'impl_path': '',
  374. 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
  375. graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
  376. 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
  377. 'platform': 'AKG', 'process': self.processor}
  378. if self.stitch_info and self.stitch_info.stitch_ops:
  379. buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
  380. if self.stitch_info.stitch_atomic_ops:
  381. buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
  382. graph_desc['buffer_stitch'] = buffer_stitch
  383. return graph_desc
  384. class GraphVisitor:
  385. """Graph visitor"""
  386. def __init__(self, forward=True):
  387. self.forward = forward
  388. def visit_graph(self, graph):
  389. """Visit graph"""
  390. if self.forward:
  391. for op in graph.ops:
  392. self.visit(op)
  393. else:
  394. for i in range(len(graph.ops)-1, -1, -1):
  395. self.visit(graph.ops[i])
  396. class AlignShape(GraphVisitor):
  397. """Align shape"""
  398. def __init__(self):
  399. super().__init__()
  400. def visit(self, op):
  401. """Visit op node"""
  402. prim = PrimLib.get_prim(op)
  403. if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
  404. out_dim = len(op.output.shape)
  405. align_dim = out_dim
  406. for t in op.inputs:
  407. if len(t.shape) > align_dim:
  408. align_dim = len(t.shape)
  409. if align_dim > out_dim:
  410. op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
  411. class AddControlBuddy(GraphVisitor):
  412. """Add control buddy"""
  413. def __init__(self):
  414. super().__init__()
  415. self.buddies = {} # {op : [ctrl_op]}
  416. def visit(self, op):
  417. """Visit op node"""
  418. if op.prim == "MakeTuple":
  419. assert len(op.output.to_ops) == 1
  420. owner = op.output.to_ops[0]
  421. if owner in self.buddies:
  422. self.buddies[owner].append(op)
  423. else:
  424. self.buddies[owner] = [op]
  425. if op in self.buddies:
  426. ops = self.buddies.pop(op)
  427. self.buddies[owner].extend(ops)
  428. def visit_graph(self, graph):
  429. """Visit graph nodes"""
  430. super().visit_graph(graph)
  431. for owner in self.buddies:
  432. for op in self.buddies[owner]:
  433. owner.add_buddy(op.output)
  434. class GraphKernelUnsupportedException(Exception):
  435. def __init__(self, message):
  436. super().__init__()
  437. self.message = message