You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel_split.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """graph kernel split"""
  16. import json
  17. import getopt
  18. import sys
  19. import model
  20. def print_usage():
  21. print('Usage: graph_kernel_split.py [OPTION] <JSON_FILE>')
  22. print('Options:')
  23. print(' -s <config/auto>\tsplit graph with config')
  24. print(' -e \t\testimate graph')
  25. print(' -i \t\tnaive estimate')
  26. print(' -o <prefix>\toutput split graphs')
  27. print(' -v \t\tverbose mode')
  28. print(' -h \t\tprint this help')
  29. class Option:
  30. """Options"""
  31. def __init__(self):
  32. self.split = None
  33. self.estimate = False
  34. self.estimate_naive = False
  35. self.output = None
  36. self.verbose = False
  37. self.help = False
  38. def parse(self, options):
  39. """parse options"""
  40. for name, val in options:
  41. if name == '-h':
  42. self.help = True
  43. elif name == '-v':
  44. self.verbose = True
  45. elif name == '-o':
  46. self.output = val
  47. elif name == '-e':
  48. self.estimate = True
  49. elif name == '-s':
  50. self.split = val
  51. elif name == '-i':
  52. self.estimate_naive = True
  53. opt = Option()
  54. def estimate(graph_in, parts_in, naive):
  55. """estimate graphs costs"""
  56. def _print_cost(name, c):
  57. print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
  58. (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
  59. main_cost, _ = model.estimate(graph_in, naive)
  60. split_cost, sub_costs = model.estimate(parts_in, naive) if parts_in else (None, None)
  61. _print_cost("MainGraph:", main_cost)
  62. if parts_in:
  63. _print_cost("Subgraphs:", split_cost)
  64. if opt.verbose:
  65. for i, sub_cost in enumerate(sub_costs):
  66. _print_cost(" |_%d:\t" % (i), sub_cost)
  67. def split_graph(graph_in, config):
  68. """split graph"""
  69. if config == 'auto':
  70. return model.split(graph_in)
  71. subgraphs = []
  72. all_tensors = []
  73. subgraph_idx = 0
  74. config_parts = config.split('|')
  75. for part in config_parts:
  76. tensor_names = part.split(',')
  77. graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
  78. g = graph_in.extract_subgraph(graph_name, tensor_names)
  79. assert len(g.ops) == len(tensor_names)
  80. subgraphs.append(g)
  81. all_tensors += tensor_names
  82. subgraph_idx += 1
  83. if len(all_tensors) < len(graph_in.ops):
  84. graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
  85. g = graph_in.extract_subgraph(graph_name, all_tensors, True)
  86. subgraphs.append(g)
  87. return subgraphs
  88. def main():
  89. opts, args = getopt.getopt(sys.argv[1:], 'heivo:s:')
  90. opt.parse(opts)
  91. if len(args) != 1 or opt.help:
  92. print_usage()
  93. sys.exit(0)
  94. in_file = args[0]
  95. with open(in_file, 'r') as f:
  96. desc = json.loads(f.read())
  97. comp = model.load_composite(desc)
  98. graph = comp.graph
  99. parts = []
  100. # 1. split sub-graphs
  101. if opt.split is not None:
  102. parts = split_graph(graph, opt.split)
  103. if opt.verbose:
  104. print('----------- main graph --------------')
  105. print(graph)
  106. for i, _ in enumerate(parts):
  107. print('---------------- sub graph %d ---------------' % (i))
  108. print(parts[i])
  109. # 2. estimate cost
  110. if opt.estimate:
  111. print('------------- cost --------------')
  112. estimate(graph, parts, False)
  113. if opt.estimate_naive:
  114. print('------------- naive cost --------------')
  115. estimate(graph, parts, True)
  116. # 3. output parts
  117. if opt.output is not None:
  118. for graph_part in parts:
  119. desc = comp.dump(graph_part)
  120. s_desc = json.dumps(desc)
  121. fname = "%s_%s.json" % (opt.output, graph_part.name)
  122. with open(fname, 'w', encoding='utf-8') as of:
  123. of.write(s_desc)
  124. if __name__ == '__main__':
  125. main()