You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_session.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/infer_session.h"
  17. #include <memory>
  18. #include <algorithm>
  19. #include <fstream>
  20. #include "include/inference.h"
  21. #include "utils/load_onnx/anf_converter.h"
  22. #include "backend/session/session_basic.h"
  23. #include "backend/session/session_factory.h"
  24. #include "backend/session/executor_manager.h"
  25. #include "base/base_ref_utils.h"
  26. #include "backend/kernel_compiler/oplib/oplib.h"
  27. #include "utils/context/context_extends.h"
  28. #include "runtime/device/kernel_runtime_manager.h"
  29. #include "pybind11/pybind11.h"
  30. #ifdef ENABLE_D
  31. #include "utils/ms_context.h"
  32. #endif
  33. using std::string;
  34. using std::vector;
  35. namespace py = pybind11;
  36. namespace mindspore {
  37. namespace inference {
  38. std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
  39. try {
  40. auto session = std::make_shared<MSInferSession>();
  41. Status ret = session->InitEnv(device, device_id);
  42. if (ret != SUCCESS) {
  43. return nullptr;
  44. }
  45. return session;
  46. } catch (std::bad_alloc &e) {
  47. MS_LOG(ERROR) << "Inference CreatSession failed, failed to alloc memory";
  48. return nullptr;
  49. }
  50. }
  51. MSInferSession::MSInferSession() = default;
  52. MSInferSession::~MSInferSession() = default;
  53. std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &file) {
  54. if (file.empty()) {
  55. MS_LOG(ERROR) << "file is nullptr";
  56. return nullptr;
  57. }
  58. std::string realPath = file;
  59. std::ifstream ifs(realPath);
  60. if (!ifs.good()) {
  61. MS_LOG(ERROR) << "file: " << realPath << " is not exist";
  62. return nullptr;
  63. }
  64. if (!ifs.is_open()) {
  65. MS_LOG(ERROR) << "file: " << realPath << "open failed";
  66. return nullptr;
  67. }
  68. ifs.seekg(0, std::ios::end);
  69. size_t size = ifs.tellg();
  70. std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
  71. if (buf == nullptr) {
  72. MS_LOG(ERROR) << "malloc buf failed, file: " << realPath;
  73. ifs.close();
  74. return nullptr;
  75. }
  76. ifs.seekg(0, std::ios::beg);
  77. ifs.read(buf->data(), size);
  78. ifs.close();
  79. return buf;
  80. }
  81. Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
  82. auto graphBuf = ReadFile(file_name);
  83. if (graphBuf == nullptr) {
  84. MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
  85. return FAILED;
  86. }
  87. auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
  88. if (graph == nullptr) {
  89. MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
  90. return FAILED;
  91. }
  92. Status ret = CompileGraph(graph, model_id);
  93. if (ret != SUCCESS) {
  94. MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
  95. return FAILED;
  96. }
  97. MS_LOG(INFO) << "Load model from file " << file_name << " success";
  98. #ifdef ENABLE_D
  99. // set d context
  100. rtError_t rt_ret = rtCtxGetCurrent(&context_);
  101. if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
  102. MS_LOG(ERROR) << "the ascend device context is null";
  103. return FAILED;
  104. }
  105. #endif
  106. return SUCCESS;
  107. }
  108. Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; }
  109. Status ServingTensor2MSTensor(size_t index, const InferTensorBase &out_tensor, tensor::TensorPtr &ms_tensor) {
  110. std::vector<int> shape;
  111. for (auto dim : out_tensor.shape()) {
  112. shape.push_back(static_cast<int>(dim));
  113. }
  114. TypeId data_type;
  115. const std::map<inference::DataType, TypeId> type2id_map{
  116. {inference::kMSI_Unknown, TypeId::kNumberTypeBegin}, {inference::kMSI_Bool, TypeId::kNumberTypeBool},
  117. {inference::kMSI_Int8, TypeId::kNumberTypeInt8}, {inference::kMSI_Uint8, TypeId::kNumberTypeUInt8},
  118. {inference::kMSI_Int16, TypeId::kNumberTypeInt16}, {inference::kMSI_Uint16, TypeId::kNumberTypeUInt16},
  119. {inference::kMSI_Int32, TypeId::kNumberTypeInt32}, {inference::kMSI_Uint32, TypeId::kNumberTypeUInt32},
  120. {inference::kMSI_Int64, TypeId::kNumberTypeInt64}, {inference::kMSI_Uint64, TypeId::kNumberTypeUInt64},
  121. {inference::kMSI_Float16, TypeId::kNumberTypeFloat16}, {inference::kMSI_Float32, TypeId::kNumberTypeFloat32},
  122. {inference::kMSI_Float64, TypeId::kNumberTypeFloat64},
  123. };
  124. auto it = type2id_map.find(out_tensor.data_type());
  125. if (it == type2id_map.end()) {
  126. MSI_LOG_WARNING << "undefined MSI data type " << out_tensor.data_type();
  127. return FAILED;
  128. } else {
  129. data_type = it->second;
  130. }
  131. ms_tensor = std::make_shared<tensor::Tensor>(data_type, shape);
  132. if (out_tensor.data_size() == 0 || ms_tensor->Size() != out_tensor.data_size()) {
  133. MSI_LOG_ERROR << "input " << std::to_string(index)
  134. << " data size not match shape and dtype, calculated required size " << ms_tensor->Size()
  135. << ", given " << out_tensor.data_size();
  136. return INFER_STATUS(INVALID_INPUTS) << "input " << std::to_string(index)
  137. << " data size not match shape and dtype, calculated required size "
  138. << ms_tensor->Size() << ", given " << out_tensor.data_size();
  139. }
  140. if (out_tensor.data() == nullptr || ms_tensor->data_c() == nullptr) {
  141. MSI_LOG_ERROR << "invalid data buffer";
  142. return FAILED;
  143. }
  144. auto ret_code = memcpy_s(ms_tensor->data_c(), ms_tensor->Size(), out_tensor.data(), out_tensor.data_size());
  145. if (ret_code != 0) {
  146. MS_LOG(ERROR) << "Failed to copy data from ms_tensor to out_tensor.";
  147. }
  148. return SUCCESS;
  149. }
  150. void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_tensor) {
  151. vector<int64_t> shape;
  152. for (auto dim : ms_tensor->shape()) {
  153. shape.push_back(dim);
  154. }
  155. out_tensor.set_shape(shape);
  156. const std::map<TypeId, inference::DataType> id2type_map{
  157. {TypeId::kNumberTypeBegin, inference::kMSI_Unknown}, {TypeId::kNumberTypeBool, inference::kMSI_Bool},
  158. {TypeId::kNumberTypeFloat64, inference::kMSI_Float64}, {TypeId::kNumberTypeInt8, inference::kMSI_Int8},
  159. {TypeId::kNumberTypeUInt8, inference::kMSI_Uint8}, {TypeId::kNumberTypeInt16, inference::kMSI_Int16},
  160. {TypeId::kNumberTypeUInt16, inference::kMSI_Uint16}, {TypeId::kNumberTypeInt32, inference::kMSI_Int32},
  161. {TypeId::kNumberTypeUInt32, inference::kMSI_Uint32}, {TypeId::kNumberTypeInt64, inference::kMSI_Int64},
  162. {TypeId::kNumberTypeUInt64, inference::kMSI_Uint64}, {TypeId::kNumberTypeFloat16, inference::kMSI_Float16},
  163. {TypeId::kNumberTypeFloat32, inference::kMSI_Float32},
  164. };
  165. auto it = id2type_map.find(ms_tensor->data_type());
  166. if (it == id2type_map.end()) {
  167. MSI_LOG_WARNING << "undefined MS data type " << ms_tensor->data_type();
  168. out_tensor.set_data_type(inference::kMSI_Unknown);
  169. } else {
  170. out_tensor.set_data_type(it->second);
  171. }
  172. out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
  173. }
  174. Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
  175. #ifdef ENABLE_D
  176. if (context_ == nullptr) {
  177. MS_LOG(ERROR) << "rtCtx is nullptr";
  178. return FAILED;
  179. }
  180. rtError_t rt_ret = rtCtxSetCurrent(context_);
  181. if (rt_ret != RT_ERROR_NONE) {
  182. MS_LOG(ERROR) << "set Ascend rtCtx failed";
  183. return FAILED;
  184. }
  185. #endif
  186. vector<tensor::TensorPtr> inputs;
  187. for (size_t i = 0; i < request.size(); i++) {
  188. if (request[i] == nullptr) {
  189. MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i;
  190. return FAILED;
  191. }
  192. tensor::TensorPtr input = nullptr;
  193. auto ret = ServingTensor2MSTensor(i, *request[i], input);
  194. if (ret != SUCCESS) {
  195. MS_LOG(ERROR) << "Tensor convert failed";
  196. return ret;
  197. }
  198. inputs.push_back(input);
  199. }
  200. auto ret = CheckModelInputs(model_id, inputs);
  201. if (ret != SUCCESS) {
  202. MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
  203. return ret;
  204. }
  205. vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
  206. if (outputs.empty()) {
  207. MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
  208. return FAILED;
  209. }
  210. reply.clear();
  211. for (const auto &tensor : outputs) {
  212. auto out_tensor = reply.add();
  213. if (out_tensor == nullptr) {
  214. MS_LOG(ERROR) << "Execute Model " << model_id << " Failed add output tensor failed";
  215. return FAILED;
  216. }
  217. MSTensor2ServingTensor(tensor, *out_tensor);
  218. }
  219. return SUCCESS;
  220. }
  221. Status MSInferSession::FinalizeEnv() {
  222. session::ExecutorManager::Instance().Clear();
  223. device::KernelRuntimeManager::Instance().ClearRuntimeResource();
  224. auto ms_context = MsContext::GetInstance();
  225. if (ms_context == nullptr) {
  226. MS_LOG(ERROR) << "Get Context failed!";
  227. return FAILED;
  228. }
  229. if (!context::CloseTsd(ms_context)) {
  230. MS_LOG(ERROR) << "Inference CloseTsd failed!";
  231. return FAILED;
  232. }
  233. return SUCCESS;
  234. }
  235. std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
  236. try {
  237. auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
  238. return anf_graph;
  239. } catch (std::exception &e) {
  240. MS_LOG(ERROR) << "Inference LoadModel failed";
  241. return nullptr;
  242. }
  243. }
  244. void MSInferSession::RegAllOp() {
  245. static std::mutex init_mutex;
  246. static bool Initialized = false;
  247. std::lock_guard<std::mutex> lock(init_mutex);
  248. if (Initialized) {
  249. return;
  250. }
  251. Initialized = true;
  252. MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
  253. Py_Initialize();
  254. auto c_expression = PyImport_ImportModule("mindspore._c_expression");
  255. MS_EXCEPTION_IF_NULL(c_expression);
  256. PyObject *c_expression_dict = PyModule_GetDict(c_expression);
  257. MS_EXCEPTION_IF_NULL(c_expression_dict);
  258. PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
  259. MS_EXCEPTION_IF_NULL(op_info_loader_class);
  260. PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
  261. MS_EXCEPTION_IF_NULL(op_info_loader);
  262. PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
  263. MS_EXCEPTION_IF_NULL(op_info_loader_ins);
  264. auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
  265. MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul);
  266. auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
  267. auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
  268. for (auto op_info : *all_ops_info) {
  269. kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
  270. }
  271. all_ops_info->clear();
  272. delete all_ops_info;
  273. Py_DECREF(op_info_loader);
  274. Py_DECREF(op_info_loader_class);
  275. Py_DECREF(c_expression_dict);
  276. Py_DECREF(c_expression);
  277. return;
  278. }
  279. Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
  280. MS_ASSERT(session_impl_ != nullptr);
  281. try {
  282. auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
  283. py::gil_scoped_release gil_release;
  284. model_id = graph_id;
  285. return SUCCESS;
  286. } catch (std::exception &e) {
  287. MS_LOG(ERROR) << "Inference CompileGraph failed";
  288. return FAILED;
  289. }
  290. }
  291. std::vector<tensor::TensorPtr> MSInferSession::RunGraph(uint32_t graph_id,
  292. const std::vector<tensor::TensorPtr> &inputs) {
  293. try {
  294. VectorRef outputs;
  295. session_impl_->RunGraph(graph_id, inputs, &outputs);
  296. return TransformVectorRefToMultiTensor(outputs);
  297. } catch (std::exception &e) {
  298. MS_LOG(ERROR) << "Inference Rungraph failed";
  299. return std::vector<tensor::TensorPtr>();
  300. }
  301. }
  302. string MSInferSession::AjustTargetName(const std::string &device) {
  303. if (device == kAscendDevice) {
  304. return std::string(kAscendDevice) + "Inference";
  305. } else {
  306. MS_LOG(ERROR) << "Only support device Ascend right now";
  307. return "";
  308. }
  309. }
  310. Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
  311. RegAllOp();
  312. auto ms_context = MsContext::GetInstance();
  313. if (ms_context == nullptr) {
  314. MS_LOG(ERROR) << "Get Context failed!";
  315. return FAILED;
  316. }
  317. ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
  318. ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
  319. auto ajust_device = AjustTargetName(device);
  320. if (ajust_device == "") {
  321. return FAILED;
  322. }
  323. ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, device);
  324. if (!context::OpenTsd(ms_context)) {
  325. MS_LOG(ERROR) << "Session init OpenTsd failed!";
  326. return FAILED;
  327. }
  328. session_impl_ = session::SessionFactory::Get().Create(ajust_device);
  329. if (session_impl_ == nullptr) {
  330. MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
  331. return FAILED;
  332. }
  333. session_impl_->Init(device_id);
  334. return SUCCESS;
  335. }
  336. Status MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
  337. MS_ASSERT(session_impl_ != nullptr);
  338. std::string error_msg;
  339. if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) {
  340. return INFER_STATUS(INVALID_INPUTS) << error_msg;
  341. }
  342. return SUCCESS;
  343. }
  344. Status MSInferSession::GetModelInputsInfo(uint32_t model_id, std::vector<inference::InferTensor> *tensor_list) const {
  345. vector<tensor::TensorPtr> inputs;
  346. session_impl_->GetModelInputsInfo(model_id, &inputs);
  347. if (inputs.size() == 0) {
  348. MS_LOG(ERROR) << "The model inputs is NULL";
  349. return FAILED;
  350. }
  351. for (const auto &tensor : inputs) {
  352. InferTensor infer_tensor = InferTensor();
  353. MSTensor2ServingTensor(tensor, infer_tensor);
  354. tensor_list->push_back(infer_tensor);
  355. }
  356. return SUCCESS;
  357. }
  358. } // namespace inference
  359. } // namespace mindspore