You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

http_process.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <map>
  17. #include <vector>
  18. #include <string>
  19. #include <nlohmann/json.hpp>
  20. #include "serving/ms_service.pb.h"
  21. #include "util/status.h"
  22. #include "core/session.h"
  23. #include "core/http_process.h"
  24. using ms_serving::MSService;
  25. using ms_serving::PredictReply;
  26. using ms_serving::PredictRequest;
  27. using nlohmann::json;
  28. namespace mindspore {
  29. namespace serving {
  30. const int BUF_MAX = 0x7FFFFFFF;
  31. static constexpr char HTTP_DATA[] = "data";
  32. static constexpr char HTTP_TENSOR[] = "tensor";
  33. enum HTTP_TYPE { TYPE_DATA = 0, TYPE_TENSOR };
  34. enum HTTP_DATA_TYPE { HTTP_DATA_NONE, HTTP_DATA_INT, HTTP_DATA_FLOAT };
  35. static const std::map<HTTP_DATA_TYPE, ms_serving::DataType> http_to_infer_map{
  36. {HTTP_DATA_NONE, ms_serving::MS_UNKNOWN},
  37. {HTTP_DATA_INT, ms_serving::MS_INT32},
  38. {HTTP_DATA_FLOAT, ms_serving::MS_FLOAT32}};
  39. Status GetPostMessage(struct evhttp_request *req, std::string *buf) {
  40. Status status(SUCCESS);
  41. size_t post_size = evbuffer_get_length(req->input_buffer);
  42. if (post_size == 0) {
  43. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message invalid");
  44. return status;
  45. } else if (post_size > BUF_MAX) {
  46. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message is bigger than 0x7FFFFFFF.");
  47. return status;
  48. } else {
  49. buf->resize(post_size);
  50. memcpy(buf->data(), evbuffer_pullup(req->input_buffer, -1), post_size);
  51. return status;
  52. }
  53. }
  54. Status CheckRequestValid(struct evhttp_request *http_request) {
  55. Status status(SUCCESS);
  56. switch (evhttp_request_get_command(http_request)) {
  57. case EVHTTP_REQ_POST:
  58. return status;
  59. default:
  60. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message only support POST right now");
  61. return status;
  62. }
  63. }
  64. void ErrorMessage(struct evhttp_request *req, Status status) {
  65. json error_json = {{"error_message", status.StatusMessage()}};
  66. std::string out_error_str = error_json.dump();
  67. struct evbuffer *retbuff = evbuffer_new();
  68. evbuffer_add(retbuff, out_error_str.data(), out_error_str.size());
  69. evhttp_send_reply(req, HTTP_OK, "Client", retbuff);
  70. evbuffer_free(retbuff);
  71. }
  72. Status CheckMessageValid(const json &message_info, HTTP_TYPE *type) {
  73. Status status(SUCCESS);
  74. int count = 0;
  75. if (message_info.find(HTTP_DATA) != message_info.end()) {
  76. *type = TYPE_DATA;
  77. count++;
  78. }
  79. if (message_info.find(HTTP_TENSOR) != message_info.end()) {
  80. *type = TYPE_TENSOR;
  81. count++;
  82. }
  83. if (count != 1) {
  84. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message must have only one type of (data, tensor)");
  85. return status;
  86. }
  87. return status;
  88. }
  89. Status GetDataFromJson(const json &json_data, std::string *data, HTTP_DATA_TYPE *type) {
  90. Status status(SUCCESS);
  91. if (json_data.is_number_integer()) {
  92. if (*type == HTTP_DATA_NONE) {
  93. *type = HTTP_DATA_INT;
  94. } else if (*type != HTTP_DATA_INT) {
  95. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input data type should be consistent");
  96. return status;
  97. }
  98. auto s_data = json_data.get<int32_t>();
  99. data->append(reinterpret_cast<char *>(&s_data), sizeof(int32_t));
  100. } else if (json_data.is_number_float()) {
  101. if (*type == HTTP_DATA_NONE) {
  102. *type = HTTP_DATA_FLOAT;
  103. } else if (*type != HTTP_DATA_FLOAT) {
  104. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input data type should be consistent");
  105. return status;
  106. }
  107. auto s_data = json_data.get<float>();
  108. data->append(reinterpret_cast<char *>(&s_data), sizeof(float));
  109. } else {
  110. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input data type should be int or float");
  111. return status;
  112. }
  113. return SUCCESS;
  114. }
  115. Status RecusiveGetTensor(const json &json_data, size_t depth, std::vector<int> *shape, std::string *data,
  116. HTTP_DATA_TYPE *type) {
  117. Status status(SUCCESS);
  118. if (depth >= 10) {
  119. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the tensor shape dims is larger than 10");
  120. return status;
  121. }
  122. if (!json_data.is_array()) {
  123. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the tensor is constructed illegally");
  124. return status;
  125. }
  126. int cur_dim = json_data.size();
  127. if (shape->size() <= depth) {
  128. shape->push_back(cur_dim);
  129. } else if ((*shape)[depth] != cur_dim) {
  130. return INFER_STATUS(INVALID_INPUTS) << "the tensor shape is constructed illegally";
  131. }
  132. if (json_data.at(0).is_array()) {
  133. for (const auto &item : json_data) {
  134. status = RecusiveGetTensor(item, depth + 1, shape, data, type);
  135. if (status != SUCCESS) {
  136. return status;
  137. }
  138. }
  139. } else {
  140. // last dim, read the data
  141. for (auto item : json_data) {
  142. status = GetDataFromJson(item, data, type);
  143. if (status != SUCCESS) {
  144. return status;
  145. }
  146. }
  147. }
  148. return status;
  149. }
  150. Status TransDataToPredictRequest(const json &message_info, PredictRequest *request) {
  151. Status status = SUCCESS;
  152. auto tensors = message_info.find(HTTP_DATA);
  153. if (tensors == message_info.end()) {
  154. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message do not have data type");
  155. return status;
  156. }
  157. if (tensors->size() == 0) {
  158. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input tensor list is null");
  159. return status;
  160. }
  161. for (const auto &tensor : *tensors) {
  162. std::string msg_data;
  163. HTTP_DATA_TYPE type{HTTP_DATA_NONE};
  164. if (!tensor.is_array()) {
  165. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the tensor is constructed illegally");
  166. return status;
  167. }
  168. if (tensor.size() == 0) {
  169. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input tensor is null");
  170. return status;
  171. }
  172. for (const auto &tensor_data : tensor) {
  173. status = GetDataFromJson(tensor_data, &msg_data, &type);
  174. if (status != SUCCESS) {
  175. return status;
  176. }
  177. }
  178. auto iter = http_to_infer_map.find(type);
  179. if (iter == http_to_infer_map.end()) {
  180. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input type is not supported right now");
  181. return status;
  182. }
  183. auto infer_tensor = request->add_data();
  184. infer_tensor->set_tensor_type(iter->second);
  185. infer_tensor->set_data(msg_data.data(), msg_data.size());
  186. }
  187. // get model required shape
  188. std::vector<inference::InferTensor> tensor_list;
  189. status = Session::Instance().GetModelInputsInfo(tensor_list);
  190. if (status != SUCCESS) {
  191. ERROR_INFER_STATUS(status, FAILED, "get model inputs info failed");
  192. return status;
  193. }
  194. if (request->data_size() != static_cast<int64_t>(tensor_list.size())) {
  195. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the inputs number is not equal to model required");
  196. return status;
  197. }
  198. for (int i = 0; i < request->data_size(); i++) {
  199. for (size_t j = 0; j < tensor_list[i].shape().size(); ++j) {
  200. request->mutable_data(i)->mutable_tensor_shape()->add_dims(tensor_list[i].shape()[j]);
  201. }
  202. }
  203. return SUCCESS;
  204. }
  205. Status TransTensorToPredictRequest(const json &message_info, PredictRequest *request) {
  206. Status status(SUCCESS);
  207. auto tensors = message_info.find(HTTP_TENSOR);
  208. if (tensors == message_info.end()) {
  209. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message do not have tensor type");
  210. return status;
  211. }
  212. for (const auto &tensor : *tensors) {
  213. std::vector<int> shape;
  214. std::string msg_data;
  215. HTTP_DATA_TYPE type{HTTP_DATA_NONE};
  216. RecusiveGetTensor(tensor, 0, &shape, &msg_data, &type);
  217. auto iter = http_to_infer_map.find(type);
  218. if (iter == http_to_infer_map.end()) {
  219. ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input type is not supported right now");
  220. return status;
  221. }
  222. auto infer_tensor = request->add_data();
  223. infer_tensor->set_tensor_type(iter->second);
  224. infer_tensor->set_data(msg_data.data(), msg_data.size());
  225. for (const auto dim : shape) {
  226. infer_tensor->mutable_tensor_shape()->add_dims(dim);
  227. }
  228. }
  229. return status;
  230. }
  231. Status TransHTTPMsgToPredictRequest(struct evhttp_request *http_request, PredictRequest *request, HTTP_TYPE *type) {
  232. Status status = CheckRequestValid(http_request);
  233. if (status != SUCCESS) {
  234. return status;
  235. }
  236. std::string post_message;
  237. status = GetPostMessage(http_request, &post_message);
  238. if (status != SUCCESS) {
  239. return status;
  240. }
  241. json message_info;
  242. try {
  243. message_info = nlohmann::json::parse(post_message);
  244. } catch (nlohmann::json::exception &e) {
  245. std::string json_exception = e.what();
  246. std::string error_message = "Illegal JSON format." + json_exception;
  247. ERROR_INFER_STATUS(status, INVALID_INPUTS, error_message);
  248. return status;
  249. }
  250. status = CheckMessageValid(message_info, type);
  251. if (status != SUCCESS) {
  252. return status;
  253. }
  254. switch (*type) {
  255. case TYPE_DATA:
  256. status = TransDataToPredictRequest(message_info, request);
  257. break;
  258. case TYPE_TENSOR:
  259. status = TransTensorToPredictRequest(message_info, request);
  260. break;
  261. default:
  262. ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message must have only one type of (data, tensor)");
  263. return status;
  264. }
  265. return status;
  266. }
  267. Status GetJsonFromTensor(const ms_serving::Tensor &tensor, int len, int *pos, json *out_json) {
  268. Status status(SUCCESS);
  269. switch (tensor.tensor_type()) {
  270. case ms_serving::MS_INT32: {
  271. std::vector<int> result_tensor;
  272. for (int j = 0; j < len; j++) {
  273. int val;
  274. memcpy(&val, reinterpret_cast<const int *>(tensor.data().data()) + *pos + j, sizeof(int));
  275. result_tensor.push_back(val);
  276. }
  277. *out_json = result_tensor;
  278. *pos += len;
  279. break;
  280. }
  281. case ms_serving::MS_FLOAT32: {
  282. std::vector<float> result_tensor;
  283. for (int j = 0; j < len; j++) {
  284. float val;
  285. memcpy(&val, reinterpret_cast<const float *>(tensor.data().data()) + *pos + j, sizeof(float));
  286. result_tensor.push_back(val);
  287. }
  288. *out_json = result_tensor;
  289. *pos += len;
  290. break;
  291. }
  292. default:
  293. MSI_LOG(ERROR) << "the result type is not supported in restful api, type is " << tensor.tensor_type();
  294. ERROR_INFER_STATUS(status, FAILED, "reply have unsupported type");
  295. }
  296. return status;
  297. }
  298. Status TransPredictReplyToData(const PredictReply &reply, json *out_json) {
  299. Status status(SUCCESS);
  300. for (int i = 0; i < reply.result_size(); i++) {
  301. json tensor_json;
  302. int num = 1;
  303. for (auto j = 0; j < reply.result(i).tensor_shape().dims_size(); j++) {
  304. num *= reply.result(i).tensor_shape().dims(j);
  305. }
  306. int pos = 0;
  307. status = GetJsonFromTensor(reply.result(i), num, &pos, &tensor_json);
  308. if (status != SUCCESS) {
  309. return status;
  310. }
  311. (*out_json)["data"].push_back(tensor_json);
  312. }
  313. return status;
  314. }
  315. Status RecusiveGetJson(const ms_serving::Tensor &tensor, int depth, int *pos, json *out_json) {
  316. Status status(SUCCESS);
  317. if (depth >= 10) {
  318. ERROR_INFER_STATUS(status, FAILED, "result tensor shape dims is larger than 10");
  319. return status;
  320. }
  321. if (depth == tensor.tensor_shape().dims_size() - 1) {
  322. status = GetJsonFromTensor(tensor, tensor.tensor_shape().dims(depth), pos, out_json);
  323. if (status != SUCCESS) {
  324. return status;
  325. }
  326. } else {
  327. for (int i = 0; i < tensor.tensor_shape().dims(depth); i++) {
  328. json tensor_json;
  329. status = RecusiveGetJson(tensor, depth + 1, pos, &tensor_json);
  330. if (status != SUCCESS) {
  331. return status;
  332. }
  333. out_json->push_back(tensor_json);
  334. }
  335. }
  336. return status;
  337. }
  338. Status TransPredictReplyToTensor(const PredictReply &reply, json *out_json) {
  339. Status status(SUCCESS);
  340. for (int i = 0; i < reply.result_size(); i++) {
  341. json tensor_json;
  342. int pos = 0;
  343. status = RecusiveGetJson(reply.result(i), 0, &pos, &tensor_json);
  344. if (status != SUCCESS) {
  345. return status;
  346. }
  347. (*out_json)["tensor"].push_back(tensor_json);
  348. }
  349. return status;
  350. }
  351. Status TransPredictReplyToHTTPMsg(const PredictReply &reply, const HTTP_TYPE &type, struct evbuffer *buf) {
  352. Status status(SUCCESS);
  353. json out_json;
  354. switch (type) {
  355. case TYPE_DATA:
  356. status = TransPredictReplyToData(reply, &out_json);
  357. break;
  358. case TYPE_TENSOR:
  359. status = TransPredictReplyToTensor(reply, &out_json);
  360. break;
  361. default:
  362. ERROR_INFER_STATUS(status, FAILED, "http message must have only one type of (data, tensor)");
  363. return status;
  364. }
  365. std::string out_str = out_json.dump();
  366. evbuffer_add(buf, out_str.data(), out_str.size());
  367. return status;
  368. }
  369. void http_handler_msg(struct evhttp_request *req, void *arg) {
  370. std::cout << "in handle" << std::endl;
  371. PredictRequest request;
  372. PredictReply reply;
  373. HTTP_TYPE type;
  374. auto status = TransHTTPMsgToPredictRequest(req, &request, &type);
  375. if (status != SUCCESS) {
  376. ErrorMessage(req, status);
  377. MSI_LOG(ERROR) << "restful trans to request failed";
  378. return;
  379. }
  380. MSI_TIME_STAMP_START(Predict)
  381. status = Session::Instance().Predict(request, reply);
  382. if (status != SUCCESS) {
  383. ErrorMessage(req, status);
  384. MSI_LOG(ERROR) << "restful predict failed";
  385. }
  386. MSI_TIME_STAMP_END(Predict)
  387. struct evbuffer *retbuff = evbuffer_new();
  388. status = TransPredictReplyToHTTPMsg(reply, type, retbuff);
  389. if (status != SUCCESS) {
  390. ErrorMessage(req, status);
  391. MSI_LOG(ERROR) << "restful trans to reply failed";
  392. return;
  393. }
  394. evhttp_send_reply(req, HTTP_OK, "Client", retbuff);
  395. evbuffer_free(retbuff);
  396. }
  397. } // namespace serving
  398. } // namespace mindspore