You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_build_client.cc 5.7 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/kernel_build_client.h"
  17. #include <memory>
  18. namespace mindspore {
  19. namespace kernel {
  20. inline static bool init_flag = false;
  21. void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
  22. std::string::size_type start = 0;
  23. while ((start = (*dest).find(replace, start)) != std::string::npos) {
  24. (*dest).replace(start, replace.size(), 1, new_char);
  25. start++; // Replaced 1 character.
  26. }
  27. }
  28. bool KernelBuildClient::AkgStart(int process_num, int wait_time) {
  29. // Start compiling..
  30. auto res = SendRequest(kAkgStart);
  31. if (res != kAck) {
  32. MS_LOG(ERROR) << "AKG/START failed, res: " << res;
  33. return false;
  34. }
  35. std::string process_num_str = std::to_string(process_num);
  36. res = SendRequest(process_num_str);
  37. if (res != kAck) {
  38. MS_LOG(ERROR) << "AKG/START(process_num) responds failed, res: " << res;
  39. return false;
  40. }
  41. std::string wait_time_str = std::to_string(wait_time);
  42. res = SendRequest(wait_time_str);
  43. if (res != kAck) {
  44. MS_LOG(ERROR) << "AKG/START(wait_time) responds failed, res: " << res;
  45. return false;
  46. }
  47. return true;
  48. }
  49. bool KernelBuildClient::AkgSendAttr(const std::string &attr) {
  50. auto res = SendRequest(kAkgAttr);
  51. if (res != kAck) {
  52. MS_LOG(ERROR) << "AKG/ATTR failed, res: " << res;
  53. return false;
  54. }
  55. res = SendRequest(attr);
  56. if (res != kAck) {
  57. MS_LOG(ERROR) << "AKG/ATTR.. responds failed, res: " << res << ", when sending [" << attr << "]";
  58. return false;
  59. }
  60. return true;
  61. }
  62. bool KernelBuildClient::AkgSendData(const std::vector<std::string> &jsons) {
  63. auto res = SendRequest(kAkgData);
  64. if (res != kAck) {
  65. MS_LOG(ERROR) << "AKG/DATA failed, res: " << res;
  66. return false;
  67. }
  68. for (auto &json : jsons) {
  69. res = SendRequest(json);
  70. if (res != kAck) {
  71. MS_LOG(ERROR) << "AKG/DATA.. responds failed, res: " << res << ", when sending [" << json << "]";
  72. return false;
  73. }
  74. }
  75. return true;
  76. }
  77. // Fetch the result of AKG compiling.
  78. bool KernelBuildClient::AkgWait() {
  79. auto res = SendRequest(kAkgWait);
  80. if (res != kTrue) {
  81. MS_LOG(ERROR) << "AKG/WAIT failed, res: " << res;
  82. return false;
  83. }
  84. return true;
  85. }
  86. void AscendKernelBuildClient::TbePre(const std::string &mode) {
  87. auto res = SendRequest(kTbePre);
  88. if (res.find(kSuccess) == std::string::npos) {
  89. MS_LOG(EXCEPTION) << "PRE failed, res: " << res;
  90. }
  91. MS_LOG(INFO) << "Pre " << res;
  92. // init env for auto tune
  93. res = SendRequest(kTbeTune);
  94. if (res != kAck) {
  95. MS_LOG(EXCEPTION) << "Send tune single failed, res: " << res;
  96. }
  97. res = SendRequest(mode);
  98. if (res != kSuccess) {
  99. MS_LOG(EXCEPTION) << "PRE failed, res: " << res;
  100. }
  101. }
  102. int AscendKernelBuildClient::TbeStart(const std::string &json, const std::string &mode) {
  103. if (!init_flag) {
  104. TbePre(mode);
  105. init_flag = true;
  106. }
  107. // Start compiling..
  108. auto res = SendRequest(kTbeStart);
  109. if (res != kAck) {
  110. MS_LOG(ERROR) << "START failed, res: " << res;
  111. return -1;
  112. }
  113. // Send the json data.
  114. res = SendRequest(json);
  115. if (res == kFailed) {
  116. MS_LOG(ERROR) << "TBE/START responds failed, res: " << res;
  117. return -1;
  118. }
  119. // Return task id.
  120. return std::stoi(res);
  121. }
  122. std::string AscendKernelBuildClient::TbeSendJob(const std::string &json) {
  123. auto res = SendRequest(kTbeJob);
  124. if (res != kAck) {
  125. MS_LOG(ERROR) << "Send TBE job failed, res: " << res;
  126. return "";
  127. }
  128. // Send the json data.
  129. res = SendRequest(json);
  130. if (res == kFailed) {
  131. MS_LOG(ERROR) << "Send TBE job json failed, res: " << res;
  132. return "";
  133. }
  134. return res;
  135. }
  136. bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) {
  137. // Start waiting..
  138. auto res = SendRequest(kTbeWait);
  139. if (res != kAck) {
  140. MS_LOG(ERROR) << "TBE/WAIT failed, res: " << res;
  141. return false;
  142. }
  143. // Request task id.
  144. *task_id = std::stoi(SendRequest(kContinue));
  145. // Request task result.
  146. *task_result = SendRequest(kContinue);
  147. // Request prebuild result.
  148. *pre_build_result = SendRequest(kContinue);
  149. return true;
  150. }
  151. void AscendKernelBuildClient::TbeReset() {
  152. // Start compiling..
  153. init_flag = false;
  154. auto res = SendRequest(kTbeReset);
  155. if (res != kAck) {
  156. MS_LOG(EXCEPTION) << "TBE/RESET response is: " << res;
  157. }
  158. }
  159. std::string AscendKernelBuildClient::SelectFormat(const std::string &json) {
  160. // Start compiling..
  161. auto res = SendRequest(kFormat);
  162. if (res != kAck) {
  163. MS_LOG(ERROR) << "FORMAT failed, res: " << res;
  164. return "";
  165. }
  166. // Send the json data.
  167. res = SendRequest(json);
  168. if (res == kErr) {
  169. MS_LOG(ERROR) << "FORMAT responds failed, res: " << res;
  170. return "";
  171. }
  172. return res;
  173. }
  174. bool AscendKernelBuildClient::CheckSupported(const std::string &json) {
  175. // Checking support..
  176. auto res = SendRequest(kSupport);
  177. if (res != kAck) {
  178. MS_LOG(ERROR) << "SUPPORT failed, res: " << res;
  179. return false;
  180. }
  181. // Send the json data.
  182. res = SendRequest(json);
  183. if (res != kTrue) {
  184. MS_LOG(INFO) << "SUPPORT responds failed, res: " << res;
  185. return false;
  186. }
  187. return true;
  188. }
  189. } // namespace kernel
  190. } // namespace mindspore