You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_python_funcs.cc 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_python_funcs.h"
  17. #include "kernel/tbe/tbe_utils.h"
  18. #include "common/utils.h"
  19. #include "utils/context/ms_context.h"
  20. namespace mindspore {
  21. namespace kernel {
  22. using mindspore::kernel::tbe::TbeUtils;
  23. constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process";
  24. constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler";
  25. constexpr auto kOpSelectFormatFunc = "op_select_format";
  26. constexpr auto kCheckSupportedFunc = "check_supported";
  27. constexpr auto kTBEException = "TBEException";
  28. PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr;
  29. PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr;
  30. PyObject *TbePythonFuncs::pOpSelectFormatFunc_ = nullptr;
  31. PyObject *TbePythonFuncs::pCheckSupportedFunc_ = nullptr;
  32. bool TbePythonFuncs::Init() {
  33. static bool initialized = false;
  34. if (initialized) {
  35. return true;
  36. }
  37. // Initialize cache
  38. TbeUtils::LoadCache();
  39. // tbe_process
  40. PyObject *pTbeProcessModule = nullptr;
  41. pTbeProcessModule = PyImport_ImportModule(kTbeProcessModule);
  42. if (pTbeProcessModule == nullptr) {
  43. MS_LOG(ERROR) << "Failed to import [" << kTbeProcessModule << "] module.";
  44. return false;
  45. }
  46. pCreateTbeParallelCompilerFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCreateTbeParallelCompilerFunc);
  47. if (pCreateTbeParallelCompilerFunc_ == nullptr) {
  48. MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule
  49. << "], FuncName:[" << kCreateTbeParallelCompilerFunc << "].";
  50. return false;
  51. }
  52. pTbeCompiler_ = PyEval_CallObject(pCreateTbeParallelCompilerFunc_, nullptr);
  53. if (pTbeCompiler_ == nullptr) {
  54. PyErr_Print();
  55. MS_EXCEPTION(ArgumentError) << "Failed to call function : create_parallel_compiler.";
  56. return false;
  57. }
  58. pOpSelectFormatFunc_ = PyObject_GetAttrString(pTbeProcessModule, kOpSelectFormatFunc);
  59. if (pOpSelectFormatFunc_ == nullptr) {
  60. MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule
  61. << "], FuncName:[" << kOpSelectFormatFunc << "].";
  62. return false;
  63. }
  64. pCheckSupportedFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCheckSupportedFunc);
  65. if (pCheckSupportedFunc_ == nullptr) {
  66. MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule
  67. << "], FuncName:[" << kCheckSupportedFunc << "].";
  68. return false;
  69. }
  70. initialized = true;
  71. MS_LOG(INFO) << "TbePythonFuncs initialized Success.";
  72. return true;
  73. }
  74. std::string TbePythonFuncs::PyObjectToStr(PyObject *PyObj) {
  75. char *pChar = nullptr;
  76. std::string str_res;
  77. if (PyObj == nullptr) {
  78. MS_LOG(ERROR) << "Input parameter is nullptr.";
  79. return str_res;
  80. }
  81. PyObject *strArgs = PyObject_Str(PyObj);
  82. if (strArgs != nullptr) {
  83. (void)PyArg_Parse(strArgs, "s", &pChar);
  84. }
  85. if (pChar == nullptr) {
  86. MS_LOG(ERROR) << "pChar is nullptr.";
  87. return str_res;
  88. }
  89. str_res = pChar;
  90. return str_res;
  91. }
  92. std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) {
  93. PyObject *pArg = nullptr;
  94. PyObject *pRet = nullptr;
  95. std::string res_json_str;
  96. if (!Init()) {
  97. MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !";
  98. return res_json_str;
  99. }
  100. // assembly Args
  101. pArg = PyTuple_New(1);
  102. std::string json_str = kernel_json.dump();
  103. (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", json_str.c_str()));
  104. if (pArg == nullptr) {
  105. MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject.";
  106. return res_json_str;
  107. }
  108. // call functions
  109. if (pOpSelectFormatFunc_ == nullptr) {
  110. MS_LOG(ERROR) << "function is nullptr.";
  111. return res_json_str;
  112. }
  113. pRet = PyEval_CallObject(pOpSelectFormatFunc_, pArg);
  114. if (pRet == nullptr) {
  115. PyErr_Print();
  116. MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc
  117. << "], function args:" << PyObjectToStr(pArg);
  118. }
  119. char *pstr = nullptr;
  120. (void)PyArg_Parse(pRet, "s", &pstr);
  121. res_json_str = pstr;
  122. if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) {
  123. MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str
  124. << " ,function args:" << PyObjectToStr(pArg);
  125. }
  126. return res_json_str;
  127. }
  128. bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) {
  129. PyObject *pArg = nullptr;
  130. PyObject *pRes = nullptr;
  131. bool ret = false;
  132. if (!Init()) {
  133. MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !";
  134. return ret;
  135. }
  136. // assembly Args
  137. pArg = PyTuple_New(1);
  138. std::string json_str = kernel_json.dump();
  139. PyObject *arg1 = Py_BuildValue("s", json_str.c_str());
  140. (void)PyTuple_SetItem(pArg, 0, arg1);
  141. if (pArg == nullptr) {
  142. MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject.";
  143. return ret;
  144. }
  145. // call functions
  146. if (pCheckSupportedFunc_ == nullptr) {
  147. MS_LOG(ERROR) << "function is nullptr.";
  148. return ret;
  149. }
  150. pRes = PyEval_CallObject(pCheckSupportedFunc_, pArg);
  151. if (pRes == nullptr) {
  152. PyErr_Print();
  153. MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc
  154. << "], function args: " << PyObjectToStr(pArg);
  155. }
  156. if (PyBool_Check(pRes)) {
  157. ret = PyObject_IsTrue(pRes) != 0;
  158. } else {
  159. char *pstr = nullptr;
  160. (void)PyArg_Parse(pRes, "s", &pstr);
  161. std::string res_str = pstr;
  162. if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) {
  163. MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str
  164. << ", function args: " << PyObjectToStr(pArg);
  165. }
  166. }
  167. return ret;
  168. }
  169. PyObject *TbePythonFuncs::TbeParallelCompiler() {
  170. if (!Init()) {
  171. MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !";
  172. return nullptr;
  173. }
  174. return pTbeCompiler_;
  175. }
  176. } // namespace kernel
  177. } // namespace mindspore