| @@ -10,7 +10,7 @@ if(ENABLE_ACL) | |||||
| include_directories(${CMAKE_SOURCE_DIR}/graphengine/ge) | include_directories(${CMAKE_SOURCE_DIR}/graphengine/ge) | ||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
| file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} | ||||
| "python_utils.cc" | |||||
| "akg_kernel_register.cc" | |||||
| "model/acl/*.cc" | "model/acl/*.cc" | ||||
| "model/model_converter_utils/*.cc" | "model/model_converter_utils/*.cc" | ||||
| "graph/acl/*.cc" | "graph/acl/*.cc" | ||||
| @@ -19,11 +19,12 @@ endif() | |||||
| if(ENABLE_D) | if(ENABLE_D) | ||||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} | ||||
| "python_utils.cc" "model/ms/*.cc" "graph/ascend/*.cc") | |||||
| "akg_kernel_register.cc" "model/ms/*.cc" "graph/ascend/*.cc") | |||||
| endif() | endif() | ||||
| if(ENABLE_GPU) | if(ENABLE_GPU) | ||||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/gpu/*.cc") | |||||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| "akg_kernel_register.cc" "model/ms/*.cc" "graph/gpu/*.cc") | |||||
| endif() | endif() | ||||
| set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | ||||
| @@ -13,52 +13,18 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "cxx_api/python_utils.h" | |||||
| #include "cxx_api/akg_kernel_register.h" | |||||
| #include <dlfcn.h> | #include <dlfcn.h> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <vector> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include "mindspore/core/utils/ms_context.h" | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "backend/kernel_compiler/oplib/oplib.h" | #include "backend/kernel_compiler/oplib/oplib.h" | ||||
| namespace py = pybind11; | |||||
| static std::mutex init_mutex; | static std::mutex init_mutex; | ||||
| static bool Initialized = false; | static bool Initialized = false; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| static void RegAllOpFromPython() { | |||||
| MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||||
| Py_Initialize(); | |||||
| auto c_expression = PyImport_ImportModule("mindspore._c_expression"); | |||||
| MS_EXCEPTION_IF_NULL(c_expression); | |||||
| PyObject *c_expression_dict = PyModule_GetDict(c_expression); | |||||
| MS_EXCEPTION_IF_NULL(c_expression_dict); | |||||
| PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); | |||||
| MS_EXCEPTION_IF_NULL(op_info_loader_class); | |||||
| PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); | |||||
| MS_EXCEPTION_IF_NULL(op_info_loader); | |||||
| PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); | |||||
| MS_EXCEPTION_IF_NULL(op_info_loader_ins); | |||||
| auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); | |||||
| MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul); | |||||
| auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); | |||||
| auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr); | |||||
| for (auto op_info : *all_ops_info) { | |||||
| kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info)); | |||||
| } | |||||
| all_ops_info->clear(); | |||||
| delete all_ops_info; | |||||
| Py_DECREF(op_info_loader); | |||||
| Py_DECREF(op_info_loader_class); | |||||
| Py_DECREF(c_expression_dict); | |||||
| Py_DECREF(c_expression); | |||||
| } | |||||
| static bool RegAllOpFromFile() { | static bool RegAllOpFromFile() { | ||||
| Dl_info info; | Dl_info info; | ||||
| int dl_ret = dladdr(reinterpret_cast<void *>(RegAllOpFromFile), &info); | int dl_ret = dladdr(reinterpret_cast<void *>(RegAllOpFromFile), &info); | ||||
| @@ -111,36 +77,10 @@ void RegAllOp() { | |||||
| } | } | ||||
| bool ret = RegAllOpFromFile(); | bool ret = RegAllOpFromFile(); | ||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(INFO) << "Reg all op from file failed, start to reg from python."; | |||||
| RegAllOpFromPython(); | |||||
| MS_LOG(ERROR) << "Register operators failed. The package may damaged or file is missing."; | |||||
| return; | |||||
| } | } | ||||
| Initialized = true; | Initialized = true; | ||||
| } | } | ||||
| bool PythonIsInited() { return Py_IsInitialized() != 0; } | |||||
| void InitPython() { | |||||
| if (!PythonIsInited()) { | |||||
| Py_Initialize(); | |||||
| } | |||||
| } | |||||
| void FinalizePython() { | |||||
| if (PythonIsInited()) { | |||||
| Py_Finalize(); | |||||
| } | |||||
| } | |||||
| PythonEnvGuard::PythonEnvGuard() { | |||||
| origin_init_status_ = PythonIsInited(); | |||||
| InitPython(); | |||||
| } | |||||
| PythonEnvGuard::~PythonEnvGuard() { | |||||
| // finalize when init by this | |||||
| if (!origin_init_status_) { | |||||
| FinalizePython(); | |||||
| } | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,22 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H | |||||
| #define MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H | |||||
| #ifndef MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_ | |||||
| #define MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| void RegAllOp(); | void RegAllOp(); | ||||
| bool PythonIsInited(); | |||||
| void InitPython(); | |||||
| void FinalizePython(); | |||||
| class PythonEnvGuard { | |||||
| public: | |||||
| PythonEnvGuard(); | |||||
| ~PythonEnvGuard(); | |||||
| private: | |||||
| bool origin_init_status_; | |||||
| }; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_CXXAPI_PYTHON_UTILS_H | |||||
| #endif // MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_ | |||||
| @@ -17,7 +17,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "include/api/context.h" | #include "include/api/context.h" | ||||
| #include "cxx_api/factory.h" | #include "cxx_api/factory.h" | ||||
| #include "cxx_api/python_utils.h" | |||||
| #include "cxx_api/akg_kernel_register.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/context/context_extends.h" | #include "utils/context/context_extends.h" | ||||
| #include "mindspore/core/base/base_ref_utils.h" | #include "mindspore/core/base/base_ref_utils.h" | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "runtime/dev.h" | #include "runtime/dev.h" | ||||
| #include "pipeline/jit/pipeline.h" | #include "pipeline/jit/pipeline.h" | ||||
| #include "frontend/parallel/step_parallel.h" | #include "frontend/parallel/step_parallel.h" | ||||
| #include "pybind11/pybind11.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); | API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); | ||||
| @@ -380,4 +381,30 @@ std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv | |||||
| std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_; | std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_; | ||||
| std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; | std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; | ||||
| PythonEnvGuard::PythonEnvGuard() { | |||||
| origin_init_status_ = PythonIsInited(); | |||||
| InitPython(); | |||||
| } | |||||
| PythonEnvGuard::~PythonEnvGuard() { | |||||
| // finalize when init by this | |||||
| if (!origin_init_status_) { | |||||
| FinalizePython(); | |||||
| } | |||||
| } | |||||
| bool PythonEnvGuard::PythonIsInited() { return Py_IsInitialized() != 0; } | |||||
| void PythonEnvGuard::InitPython() { | |||||
| if (!PythonIsInited()) { | |||||
| Py_Initialize(); | |||||
| } | |||||
| } | |||||
| void PythonEnvGuard::FinalizePython() { | |||||
| if (PythonIsInited()) { | |||||
| Py_Finalize(); | |||||
| } | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -79,5 +79,17 @@ class AscendGraphImpl::MsEnvGuard { | |||||
| Status errno_; | Status errno_; | ||||
| uint32_t device_id_; | uint32_t device_id_; | ||||
| }; | }; | ||||
| class PythonEnvGuard { | |||||
| public: | |||||
| PythonEnvGuard(); | |||||
| ~PythonEnvGuard(); | |||||
| private: | |||||
| bool PythonIsInited(); | |||||
| void InitPython(); | |||||
| void FinalizePython(); | |||||
| bool origin_init_status_; | |||||
| }; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "include/api/context.h" | #include "include/api/context.h" | ||||
| #include "cxx_api/factory.h" | #include "cxx_api/factory.h" | ||||
| #include "cxx_api/akg_kernel_register.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "mindspore/core/base/base_ref_utils.h" | #include "mindspore/core/base/base_ref_utils.h" | ||||
| #include "backend/session/session_factory.h" | #include "backend/session/session_factory.h" | ||||
| @@ -43,6 +44,8 @@ Status GPUGraphImpl::InitEnv() { | |||||
| return kSuccess; | return kSuccess; | ||||
| } | } | ||||
| // Register op implemented with AKG. | |||||
| RegAllOp(); | |||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| if (ms_context == nullptr) { | if (ms_context == nullptr) { | ||||
| MS_LOG(ERROR) << "Get Context failed!"; | MS_LOG(ERROR) << "Get Context failed!"; | ||||