You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

http_process.cc 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <map>
  17. #include <vector>
  18. #include <string>
  19. #include <nlohmann/json.hpp>
  20. #include "serving/ms_service.pb.h"
  21. #include "util/status.h"
  22. #include "core/session.h"
  23. #include "core/http_process.h"
  24. #include "core/serving_tensor.h"
  25. using ms_serving::MSService;
  26. using ms_serving::PredictReply;
  27. using ms_serving::PredictRequest;
  28. using nlohmann::json;
  29. namespace mindspore {
  30. namespace serving {
  31. const int BUF_MAX = 0x7FFFFFFF;
  32. static constexpr char HTTP_DATA[] = "data";
  33. static constexpr char HTTP_TENSOR[] = "tensor";
  34. enum HTTP_TYPE { TYPE_DATA = 0, TYPE_TENSOR };
  35. enum HTTP_DATA_TYPE { HTTP_DATA_NONE, HTTP_DATA_INT, HTTP_DATA_FLOAT };
  36. static const std::map<inference::DataType, HTTP_DATA_TYPE> infer_type2_http_type{
  37. {inference::DataType::kMSI_Int32, HTTP_DATA_INT}, {inference::DataType::kMSI_Float32, HTTP_DATA_FLOAT}};
  38. Status GetPostMessage(struct evhttp_request *req, std::string *buf) {
  39. Status status(SUCCESS);
  40. size_t post_size = evbuffer_get_length(req->input_buffer);
  41. if (post_size == 0) {
  42. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message invalid");
  43. return status;
  44. } else if (post_size > BUF_MAX) {
  45. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message is bigger than 0x7FFFFFFF.");
  46. return status;
  47. } else {
  48. buf->resize(post_size);
  49. memcpy(buf->data(), evbuffer_pullup(req->input_buffer, -1), post_size);
  50. return status;
  51. }
  52. }
  53. Status CheckRequestValid(struct evhttp_request *http_request) {
  54. Status status(SUCCESS);
  55. switch (evhttp_request_get_command(http_request)) {
  56. case EVHTTP_REQ_POST:
  57. return status;
  58. default:
  59. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message only support POST right now");
  60. return status;
  61. }
  62. }
  63. void ErrorMessage(struct evhttp_request *req, Status status) {
  64. json error_json = {{"error_message", status.StatusMessage()}};
  65. std::string out_error_str = error_json.dump();
  66. struct evbuffer *retbuff = evbuffer_new();
  67. evbuffer_add(retbuff, out_error_str.data(), out_error_str.size());
  68. evhttp_send_reply(req, HTTP_OK, "Client", retbuff);
  69. evbuffer_free(retbuff);
  70. }
  71. Status CheckMessageValid(const json &message_info, HTTP_TYPE *type) {
  72. Status status(SUCCESS);
  73. int count = 0;
  74. if (message_info.find(HTTP_DATA) != message_info.end()) {
  75. *type = TYPE_DATA;
  76. count++;
  77. }
  78. if (message_info.find(HTTP_TENSOR) != message_info.end()) {
  79. *type = TYPE_TENSOR;
  80. count++;
  81. }
  82. if (count != 1) {
  83. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message must have only one type of (data, tensor)");
  84. return status;
  85. }
  86. return status;
  87. }
  88. Status GetDataFromJson(const json &json_data_array, ServingTensor *request_tensor, size_t data_index,
  89. HTTP_DATA_TYPE type) {
  90. Status status(SUCCESS);
  91. auto type_name = [](const json &json_data) -> std::string {
  92. if (json_data.is_number_integer()) {
  93. return "integer";
  94. } else if (json_data.is_number_float()) {
  95. return "float";
  96. }
  97. return json_data.type_name();
  98. };
  99. size_t array_size = json_data_array.size();
  100. if (type == HTTP_DATA_INT) {
  101. auto data = reinterpret_cast<int32_t *>(request_tensor->mutable_data()) + data_index;
  102. for (size_t k = 0; k < array_size; k++) {
  103. auto &json_data = json_data_array[k];
  104. if (!json_data.is_number_integer()) {
  105. status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected integer, given " << type_name(json_data);
  106. MSI_LOG_ERROR << status.StatusMessage();
  107. return status;
  108. }
  109. data[k] = json_data.get<int32_t>();
  110. }
  111. } else if (type == HTTP_DATA_FLOAT) {
  112. auto data = reinterpret_cast<float *>(request_tensor->mutable_data()) + data_index;
  113. for (size_t k = 0; k < array_size; k++) {
  114. auto &json_data = json_data_array[k];
  115. if (!json_data.is_number_float()) {
  116. status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected float, given " << type_name(json_data);
  117. MSI_LOG_ERROR << status.StatusMessage();
  118. return status;
  119. }
  120. data[k] = json_data.get<float>();
  121. }
  122. }
  123. return SUCCESS;
  124. }
  125. Status RecusiveGetTensor(const json &json_data, size_t depth, ServingTensor *request_tensor, size_t data_index,
  126. HTTP_DATA_TYPE type) {
  127. Status status(SUCCESS);
  128. std::vector<int64_t> required_shape = request_tensor->shape();
  129. if (depth >= required_shape.size()) {
  130. status = INFER_STATUS(INVALID_INPUTS)
  131. << "input tensor shape dims is more than required dims " << required_shape.size();
  132. MSI_LOG_ERROR << status.StatusMessage();
  133. return status;
  134. }
  135. if (!json_data.is_array()) {
  136. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the tensor is constructed illegally");
  137. return status;
  138. }
  139. if (json_data.size() != static_cast<size_t>(required_shape[depth])) {
  140. status = INFER_STATUS(INVALID_INPUTS)
  141. << "tensor format request is constructed illegally, input tensor shape dim " << depth
  142. << " not match, required " << required_shape[depth] << ", given " << json_data.size();
  143. MSI_LOG_ERROR << status.StatusMessage();
  144. return status;
  145. }
  146. if (depth + 1 < required_shape.size()) {
  147. size_t sub_element_cnt =
  148. std::accumulate(required_shape.begin() + depth + 1, required_shape.end(), 1LL, std::multiplies<size_t>());
  149. for (size_t k = 0; k < json_data.size(); k++) {
  150. status = RecusiveGetTensor(json_data[k], depth + 1, request_tensor, data_index + sub_element_cnt * k, type);
  151. if (status != SUCCESS) {
  152. return status;
  153. }
  154. }
  155. } else {
  156. status = GetDataFromJson(json_data, request_tensor, data_index, type);
  157. if (status != SUCCESS) {
  158. return status;
  159. }
  160. }
  161. return status;
  162. }
  163. std::vector<int64_t> GetJsonArrayShape(const json &json_array) {
  164. std::vector<int64_t> json_shape;
  165. const json *tmp_json = &json_array;
  166. while (tmp_json->is_array()) {
  167. if (tmp_json->empty()) {
  168. break;
  169. }
  170. json_shape.push_back(tmp_json->size());
  171. tmp_json = &tmp_json->at(0);
  172. }
  173. return json_shape;
  174. }
  175. Status TransDataToPredictRequest(const json &message_info, PredictRequest *request) {
  176. Status status = SUCCESS;
  177. auto tensors = message_info.find(HTTP_DATA);
  178. if (tensors == message_info.end()) {
  179. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message do not have data type");
  180. return status;
  181. }
  182. if (!tensors->is_array()) {
  183. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input tensor list is not array");
  184. return status;
  185. }
  186. auto const &json_shape = GetJsonArrayShape(*tensors);
  187. if (json_shape.size() != 2) { // 2 is data format list deep
  188. status = INFER_STATUS(INVALID_INPUTS)
  189. << "the data format request is constructed illegally, expected list nesting depth 2, given "
  190. << json_shape.size();
  191. MSI_LOG_ERROR << status.StatusMessage();
  192. return status;
  193. }
  194. if (tensors->size() != static_cast<size_t>(request->data_size())) {
  195. status = INFER_STATUS(INVALID_INPUTS)
  196. << "model input count not match, model required " << request->data_size() << ", given " << tensors->size();
  197. MSI_LOG_ERROR << status.StatusMessage();
  198. return status;
  199. }
  200. for (size_t i = 0; i < tensors->size(); i++) {
  201. const auto &tensor = tensors->at(i);
  202. ServingTensor request_tensor(*(request->mutable_data(i)));
  203. auto iter = infer_type2_http_type.find(request_tensor.data_type());
  204. if (iter == infer_type2_http_type.end()) {
  205. ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now");
  206. return status;
  207. }
  208. HTTP_DATA_TYPE type = iter->second;
  209. if (!tensor.is_array()) {
  210. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the tensor is constructed illegally");
  211. return status;
  212. }
  213. if (tensor.empty()) {
  214. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input tensor is null");
  215. return status;
  216. }
  217. if (tensor.size() != static_cast<size_t>(request_tensor.ElementNum())) {
  218. status = INFER_STATUS(INVALID_INPUTS) << "input " << i << " element count not match, model required "
  219. << request_tensor.ElementNum() << ", given " << tensor.size();
  220. MSI_LOG_ERROR << status.StatusMessage();
  221. return status;
  222. }
  223. status = GetDataFromJson(tensor, &request_tensor, 0, type);
  224. if (status != SUCCESS) {
  225. return status;
  226. }
  227. }
  228. return SUCCESS;
  229. }
  230. Status TransTensorToPredictRequest(const json &message_info, PredictRequest *request) {
  231. Status status(SUCCESS);
  232. auto tensors = message_info.find(HTTP_TENSOR);
  233. if (tensors == message_info.end()) {
  234. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message do not have tensor type");
  235. return status;
  236. }
  237. if (!tensors->is_array()) {
  238. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input tensor list is not array");
  239. return status;
  240. }
  241. if (tensors->size() != static_cast<size_t>(request->data_size())) {
  242. status =
  243. INFER_STATUS(INVALID_INPUTS)
  244. << "model input count not match or json tensor request is constructed illegally, model input count required "
  245. << request->data_size() << ", given " << tensors->size();
  246. MSI_LOG_ERROR << status.StatusMessage();
  247. return status;
  248. }
  249. for (size_t i = 0; i < tensors->size(); i++) {
  250. const auto &tensor = tensors->at(i);
  251. ServingTensor request_tensor(*(request->mutable_data(i)));
  252. // check data shape
  253. auto const &json_shape = GetJsonArrayShape(tensor);
  254. if (json_shape != request_tensor.shape()) { // data shape not match
  255. status = INFER_STATUS(INVALID_INPUTS)
  256. << "input " << i << " shape is invalid, expected " << request_tensor.shape() << ", given " << json_shape;
  257. MSI_LOG_ERROR << status.StatusMessage();
  258. return status;
  259. }
  260. auto iter = infer_type2_http_type.find(request_tensor.data_type());
  261. if (iter == infer_type2_http_type.end()) {
  262. ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now");
  263. return status;
  264. }
  265. HTTP_DATA_TYPE type = iter->second;
  266. size_t depth = 0;
  267. size_t data_index = 0;
  268. status = RecusiveGetTensor(tensor, depth, &request_tensor, data_index, type);
  269. if (status != SUCCESS) {
  270. MSI_LOG_ERROR << "Transfer tensor to predict request failed";
  271. return status;
  272. }
  273. }
  274. return status;
  275. }
  276. Status TransHTTPMsgToPredictRequest(struct evhttp_request *http_request, PredictRequest *request, HTTP_TYPE *type) {
  277. Status status = CheckRequestValid(http_request);
  278. if (status != SUCCESS) {
  279. return status;
  280. }
  281. std::string post_message;
  282. status = GetPostMessage(http_request, &post_message);
  283. if (status != SUCCESS) {
  284. return status;
  285. }
  286. // get model required shape
  287. std::vector<inference::InferTensor> tensor_list;
  288. status = Session::Instance().GetModelInputsInfo(tensor_list);
  289. if (status != SUCCESS) {
  290. ERROR_INFER_STATUS(status, FAILED, "get model inputs info failed");
  291. return status;
  292. }
  293. for (auto &item : tensor_list) {
  294. auto input = request->add_data();
  295. ServingTensor tensor(*input);
  296. tensor.set_shape(item.shape());
  297. tensor.set_data_type(item.data_type());
  298. int64_t element_num = tensor.ElementNum();
  299. int64_t data_type_size = tensor.GetTypeSize(tensor.data_type());
  300. if (element_num <= 0 || INT64_MAX / element_num < data_type_size) {
  301. ERROR_INFER_STATUS(status, FAILED, "model shape invalid");
  302. return status;
  303. }
  304. tensor.resize_data(element_num * data_type_size);
  305. }
  306. MSI_TIME_STAMP_START(ParseJson)
  307. json message_info;
  308. try {
  309. message_info = nlohmann::json::parse(post_message);
  310. } catch (nlohmann::json::exception &e) {
  311. std::string json_exception = e.what();
  312. std::string error_message = "Illegal JSON format." + json_exception;
  313. ERROR_INFER_STATUS(status, INVALID_INPUTS, error_message);
  314. return status;
  315. }
  316. MSI_TIME_STAMP_END(ParseJson)
  317. status = CheckMessageValid(message_info, type);
  318. if (status != SUCCESS) {
  319. return status;
  320. }
  321. switch (*type) {
  322. case TYPE_DATA:
  323. status = TransDataToPredictRequest(message_info, request);
  324. break;
  325. case TYPE_TENSOR:
  326. status = TransTensorToPredictRequest(message_info, request);
  327. break;
  328. default:
  329. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message must have only one type of (data, tensor)");
  330. return status;
  331. }
  332. return status;
  333. }
  334. Status GetJsonFromTensor(const ms_serving::Tensor &tensor, int len, int *pos, json *out_json) {
  335. Status status(SUCCESS);
  336. switch (tensor.tensor_type()) {
  337. case ms_serving::MS_INT32: {
  338. auto data = reinterpret_cast<const int *>(tensor.data().data()) + *pos;
  339. std::vector<int32_t> result_tensor(len);
  340. memcpy_s(result_tensor.data(), result_tensor.size() * sizeof(int32_t), data, len * sizeof(int32_t));
  341. *out_json = std::move(result_tensor);
  342. *pos += len;
  343. break;
  344. }
  345. case ms_serving::MS_FLOAT32: {
  346. auto data = reinterpret_cast<const float *>(tensor.data().data()) + *pos;
  347. std::vector<float> result_tensor(len);
  348. memcpy_s(result_tensor.data(), result_tensor.size() * sizeof(float), data, len * sizeof(float));
  349. *out_json = std::move(result_tensor);
  350. *pos += len;
  351. break;
  352. }
  353. default:
  354. MSI_LOG(ERROR) << "the result type is not supported in restful api, type is " << tensor.tensor_type();
  355. ERROR_INFER_STATUS(status, FAILED, "reply have unsupported type");
  356. }
  357. return status;
  358. }
  359. Status TransPredictReplyToData(const PredictReply &reply, json *out_json) {
  360. Status status(SUCCESS);
  361. for (int i = 0; i < reply.result_size(); i++) {
  362. (*out_json)["data"].push_back(json());
  363. json &tensor_json = (*out_json)["data"].back();
  364. int num = 1;
  365. for (auto j = 0; j < reply.result(i).tensor_shape().dims_size(); j++) {
  366. num *= reply.result(i).tensor_shape().dims(j);
  367. }
  368. int pos = 0;
  369. status = GetJsonFromTensor(reply.result(i), num, &pos, &tensor_json);
  370. if (status != SUCCESS) {
  371. return status;
  372. }
  373. }
  374. return status;
  375. }
  376. Status RecusiveGetJson(const ms_serving::Tensor &tensor, int depth, int *pos, json *out_json) {
  377. Status status(SUCCESS);
  378. if (depth >= 10) {
  379. ERROR_INFER_STATUS(status, FAILED, "result tensor shape dims is larger than 10");
  380. return status;
  381. }
  382. if (depth == tensor.tensor_shape().dims_size() - 1) {
  383. status = GetJsonFromTensor(tensor, tensor.tensor_shape().dims(depth), pos, out_json);
  384. if (status != SUCCESS) {
  385. return status;
  386. }
  387. } else {
  388. for (int i = 0; i < tensor.tensor_shape().dims(depth); i++) {
  389. out_json->push_back(json());
  390. json &tensor_json = out_json->back();
  391. status = RecusiveGetJson(tensor, depth + 1, pos, &tensor_json);
  392. if (status != SUCCESS) {
  393. return status;
  394. }
  395. }
  396. }
  397. return status;
  398. }
  399. Status TransPredictReplyToTensor(const PredictReply &reply, json *out_json) {
  400. Status status(SUCCESS);
  401. for (int i = 0; i < reply.result_size(); i++) {
  402. (*out_json)["tensor"].push_back(json());
  403. json &tensor_json = (*out_json)["tensor"].back();
  404. int pos = 0;
  405. status = RecusiveGetJson(reply.result(i), 0, &pos, &tensor_json);
  406. if (status != SUCCESS) {
  407. return status;
  408. }
  409. }
  410. return status;
  411. }
  412. Status TransPredictReplyToHTTPMsg(const PredictReply &reply, const HTTP_TYPE &type, struct evbuffer *buf) {
  413. Status status(SUCCESS);
  414. json out_json;
  415. switch (type) {
  416. case TYPE_DATA:
  417. status = TransPredictReplyToData(reply, &out_json);
  418. break;
  419. case TYPE_TENSOR:
  420. status = TransPredictReplyToTensor(reply, &out_json);
  421. break;
  422. default:
  423. ERROR_INFER_STATUS(status, FAILED, "http message must have only one type of (data, tensor)");
  424. return status;
  425. }
  426. const std::string &out_str = out_json.dump();
  427. evbuffer_add(buf, out_str.data(), out_str.size());
  428. return status;
  429. }
  430. Status HttpHandleMsgDetail(struct evhttp_request *req, void *arg, struct evbuffer *retbuff) {
  431. PredictRequest request;
  432. PredictReply reply;
  433. HTTP_TYPE type;
  434. MSI_TIME_STAMP_START(ParseRequest)
  435. auto status = TransHTTPMsgToPredictRequest(req, &request, &type);
  436. MSI_TIME_STAMP_END(ParseRequest)
  437. if (status != SUCCESS) {
  438. MSI_LOG(ERROR) << "restful trans to request failed";
  439. return status;
  440. }
  441. MSI_TIME_STAMP_START(Predict)
  442. status = Session::Instance().Predict(request, reply);
  443. MSI_TIME_STAMP_END(Predict)
  444. if (status != SUCCESS) {
  445. MSI_LOG(ERROR) << "restful predict failed";
  446. return status;
  447. }
  448. MSI_TIME_STAMP_START(CreateReplyJson)
  449. status = TransPredictReplyToHTTPMsg(reply, type, retbuff);
  450. MSI_TIME_STAMP_END(CreateReplyJson)
  451. if (status != SUCCESS) {
  452. MSI_LOG(ERROR) << "restful trans to reply failed";
  453. return status;
  454. }
  455. return SUCCESS;
  456. }
  457. void http_handler_msg(struct evhttp_request *req, void *arg) {
  458. MSI_TIME_STAMP_START(TotalRestfulPredict)
  459. struct evbuffer *retbuff = evbuffer_new();
  460. if (retbuff == nullptr) {
  461. MSI_LOG_ERROR << "Create event buffer failed";
  462. return;
  463. }
  464. auto status = HttpHandleMsgDetail(req, arg, retbuff);
  465. if (status != SUCCESS) {
  466. ErrorMessage(req, status);
  467. evbuffer_free(retbuff);
  468. return;
  469. }
  470. MSI_TIME_STAMP_START(ReplyJson)
  471. evhttp_send_reply(req, HTTP_OK, "Client", retbuff);
  472. MSI_TIME_STAMP_END(ReplyJson)
  473. evbuffer_free(retbuff);
  474. MSI_TIME_STAMP_END(TotalRestfulPredict)
  475. }
  476. } // namespace serving
  477. } // namespace mindspore