You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 19 kB

5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """GraphKernel cost model"""
  16. class Utils:
  17. """Model utils"""
  18. def __init__(self):
  19. pass
  20. @staticmethod
  21. def get_attr_type(attr):
  22. """Get attr type"""
  23. if isinstance(attr, bool):
  24. return 'bool'
  25. if isinstance(attr, str):
  26. return 'str'
  27. if isinstance(attr, int):
  28. return 'int'
  29. if isinstance(attr, float):
  30. return 'bool'
  31. if isinstance(attr, (list, tuple)):
  32. if not attr:
  33. raise ValueError("Length of attr is 0")
  34. if isinstance(attr[0], int):
  35. return 'listInt'
  36. if isinstance(attr[0], str):
  37. return 'listStr'
  38. raise ValueError("Unknown type of attr: {}".format(attr))
  39. class DataFormat:
  40. """DataFormat"""
  41. DEFAULT = "DefaultFormat"
  42. NC1KHKWHWC0 = "NC1KHKWHWC0"
  43. ND = "ND"
  44. NCHW = "NCHW"
  45. NHWC = "NHWC"
  46. HWCN = "HWCN"
  47. NC1HWC0 = "NC1HWC0"
  48. FRAC_Z = "FracZ"
  49. FRAC_NZ = "FRACTAL_NZ"
  50. C1HWNCOC0 = "C1HWNCoC0"
  51. NC1HWC0_C04 = "NC1HWC0_C04"
  52. FRACTAL_Z_C04 = "FRACTAL_Z_C04"
  53. NDHWC = "NDHWC"
  54. def __init__(self):
  55. pass
  56. class DataType:
  57. """Data Type"""
  58. FLOAT = "float"
  59. FLOAT16 = "float16"
  60. FLOAT32 = "float32"
  61. FLOAT64 = "float64"
  62. INT = "int"
  63. INT8 = "int8"
  64. INT16 = "int16"
  65. INT32 = "int32"
  66. INT64 = "int64"
  67. UINT = "uint"
  68. UINT8 = "uint8"
  69. UINT16 = "uint16"
  70. UINT32 = "uint32"
  71. UINT64 = "uint64"
  72. BOOL = "bool"
  73. def __init__(self):
  74. pass
  75. class PrimLib:
  76. """Prim lib"""
  77. UNKNOWN = 0
  78. RESHAPE = 1
  79. ELEMWISE = 2
  80. BROADCAST = 3
  81. REDUCE = 4
  82. OPAQUE = 5
  83. def __init__(self):
  84. pass
  85. class Prim:
  86. """Prim"""
  87. def __init__(self, iter_type, calibrate=1, relation_func=None):
  88. self.iter_type = iter_type
  89. self.calibrate = calibrate
  90. self.relation_func = relation_func
  91. if relation_func is None:
  92. self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
  93. def default_reshape_relation(self, op, input_idx):
  94. """Process reshape relation"""
  95. axis_relation, elem_relation = self.unknown_relation(op, input_idx)
  96. elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
  97. return axis_relation, elem_relation
  98. def default_elemwise_broadcast_relation(self, op, input_idx):
  99. """Process elemwise and broadcast relation"""
  100. out_shape = op.output.shape
  101. in_shape = op.inputs[input_idx].shape
  102. if len(out_shape) < len(in_shape):
  103. raise ValueError("input/output size is abnormal")
  104. axis_relation, elem_relation = [], []
  105. delta = len(out_shape) - len(in_shape)
  106. if delta > 0:
  107. for i in range(0, delta):
  108. axis_relation.append(None)
  109. elem_relation.append(None)
  110. for i, _ in enumerate(in_shape):
  111. axis_relation.append(i)
  112. elem_relation.append(
  113. PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
  114. return axis_relation, elem_relation
  115. def default_reduce_relation(self, op, input_idx):
  116. """Process reduce relation"""
  117. axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
  118. for i in op.attrs['reduce_axis']:
  119. elem_relation[i] = PrimLib.REDUCE
  120. return axis_relation, elem_relation
  121. def unknown_relation(self, op, input_idx):
  122. """Process unknown relation"""
  123. out_shape = op.output.shape
  124. in_shape = op.inputs[input_idx].shape
  125. all_relation = list(range(len(in_shape)))
  126. axis_relation = [all_relation for i in range(0, len(out_shape))]
  127. elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
  128. return axis_relation, elem_relation
  129. default_relation_func = [
  130. unknown_relation,
  131. default_reshape_relation,
  132. default_elemwise_broadcast_relation,
  133. default_elemwise_broadcast_relation,
  134. default_reduce_relation,
  135. unknown_relation,
  136. ]
  137. primtives = {
  138. 'Add': Prim(ELEMWISE),
  139. 'Abs': Prim(ELEMWISE),
  140. 'Neg': Prim(ELEMWISE),
  141. 'Mul': Prim(ELEMWISE),
  142. 'Sub': Prim(ELEMWISE),
  143. 'Log': Prim(ELEMWISE),
  144. 'IsNan': Prim(ELEMWISE),
  145. 'IsInf': Prim(ELEMWISE),
  146. 'IsFinite': Prim(ELEMWISE),
  147. 'Exp': Prim(ELEMWISE),
  148. 'Rsqrt': Prim(ELEMWISE),
  149. 'Sqrt': Prim(ELEMWISE),
  150. 'Div': Prim(ELEMWISE),
  151. 'FloorDiv': Prim(ELEMWISE),
  152. 'RealDiv': Prim(ELEMWISE),
  153. 'Mod': Prim(ELEMWISE),
  154. 'Floor': Prim(ELEMWISE),
  155. 'FloorMod': Prim(ELEMWISE),
  156. 'Erf': Prim(ELEMWISE),
  157. 'Erfc': Prim(ELEMWISE),
  158. 'Cast': Prim(ELEMWISE),
  159. 'Pow': Prim(ELEMWISE),
  160. 'Minimum': Prim(ELEMWISE),
  161. 'Maximum': Prim(ELEMWISE),
  162. 'Reciprocal': Prim(ELEMWISE),
  163. 'Equal': Prim(ELEMWISE),
  164. 'NotEqual': Prim(ELEMWISE),
  165. 'Greater': Prim(ELEMWISE),
  166. 'GreaterEqual': Prim(ELEMWISE),
  167. 'Less': Prim(ELEMWISE),
  168. 'LessEqual': Prim(ELEMWISE),
  169. 'LogicalNot': Prim(ELEMWISE),
  170. 'LogicalAnd': Prim(ELEMWISE),
  171. 'LogicalOr': Prim(ELEMWISE),
  172. 'Square': Prim(ELEMWISE),
  173. 'AddN': Prim(ELEMWISE),
  174. 'Select': Prim(ELEMWISE, 8),
  175. 'ReduceSum': Prim(REDUCE),
  176. 'ReduceMax': Prim(REDUCE),
  177. 'ReduceMin': Prim(REDUCE),
  178. 'Argmax': Prim(REDUCE),
  179. 'Argmin': Prim(REDUCE),
  180. 'Assign': Prim(ELEMWISE),
  181. 'Sign': Prim(ELEMWISE),
  182. 'Sin': Prim(ELEMWISE),
  183. 'Cos': Prim(ELEMWISE),
  184. 'Asin': Prim(ELEMWISE),
  185. 'ACos': Prim(ELEMWISE),
  186. 'Tanh': Prim(ELEMWISE),
  187. 'Asinh': Prim(ELEMWISE),
  188. 'Acosh': Prim(ELEMWISE),
  189. 'InplaceAssign': Prim(ELEMWISE),
  190. '@ReduceInit': Prim(ELEMWISE),
  191. 'Reshape': Prim(RESHAPE),
  192. 'Squeeze': Prim(RESHAPE),
  193. 'Flatten': Prim(RESHAPE),
  194. 'FlattenGrad': Prim(RESHAPE),
  195. 'Transpose': Prim(OPAQUE),
  196. 'Tile': Prim(BROADCAST),
  197. 'BroadcastTo': Prim(BROADCAST),
  198. 'StridedSlice': Prim(OPAQUE),
  199. 'MatMul': Prim(OPAQUE),
  200. 'TransData': Prim(OPAQUE),
  201. 'BatchMatMul': Prim(OPAQUE),
  202. 'UnPadAkg': Prim(OPAQUE),
  203. 'PadAkg': Prim(OPAQUE),
  204. 'Conv2D': Prim(OPAQUE),
  205. 'CReal': Prim(ELEMWISE),
  206. 'CImag': Prim(ELEMWISE),
  207. 'Complex': Prim(ELEMWISE),
  208. 'Atan': Prim(ELEMWISE),
  209. 'Atan2': Prim(ELEMWISE),
  210. 'Expm1': Prim(ELEMWISE),
  211. 'TensorScatterAdd': Prim(OPAQUE),
  212. 'Gather': Prim(OPAQUE),
  213. 'GatherNd': Prim(OPAQUE),
  214. 'UnsortedSegmentSum': Prim(OPAQUE),
  215. 'StandardNormal': Prim(OPAQUE),
  216. 'UserDefined': Prim(OPAQUE),
  217. }
  218. default_primtive = Prim(UNKNOWN)
  219. @classmethod
  220. def get_prim(cls, op):
  221. """Get op primtive"""
  222. prim = cls.primtives.get(op.prim, None)
  223. if prim is None:
  224. print('[WARN] primtive is not registered: ' + op.prim)
  225. prim = cls.default_primtive
  226. return prim
  227. @classmethod
  228. def input_relation(cls, op, input_idx):
  229. """Get op's input_relation according to input_idx"""
  230. return cls.get_prim(op).relation_func(op, input_idx)
  231. @classmethod
  232. def iter_type(cls, op):
  233. """Get op's iter type"""
  234. return cls.get_prim(op).iter_type
  235. @classmethod
  236. def is_reduce(cls, op):
  237. """Check whether op's iter type is reduce"""
  238. return cls.get_prim(op).iter_type == cls.REDUCE
  239. @classmethod
  240. def calibrate_iter_size(cls, op, iter_size):
  241. """Get calibrate_iter_size"""
  242. return cls.get_prim(op).calibrate * iter_size
  243. @classmethod
  244. def dtype_bytes(cls, dtype):
  245. """Get dtype bytes"""
  246. bits, unit = 1, 1
  247. for i in range(len(dtype) - 1, 0, -1):
  248. if dtype[i].isdecimal():
  249. bits += int(dtype[i]) * unit
  250. unit *= 10
  251. else:
  252. break
  253. return bits // 8
  254. @classmethod
  255. def inplace_reuse(cls, op, input_idx, start_axis=0):
  256. """Check whether op is inplace reuse"""
  257. if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
  258. return False
  259. _, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
  260. for i in range(start_axis, len(elem_relation)):
  261. if elem_relation[i] != cls.ELEMWISE:
  262. return False
  263. return True
  264. class Tensor:
  265. """Tensor"""
  266. PARA_NONE = 0
  267. PARA_INPUT = 1
  268. PARA_OUTPUT = 2
  269. class Buddy:
  270. """Buddy"""
  271. def __init__(self, leader):
  272. self.members = [leader]
  273. def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
  274. self.name = name
  275. self.shape = shape
  276. self.dtype = dtype
  277. self.data_format = data_format
  278. self.para_type = para_type
  279. self.op = None
  280. self.to_ops = []
  281. self.buddy = None
  282. def __str__(self):
  283. return self.name + str(list(self.shape))
  284. def __repr__(self):
  285. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  286. def get_size(self):
  287. """Get size"""
  288. size = PrimLib.dtype_bytes(self.dtype)
  289. for i in self.shape:
  290. size *= i
  291. return size
  292. def add_buddy(self, tensor):
  293. """Add buddy"""
  294. if self.buddy is None:
  295. self.buddy = self.Buddy(self)
  296. self.buddy.members.append(tensor)
  297. tensor.buddy = self.buddy
  298. class Value:
  299. """Value"""
  300. def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
  301. self.name = name
  302. self.shape = [1]
  303. self.dtype = dtype
  304. self.value = value
  305. self.data_format = data_format
  306. def __str__(self):
  307. return self.name + str(list(self.shape))
  308. def __repr__(self):
  309. return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
  310. def get_size(self):
  311. """Get size"""
  312. return 1
  313. class Operator:
  314. """Operator"""
  315. def __init__(self, primtive, inputs, output, attrs):
  316. self.prim = primtive
  317. self.inputs = inputs
  318. self.output = output
  319. self.attrs = attrs
  320. for t in inputs:
  321. t.to_ops.append(self)
  322. if output.op is None:
  323. output.op = self
  324. self.all_inputs = [] # include Tensor inputs and Value inputs.
  325. def __str__(self):
  326. args = ', '.join([str(t) for t in self.all_inputs])
  327. expr = "%s = %s.%s(%s) id:%s" % (
  328. str(self.output), self.prim, self.output.dtype, args, id(self))
  329. return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
  330. def __repr__(self):
  331. return str(self)
  332. class Graph:
  333. """Graph"""
  334. def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
  335. self.name = name
  336. self.ops = ops # in topo order, can not use set
  337. self.inputs = []
  338. self.outputs = []
  339. self.stitch_info = stitch_info
  340. self.recompute_ops = recompute_ops
  341. self.processor = ""
  342. def set_processor(self, processor):
  343. """Set processor"""
  344. self.processor = processor
  345. def add(self, ops):
  346. """Add ops"""
  347. if isinstance(ops, Operator):
  348. self.ops.append(ops)
  349. else:
  350. self.ops.extend(ops)
  351. def extract_subgraph(self, graph_name, tensor_names, difference=False):
  352. """Extract subgraph from this graph"""
  353. graph = Graph(graph_name, [])
  354. outputs = set(tensor_names)
  355. if difference:
  356. for op in self.ops:
  357. if op.output.name not in outputs:
  358. graph.add(op)
  359. else:
  360. for op in self.ops:
  361. if op.output.name in outputs:
  362. graph.add(op)
  363. outputs.remove(op.output.name)
  364. for name in outputs:
  365. raise ValueError("invalid input tensor : " + name)
  366. return graph
  367. def deduce_parameters(self):
  368. """Deduce parameters"""
  369. inputs, outputs = [], []
  370. for op in self.ops:
  371. for t in op.inputs:
  372. if t not in inputs and t.op not in self.ops:
  373. inputs.append(t)
  374. if op.output in outputs:
  375. continue
  376. if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
  377. outputs.append(op.output)
  378. continue
  379. if any([succ not in self.ops for succ in op.output.to_ops]):
  380. outputs.append(op.output)
  381. if self.inputs:
  382. inputs = self.inputs
  383. if self.outputs:
  384. outputs = self.outputs
  385. return inputs, outputs
  386. def __str__(self):
  387. inputs, outputs = self.deduce_parameters()
  388. para_str = ', '.join([repr(t) for t in inputs])
  389. out_str = ', '.join([repr(t) for t in outputs])
  390. lines = []
  391. lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
  392. if self.stitch_info:
  393. if self.stitch_info.stitch_ops:
  394. lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
  395. if self.stitch_info.stitch_atomic_ops:
  396. lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
  397. for op in self.ops:
  398. lines.append(' ' + str(op))
  399. lines.append('}')
  400. return '\n'.join(lines)
  401. def __repr__(self):
  402. return str(self)
  403. def dump(self):
  404. """Dump Graph to json"""
  405. attr_name = {'reduce_axis': 'axis'}
  406. inputs, outputs = self.deduce_parameters()
  407. input_desc, output_desc, op_desc = [], [], []
  408. for t in inputs:
  409. input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
  410. 'tensor_name': t.name, 'format': t.data_format}])
  411. for t in outputs:
  412. output_desc.append({'data_type': t.dtype, 'shape': t.shape,
  413. 'tensor_name': t.name, 'format': t.data_format})
  414. for op in self.ops:
  415. attrs, in_desc = [], []
  416. for a in op.attrs:
  417. name = attr_name.get(a, a)
  418. attrs.append(
  419. {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
  420. for t in op.all_inputs:
  421. if isinstance(t, Tensor):
  422. in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
  423. 'tensor_name': t.name, 'format': t.data_format}])
  424. else:
  425. in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
  426. 'tensor_name': t.name, 'format': t.data_format}])
  427. out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
  428. 'tensor_name': op.output.name, 'format': op.output.data_format}]
  429. op_desc.append({'attr': attrs, 'impl_path': '',
  430. 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
  431. graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
  432. 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
  433. 'platform': 'AKG', 'process': self.processor}
  434. if self.stitch_info and self.stitch_info.stitch_ops:
  435. buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
  436. if self.stitch_info.stitch_atomic_ops:
  437. buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
  438. graph_desc['buffer_stitch'] = buffer_stitch
  439. return graph_desc
  440. class GraphVisitor:
  441. """Graph visitor"""
  442. def __init__(self, forward=True):
  443. self.forward = forward
  444. def visit_graph(self, graph):
  445. """Visit graph"""
  446. if self.forward:
  447. for op in graph.ops:
  448. self.visit(op)
  449. else:
  450. for i in range(len(graph.ops)-1, -1, -1):
  451. self.visit(graph.ops[i])
  452. class AlignShape(GraphVisitor):
  453. """Align shape"""
  454. def __init__(self):
  455. super(AlignShape, self).__init__()
  456. def visit(self, op):
  457. """Visit op node"""
  458. prim = PrimLib.get_prim(op)
  459. if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
  460. out_dim = len(op.output.shape)
  461. align_dim = out_dim
  462. for t in op.inputs:
  463. if len(t.shape) > align_dim:
  464. align_dim = len(t.shape)
  465. if align_dim > out_dim:
  466. op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
  467. class AddControlBuddy(GraphVisitor):
  468. """Add control buddy"""
  469. def __init__(self):
  470. super(AddControlBuddy, self).__init__()
  471. self.buddies = {} # {op : [ctrl_op]}
  472. def visit(self, op):
  473. """Visit op node"""
  474. if op.prim == "MakeTuple":
  475. if len(op.output.to_ops) != 1:
  476. raise ValueError("operator's output size is abnormal")
  477. owner = op.output.to_ops[0]
  478. if owner in self.buddies:
  479. self.buddies[owner].append(op)
  480. else:
  481. self.buddies[owner] = [op]
  482. if op in self.buddies:
  483. ops = self.buddies.pop(op)
  484. self.buddies[owner].extend(ops)
  485. def visit_graph(self, graph):
  486. """Visit graph nodes"""
  487. super(AddControlBuddy, self).visit_graph(graph)
  488. for owner in self.buddies:
  489. for op in self.buddies[owner]:
  490. owner.add_buddy(op.output)
  491. class GraphKernelUnsupportedException(Exception):
  492. """"GraphKernel Unsupported Exception"""
  493. def __init__(self, message):
  494. super(GraphKernelUnsupportedException, self).__init__()
  495. self.message = message