You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.cc 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "client/cpp/client.h"
  17. #include <grpcpp/grpcpp.h>
  18. #include <google/protobuf/text_format.h>
  19. #include <algorithm>
  20. #include <unordered_map>
  21. #include <utility>
  22. #include <sstream>
  23. #include "proto/ms_service.pb.h"
  24. #include "proto/ms_service.grpc.pb.h"
  25. namespace mindspore {
  26. namespace serving {
  27. namespace client {
  28. Status &Status::operator<<(DataType val) {
  29. std::unordered_map<DataType, std::string> data_type_map = {
  30. {DT_UINT8, "uint8"}, {DT_UINT16, "uint16"}, {DT_UINT32, "uint32"}, {DT_UINT64, "uint64"},
  31. {DT_INT8, "int8"}, {DT_INT16, "int16"}, {DT_INT32, "int32"}, {DT_INT64, "int64"},
  32. {DT_BOOL, "bool"}, {DT_FLOAT16, "float16"}, {DT_FLOAT32, "float32"}, {DT_FLOAT64, "float64"},
  33. {DT_STRING, "string"}, {DT_BYTES, "bytes"}, {DT_UNKNOWN, "unknown"},
  34. };
  35. auto it = data_type_map.find(val);
  36. if (it == data_type_map.end()) {
  37. status_msg_ += "unknown";
  38. } else {
  39. status_msg_ += it->second;
  40. }
  41. return *this;
  42. }
  43. Status &operator<<(Status &status, proto::DataType val) {
  44. std::unordered_map<proto::DataType, std::string> data_type_map = {
  45. {proto::MS_UINT8, "uint8"}, {proto::MS_UINT16, "uint16"}, {proto::MS_UINT32, "uint32"},
  46. {proto::MS_UINT64, "uint64"}, {proto::MS_INT8, "int8"}, {proto::MS_INT16, "int16"},
  47. {proto::MS_INT32, "int32"}, {proto::MS_INT64, "int64"}, {proto::MS_BOOL, "bool"},
  48. {proto::MS_FLOAT16, "float16"}, {proto::MS_FLOAT32, "float32"}, {proto::MS_FLOAT64, "float64"},
  49. {proto::MS_STRING, "string"}, {proto::MS_BYTES, "bytes"}, {proto::MS_UNKNOWN, "unknown"},
  50. };
  51. auto it = data_type_map.find(val);
  52. if (it == data_type_map.end()) {
  53. status << "unknown";
  54. } else {
  55. status << it->second;
  56. }
  57. return status;
  58. }
  59. Status &operator<<(Status &status, grpc::StatusCode val) {
  60. std::unordered_map<grpc::StatusCode, std::string> data_type_map = {
  61. {grpc::OK, "OK"},
  62. {grpc::CANCELLED, "CANCELLED"},
  63. {grpc::UNKNOWN, "UNKNOWN"},
  64. {grpc::INVALID_ARGUMENT, "INVALID_ARGUMENT"},
  65. {grpc::DEADLINE_EXCEEDED, "DEADLINE_EXCEEDED"},
  66. {grpc::NOT_FOUND, "NOT_FOUND"},
  67. {grpc::ALREADY_EXISTS, "ALREADY_EXISTS"},
  68. {grpc::PERMISSION_DENIED, "PERMISSION_DENIED"},
  69. {grpc::UNAUTHENTICATED, "UNAUTHENTICATED"},
  70. {grpc::RESOURCE_EXHAUSTED, "RESOURCE_EXHAUSTED"},
  71. {grpc::FAILED_PRECONDITION, "FAILED_PRECONDITION"},
  72. {grpc::ABORTED, "ABORTED"},
  73. {grpc::OUT_OF_RANGE, "OUT_OF_RANGE"},
  74. {grpc::UNIMPLEMENTED, "UNIMPLEMENTED"},
  75. {grpc::INTERNAL, "INTERNAL"},
  76. {grpc::UNAVAILABLE, "UNAVAILABLE"},
  77. {grpc::DATA_LOSS, "DATA_LOSS"},
  78. };
  79. auto it = data_type_map.find(val);
  80. if (it == data_type_map.end()) {
  81. status << "unknown";
  82. } else {
  83. status << it->second;
  84. }
  85. return status;
  86. }
  87. Status MutableTensor::SetBytesData(const std::vector<uint8_t> &val) {
  88. if (mutable_proto_tensor_ == nullptr) {
  89. return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr";
  90. }
  91. auto proto_shape = mutable_proto_tensor_->mutable_shape();
  92. proto_shape->add_dims(1);
  93. mutable_proto_tensor_->set_dtype(proto::MS_BYTES);
  94. if (val.empty()) {
  95. return Status(INVALID_INPUTS) << "Input index bytes val len is empty";
  96. }
  97. mutable_proto_tensor_->add_bytes_val(val.data(), val.size());
  98. return SUCCESS;
  99. }
  100. Status MutableTensor::SetStrData(const std::string &val) {
  101. if (mutable_proto_tensor_ == nullptr) {
  102. return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr";
  103. }
  104. auto proto_shape = mutable_proto_tensor_->mutable_shape();
  105. proto_shape->add_dims(val.size());
  106. mutable_proto_tensor_->set_dtype(proto::MS_STRING);
  107. if (val.empty()) {
  108. return Status(INVALID_INPUTS) << "string index string val len is empty";
  109. }
  110. mutable_proto_tensor_->add_bytes_val(val);
  111. return SUCCESS;
  112. }
  113. Status MutableTensor::SetData(const std::vector<uint8_t> &val, const std::vector<int64_t> &shape) {
  114. return SetData(val.data(), val.size() * sizeof(uint8_t), shape, DT_UINT8);
  115. }
  116. Status MutableTensor::SetData(const std::vector<uint16_t> &val, const std::vector<int64_t> &shape) {
  117. return SetData(val.data(), val.size() * sizeof(uint16_t), shape, DT_UINT16);
  118. }
  119. Status MutableTensor::SetData(const std::vector<uint32_t> &val, const std::vector<int64_t> &shape) {
  120. return SetData(val.data(), val.size() * sizeof(uint32_t), shape, DT_UINT32);
  121. }
  122. Status MutableTensor::SetData(const std::vector<uint64_t> &val, const std::vector<int64_t> &shape) {
  123. return SetData(val.data(), val.size() * sizeof(uint64_t), shape, DT_UINT64);
  124. }
  125. Status MutableTensor::SetData(const std::vector<int8_t> &val, const std::vector<int64_t> &shape) {
  126. return SetData(val.data(), val.size() * sizeof(int8_t), shape, DT_INT8);
  127. }
  128. Status MutableTensor::SetData(const std::vector<int16_t> &val, const std::vector<int64_t> &shape) {
  129. return SetData(val.data(), val.size() * sizeof(int16_t), shape, DT_INT16);
  130. }
  131. Status MutableTensor::SetData(const std::vector<int32_t> &val, const std::vector<int64_t> &shape) {
  132. return SetData(val.data(), val.size() * sizeof(int32_t), shape, DT_INT32);
  133. }
  134. Status MutableTensor::SetData(const std::vector<int64_t> &val, const std::vector<int64_t> &shape) {
  135. return SetData(val.data(), val.size() * sizeof(int64_t), shape, DT_INT64);
  136. }
  137. Status MutableTensor::SetData(const std::vector<bool> &val, const std::vector<int64_t> &shape) {
  138. std::vector<uint8_t> val_uint8;
  139. std::transform(val.begin(), val.end(), std::back_inserter(val_uint8),
  140. [](bool item) { return static_cast<uint8_t>(item); });
  141. return SetData(val_uint8.data(), val_uint8.size() * sizeof(bool), shape, DT_BOOL);
  142. }
  143. Status MutableTensor::SetData(const std::vector<float> &val, const std::vector<int64_t> &shape) {
  144. return SetData(val.data(), val.size() * sizeof(float), shape, DT_FLOAT32);
  145. }
  146. Status MutableTensor::SetData(const std::vector<double> &val, const std::vector<int64_t> &shape) {
  147. return SetData(val.data(), val.size() * sizeof(double), shape, DT_FLOAT64);
  148. }
  149. Status MutableTensor::SetData(const void *data, size_t data_len, const std::vector<int64_t> &shape,
  150. DataType data_type) {
  151. if (mutable_proto_tensor_ == nullptr) {
  152. return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr";
  153. }
  154. if (data == nullptr || data_len == 0) {
  155. return Status(INVALID_INPUTS) << "data cannot be nullptr, or data len cannot be 0";
  156. }
  157. mutable_proto_tensor_->set_data(data, data_len);
  158. auto proto_shape = mutable_proto_tensor_->mutable_shape();
  159. std::unordered_map<DataType, std::pair<proto::DataType, int64_t>> data_type_map = {
  160. {DT_UINT8, {proto::MS_UINT8, sizeof(uint8_t)}},
  161. {DT_UINT16, {proto::MS_UINT16, sizeof(uint16_t)}},
  162. {DT_UINT32, {proto::MS_UINT32, sizeof(uint32_t)}},
  163. {DT_UINT64, {proto::MS_UINT64, sizeof(uint64_t)}},
  164. {DT_INT8, {proto::MS_INT8, sizeof(int8_t)}},
  165. {DT_INT16, {proto::MS_INT16, sizeof(int16_t)}},
  166. {DT_INT32, {proto::MS_INT32, sizeof(int32_t)}},
  167. {DT_INT64, {proto::MS_INT64, sizeof(int64_t)}},
  168. {DT_BOOL, {proto::MS_BOOL, sizeof(bool)}},
  169. {DT_FLOAT16, {proto::MS_FLOAT16, 2}},
  170. {DT_FLOAT32, {proto::MS_FLOAT32, 4}},
  171. {DT_FLOAT64, {proto::MS_FLOAT64, 8}},
  172. };
  173. auto it = data_type_map.find(data_type);
  174. if (it == data_type_map.end()) {
  175. return Status(INVALID_INPUTS) << "Input unsupported find data type " << data_type;
  176. }
  177. mutable_proto_tensor_->set_dtype(it->second.first);
  178. auto shape_str = [](const std::vector<int64_t> &val) noexcept {
  179. std::stringstream sstream;
  180. sstream << "[";
  181. for (size_t i = 0; i < val.size(); i++) {
  182. sstream << val[i];
  183. if (i + 1 < val.size()) {
  184. sstream << ", ";
  185. }
  186. }
  187. sstream << "]";
  188. return sstream.str();
  189. };
  190. int64_t element_cnt = 1;
  191. for (auto &item : shape) {
  192. proto_shape->add_dims(item);
  193. if (item <= 0 || item >= INT64_MAX || INT64_MAX / element_cnt < item) {
  194. return Status(INVALID_INPUTS) << "Input input shape invalid " << shape_str(shape);
  195. }
  196. }
  197. auto item_size = it->second.second;
  198. if (static_cast<int64_t>(data_len) / element_cnt < item_size ||
  199. element_cnt * item_size != static_cast<int64_t>(data_len)) {
  200. return Status(INVALID_INPUTS) << "Input input shape " << shape_str(shape) << " does not match data len "
  201. << data_len;
  202. }
  203. return SUCCESS;
  204. }
  205. Status Tensor::GetBytesData(std::vector<uint8_t> *val) const {
  206. if (val == nullptr) {
  207. return Status(SYSTEM_ERROR) << "input val cannot be nullptr";
  208. }
  209. if (proto_tensor_ == nullptr) {
  210. return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr";
  211. }
  212. if (proto_tensor_->dtype() != proto::MS_BYTES) {
  213. return Status(INVALID_INPUTS) << "Output data type is not match, its' real data type is " << proto_tensor_->dtype();
  214. }
  215. auto &bytes_data = proto_tensor_->bytes_val();
  216. if (bytes_data.size() != 1) {
  217. return Status(INVALID_INPUTS) << "Bytes value type size can only be 1";
  218. }
  219. val->resize(bytes_data[0].size());
  220. memcpy(val->data(), val->data(), bytes_data[0].size());
  221. return SUCCESS;
  222. }
  223. Status Tensor::GetStrData(std::string *val) const {
  224. if (val == nullptr) {
  225. return Status(SYSTEM_ERROR) << "input val cannot be nullptr";
  226. }
  227. if (proto_tensor_ == nullptr) {
  228. return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr";
  229. }
  230. if (proto_tensor_->dtype() != proto::MS_STRING) {
  231. return Status(INVALID_INPUTS) << "Output data type is not match, its' real data type is " << proto_tensor_->dtype();
  232. }
  233. auto &bytes_data = proto_tensor_->bytes_val();
  234. if (bytes_data.size() != 1) {
  235. return Status(INVALID_INPUTS) << "Bytes value type size can only be 1";
  236. }
  237. val->resize(bytes_data[0].size());
  238. memcpy(val->data(), val->data(), bytes_data[0].size());
  239. return SUCCESS;
  240. }
  241. template <proto::DataType proto_dtype, class DT>
  242. Status GetInputImp(const proto::Tensor *proto_tensor, std::vector<DT> *val) {
  243. if (val == nullptr) {
  244. return Status(SYSTEM_ERROR) << "input val cannot be nullptr";
  245. }
  246. if (proto_tensor == nullptr) {
  247. return Status(SYSTEM_ERROR) << "proto tensor cannot be nullptr";
  248. }
  249. if (proto_tensor->dtype() != proto_dtype) {
  250. return Status(INVALID_INPUTS) << "Output data type is not match, its' real data type is " << proto_tensor->dtype();
  251. }
  252. auto data = proto_tensor->data().data();
  253. auto data_len = proto_tensor->data().length();
  254. val->resize(data_len / sizeof(DT));
  255. memcpy(val->data(), data, data_len);
  256. return SUCCESS;
  257. }
  258. Status Tensor::GetData(std::vector<uint8_t> *val) const { return GetInputImp<proto::MS_UINT8>(proto_tensor_, val); }
  259. Status Tensor::GetData(std::vector<uint16_t> *val) const { return GetInputImp<proto::MS_UINT16>(proto_tensor_, val); }
  260. Status Tensor::GetData(std::vector<uint32_t> *val) const { return GetInputImp<proto::MS_UINT32>(proto_tensor_, val); }
  261. Status Tensor::GetData(std::vector<uint64_t> *val) const { return GetInputImp<proto::MS_UINT64>(proto_tensor_, val); }
  262. Status Tensor::GetData(std::vector<int8_t> *val) const { return GetInputImp<proto::MS_INT8>(proto_tensor_, val); }
  263. Status Tensor::GetData(std::vector<int16_t> *val) const { return GetInputImp<proto::MS_INT16>(proto_tensor_, val); }
  264. Status Tensor::GetData(std::vector<int32_t> *val) const { return GetInputImp<proto::MS_INT32>(proto_tensor_, val); }
  265. Status Tensor::GetData(std::vector<int64_t> *val) const { return GetInputImp<proto::MS_INT64>(proto_tensor_, val); }
  266. Status Tensor::GetData(std::vector<bool> *val) const {
  267. if (val == nullptr) {
  268. return Status(SYSTEM_ERROR) << "input val cannot be nullptr";
  269. }
  270. std::vector<uint8_t> val_uint8;
  271. Status status = GetInputImp<proto::MS_BOOL>(proto_tensor_, &val_uint8);
  272. if (!status.IsSuccess()) {
  273. return status;
  274. }
  275. std::transform(val_uint8.begin(), val_uint8.end(), std::back_inserter(*val), [](uint8_t item) { return item != 0; });
  276. return SUCCESS;
  277. }
  278. Status Tensor::GetData(std::vector<float> *val) const { return GetInputImp<proto::MS_FLOAT32>(proto_tensor_, val); }
  279. Status Tensor::GetData(std::vector<double> *val) const { return GetInputImp<proto::MS_FLOAT64>(proto_tensor_, val); }
  280. Status Tensor::GetFp16Data(std::vector<uint16_t> *val) const {
  281. return GetInputImp<proto::MS_FLOAT16>(proto_tensor_, val);
  282. }
  283. DataType Tensor::GetDataType() const {
  284. if (proto_tensor_ == nullptr) {
  285. std::cout << "proto tensor cannot be nullptr" << std::endl;
  286. return DT_UNKNOWN;
  287. }
  288. std::unordered_map<proto::DataType, DataType> data_type_map = {
  289. {proto::MS_UNKNOWN, DT_UNKNOWN}, {proto::MS_UINT8, DT_UINT8}, {proto::MS_UINT16, DT_UINT16},
  290. {proto::MS_UINT32, DT_UINT32}, {proto::MS_UINT64, DT_UINT64}, {proto::MS_INT8, DT_INT8},
  291. {proto::MS_INT16, DT_INT16}, {proto::MS_INT32, DT_INT32}, {proto::MS_INT64, DT_INT64},
  292. {proto::MS_BOOL, DT_BOOL}, {proto::MS_FLOAT16, DT_FLOAT16}, {proto::MS_FLOAT32, DT_FLOAT32},
  293. {proto::MS_FLOAT64, DT_FLOAT64}, {proto::MS_STRING, DT_STRING}, {proto::MS_BYTES, DT_BYTES},
  294. };
  295. auto it_dt = data_type_map.find(proto_tensor_->dtype());
  296. if (it_dt == data_type_map.end()) {
  297. std::cout << "Unsupported data type " << proto_tensor_->dtype() << std::endl;
  298. return DT_UNKNOWN;
  299. }
  300. return it_dt->second;
  301. }
  302. std::vector<int64_t> Tensor::GetShape() const {
  303. if (proto_tensor_ == nullptr) {
  304. std::cout << "proto tensor cannot be nullptr" << std::endl;
  305. return std::vector<int64_t>();
  306. }
  307. std::vector<int64_t> shape;
  308. auto &dims = proto_tensor_->shape().dims();
  309. std::copy(dims.begin(), dims.end(), std::back_inserter(shape));
  310. return shape;
  311. }
  312. Tensor Instance::Get(const std::string &item_name) const {
  313. if (proto_instance_ == nullptr) {
  314. std::cout << "proto instance cannot be nullptr" << std::endl;
  315. return Tensor(nullptr, nullptr);
  316. }
  317. auto &items = proto_instance_->items();
  318. auto it = items.find(item_name);
  319. if (it == items.end()) {
  320. std::cout << "Cannot find item name " << item_name << std::endl;
  321. return Tensor(nullptr, nullptr);
  322. }
  323. return Tensor(message_owner_, &it->second);
  324. }
  325. bool Instance::HasErrorMsg(int64_t *error_code, std::string *error_msg) const {
  326. if (error_code == nullptr) {
  327. return false;
  328. }
  329. if (error_msg == nullptr) {
  330. return false;
  331. }
  332. if (error_msg_ == nullptr) {
  333. return false;
  334. }
  335. *error_code = error_msg_->error_code();
  336. *error_msg = error_msg_->error_msg();
  337. return true;
  338. }
  339. MutableTensor MutableInstance::Add(const std::string &item_name) {
  340. if (mutable_proto_instance_ == nullptr) {
  341. std::cout << "proto instance cannot be nullptr" << std::endl;
  342. return MutableTensor(nullptr, nullptr);
  343. }
  344. auto items = mutable_proto_instance_->mutable_items();
  345. auto &proto_tensor = (*items)[item_name];
  346. return MutableTensor(message_owner_, &proto_tensor);
  347. }
  348. InstancesRequest::InstancesRequest() { request_ = std::make_shared<proto::PredictRequest>(); }
  349. MutableInstance InstancesRequest::AddInstance() {
  350. auto proto_instance = request_->add_instances();
  351. return MutableInstance(request_, proto_instance);
  352. }
  353. InstancesReply::InstancesReply() { reply_ = std::make_shared<proto::PredictReply>(); }
  354. std::vector<Instance> InstancesReply::GetResult() const {
  355. std::vector<Instance> instances;
  356. auto &proto_instances = reply_->instances();
  357. auto &proto_error_msgs = reply_->error_msg();
  358. for (int i = 0; i < proto_instances.size(); i++) {
  359. auto &proto_instance = proto_instances[i];
  360. const proto::ErrorMsg *error_msg = nullptr;
  361. if (proto_error_msgs.size() == 1) {
  362. error_msg = &proto_error_msgs[0];
  363. } else if (proto_error_msgs.size() == proto_instances.size() && proto_error_msgs[i].error_code() != 0) {
  364. error_msg = &proto_error_msgs[i];
  365. }
  366. instances.push_back(Instance(reply_, &proto_instance, error_msg));
  367. }
  368. return instances;
  369. }
  370. class ClientImpl {
  371. public:
  372. ClientImpl(const std::string &server_ip, uint64_t server_port) {
  373. std::string target_str = server_ip + ":" + std::to_string(server_port);
  374. auto channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials());
  375. stub_ = proto::MSService::NewStub(channel);
  376. }
  377. Status Predict(const proto::PredictRequest &request, proto::PredictReply *reply) {
  378. if (reply == nullptr) {
  379. return Status(SYSTEM_ERROR, "ClientImpl::Predict input reply cannot be nullptr");
  380. }
  381. grpc::ClientContext context;
  382. // The actual RPC.
  383. grpc::Status status = stub_->Predict(&context, request, reply);
  384. if (status.ok()) {
  385. return SUCCESS;
  386. } else {
  387. std::cout << status.error_code() << ": " << status.error_message() << std::endl;
  388. return Status(FAILED, status.error_message());
  389. }
  390. }
  391. private:
  392. std::unique_ptr<proto::MSService::Stub> stub_;
  393. };
  394. Client::Client(const std::string &server_ip, uint64_t server_port, const std::string &servable_name,
  395. const std::string &method_name, uint64_t version_number)
  396. : server_ip_(server_ip),
  397. server_port_(server_port),
  398. servable_name_(servable_name),
  399. method_name_(method_name),
  400. version_number_(version_number),
  401. impl_(std::make_shared<ClientImpl>(server_ip, server_port)) {}
  402. Status Client::SendRequest(const InstancesRequest &request, InstancesReply *reply) {
  403. if (reply == nullptr) {
  404. return Status(SYSTEM_ERROR) << "input reply cannot be nullptr";
  405. }
  406. proto::PredictRequest *proto_request = request.request_.get();
  407. proto::PredictReply *proto_reply = reply->reply_.get();
  408. auto servable_spec = proto_request->mutable_servable_spec();
  409. servable_spec->set_name(servable_name_);
  410. servable_spec->set_method_name(method_name_);
  411. servable_spec->set_version_number(version_number_);
  412. Status result = impl_->Predict(*proto_request, proto_reply);
  413. return result;
  414. }
  415. } // namespace client
  416. } // namespace serving
  417. } // namespace mindspore

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.