You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_build_server.py 5.1 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """kernel build server"""
  16. import os
  17. from mindspore import log as logger
  18. from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
  19. class Messager:
  20. '''Messager'''
  21. def __init__(self, fdin, fdout):
  22. self.fdin = fdin
  23. self.fdout = fdout
  24. self.fin = os.fdopen(fdin, "r")
  25. self.fout = os.fdopen(fdout, "w")
  26. self.message = ''
  27. def __del__(self):
  28. os.close(self.fdin)
  29. os.close(self.fdout)
  30. def get_message(self):
  31. """
  32. Get message from remote
  33. Returns:
  34. message
  35. """
  36. try:
  37. # Not read by input() anymore
  38. res = self.fin.readline()
  39. if not res:
  40. logger.debug("[TRACE] read nothing...")
  41. self.exit()
  42. if res[len(res) - 1] == '\n':
  43. res = res[0:len(res) - 1]
  44. self.message = res
  45. logger.debug(f"[IN] {self.message}")
  46. except (EOFError, KeyboardInterrupt):
  47. self.exit()
  48. finally:
  49. pass
  50. if self.message == '' or self.message == 'FINISH':
  51. self.send_ack()
  52. self.exit()
  53. return self.message
  54. def send_res(self, res, keep_format=True):
  55. """
  56. Send result to remote
  57. Args:
  58. keep_format: True or False
  59. """
  60. logger.debug(f"[OUT] {str(res)}")
  61. if keep_format:
  62. res_str = str(res).replace('\n', '[LF]').replace('\r', '[CR]').replace(' ', '[SP]')
  63. else:
  64. res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '')
  65. tag = '[~]' # The same as client kTAG
  66. # Not write by print(tag + res_str, flush=True) any more
  67. try:
  68. self.fout.write(tag + res_str + "\n")
  69. self.fout.flush()
  70. except BrokenPipeError as err:
  71. logger.info(f"[TRACE] Write {str(err)}")
  72. self.exit()
  73. finally:
  74. pass
  75. def send_ack(self, success=True):
  76. """
  77. Send ack to remote
  78. Args:
  79. success: True or False
  80. """
  81. if success:
  82. self.send_res('ACK')
  83. else:
  84. self.send_res('ERR')
  85. def loop(self):
  86. """
  87. Messaging loop
  88. """
  89. while True:
  90. self.handle()
  91. def run(self):
  92. self.loop()
  93. def handle(self):
  94. """
  95. A interface communicates with remote.
  96. Note:
  97. All subclasses should override this interface.
  98. """
  99. raise NotImplementedError
  100. def exit(self):
  101. """
  102. A interface handles the procedure before exit.
  103. Note:
  104. All subclasses should override this interface.
  105. """
  106. raise NotImplementedError
  107. class AkgBuilder():
  108. """Akg building wrapper"""
  109. def __init__(self, platform):
  110. self.platform = platform
  111. self.attrs = None
  112. def create(self, process_num, waitime):
  113. """ Create akg processor"""
  114. self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform)
  115. def accept_json(self, json):
  116. """ Accept json"""
  117. return self.akg_processor.accept_json(json)
  118. def compile(self):
  119. """Compile"""
  120. return self.akg_processor.compile(self.attrs)
  121. def handle(self, messager, arg):
  122. """Handle message about akg"""
  123. if arg == 'AKG/START':
  124. messager.send_ack()
  125. process_num_str = messager.get_message()
  126. messager.send_ack()
  127. wait_time_str = messager.get_message()
  128. messager.send_ack()
  129. self.create(int(process_num_str), int(wait_time_str))
  130. elif arg == 'AKG/ATTR':
  131. messager.send_ack()
  132. self.attrs = messager.get_message()
  133. messager.send_ack()
  134. elif arg == 'AKG/DATA':
  135. messager.send_ack()
  136. while True:
  137. req = messager.get_message()
  138. if req.startswith('{'):
  139. self.accept_json(req)
  140. messager.send_ack()
  141. elif req == 'AKG/WAIT':
  142. res = self.compile()
  143. messager.send_res(res)
  144. break
  145. else:
  146. messager.send_ack(False)
  147. break
  148. else:
  149. raise RuntimeError("Unknown message type: %s" % arg)
  150. def get_logger():
  151. return logger