You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_compiler.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Providing multi process compile with json"""
  16. import json
  17. import math
  18. import os
  19. import subprocess
  20. import sys
  21. from multiprocessing import Pool
  22. def _compiletask(platform, *jsons):
  23. """
  24. compile func called in single process
  25. Parameters:
  26. platform: str. AKG platform or TBE platform
  27. *jsons: str. json str contain kernel info, suitable for json compile
  28. api
  29. """
  30. if platform == "AKG":
  31. p = __import__("_akg", globals(), locals(), ['ms'], 0)
  32. func = getattr(p.ms, "compilewithjson")
  33. for json_item in jsons:
  34. res = func(json_item)
  35. if not res:
  36. raise ValueError("Compile error")
  37. if platform == "TBE":
  38. tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
  39. for json_item in jsons:
  40. res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
  41. if res.returncode != 0:
  42. raise ValueError("Tbe compile error")
  43. def compilekernelparallel(jsons, process, waitime):
  44. """
  45. compile kernel use multi processes
  46. Parameters:
  47. jsons: list. json str list contain kernel info
  48. process: int. processes num
  49. waittime: int. max time the function blocked
  50. """
  51. if not isinstance(jsons, list):
  52. raise ValueError("jsons must be a list")
  53. if not isinstance(process, int):
  54. raise ValueError("process must be a num")
  55. if not isinstance(waitime, int):
  56. raise ValueError("waittime must be a num")
  57. jsons_akg = []
  58. jsons_tbe = []
  59. for json_ in jsons:
  60. j = json.loads(json_)
  61. if j["platform"] == "TBE":
  62. jsons_tbe.append(json_)
  63. continue
  64. if j["platform"] == "AKG":
  65. jsons_akg.append(json_)
  66. continue
  67. raise RuntimeError(
  68. "not support this platform {0}".format(j["platform"]))
  69. if jsons_akg:
  70. process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
  71. else:
  72. process_akg = 0
  73. if process_akg == 0 and jsons_akg:
  74. process_akg = 1
  75. process_tbe = process-process_akg
  76. if process_tbe == 0 and jsons_tbe:
  77. process_tbe = 1
  78. raise RuntimeWarning("we add a process for compile more operator")
  79. args = [[] for _ in range(process_akg+process_tbe)]
  80. args_lens = len(args)
  81. for p in range(args_lens):
  82. if p < process_tbe:
  83. args[p].append("TBE")
  84. else:
  85. args[p].append("AKG")
  86. jsons_tbe_lens = len(jsons_tbe)
  87. for p in range(jsons_tbe_lens):
  88. args[p % process_tbe].append(jsons_tbe[p])
  89. jsons_akg_lens = len(jsons_akg)
  90. for p in range(jsons_akg_lens):
  91. args[process-p % process_akg-1].append(jsons_akg[p])
  92. for p in range(args_lens):
  93. args[p] = tuple(args[p])
  94. with Pool(processes=process) as pool:
  95. res = pool.starmap_async(_compiletask, args)
  96. res.get(timeout=waitime)
  97. return True