You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.h 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_SERVING_CLIENT_H
  17. #define MINDSPORE_SERVING_CLIENT_H
  18. #include <string>
  19. #include <vector>
  20. #include <memory>
  21. #include <sstream>
  22. namespace google {
  23. namespace protobuf {
  24. class Message;
  25. }
  26. } // namespace google
  27. namespace mindspore {
  28. namespace serving {
  29. #define MS_API __attribute__((visibility("default")))
  30. namespace proto {
  31. class Tensor;
  32. class Instance;
  33. class PredictRequest;
  34. class PredictReply;
  35. class ErrorMsg;
  36. } // namespace proto
  37. namespace client {
  38. using ProtoMsgOwner = std::shared_ptr<google::protobuf::Message>;
  39. enum DataType {
  40. DT_UNKNOWN,
  41. DT_UINT8,
  42. DT_UINT16,
  43. DT_UINT32,
  44. DT_UINT64,
  45. DT_INT8,
  46. DT_INT16,
  47. DT_INT32,
  48. DT_INT64,
  49. DT_BOOL,
  50. DT_FLOAT16,
  51. DT_FLOAT32,
  52. DT_FLOAT64,
  53. DT_STRING,
  54. DT_BYTES,
  55. };
  56. enum StatusCode { SUCCESS = 0, FAILED, INVALID_INPUTS, SYSTEM_ERROR, UNAVAILABLE };
  57. class MS_API Status {
  58. public:
  59. Status() : status_code_(FAILED) {}
  60. Status(enum StatusCode status_code, const std::string &status_msg = "")
  61. : status_code_(status_code), status_msg_(status_msg) {}
  62. bool IsSuccess() const { return status_code_ == SUCCESS; }
  63. enum StatusCode StatusCode() const { return status_code_; }
  64. std::string StatusMessage() { return status_msg_; }
  65. bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
  66. bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
  67. bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
  68. bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
  69. operator bool() const = delete;
  70. template <class T>
  71. Status &operator<<(T val);
  72. Status &operator<<(DataType val);
  73. template <class T>
  74. Status &operator<<(const std::vector<T> &val);
  75. private:
  76. enum StatusCode status_code_;
  77. std::string status_msg_;
  78. };
  79. class MS_API Tensor {
  80. public:
  81. Tensor(const ProtoMsgOwner &owner, const proto::Tensor *proto_tensor)
  82. : message_owner_(owner), proto_tensor_(proto_tensor) {}
  83. virtual ~Tensor() = default;
  84. // Bytes type: for images etc.
  85. Status GetBytesData(std::vector<uint8_t> *val) const;
  86. Status GetStrData(std::string *val) const;
  87. Status GetData(std::vector<uint8_t> *val) const;
  88. Status GetData(std::vector<uint16_t> *val) const;
  89. Status GetData(std::vector<uint32_t> *val) const;
  90. Status GetData(std::vector<uint64_t> *val) const;
  91. Status GetData(std::vector<int8_t> *val) const;
  92. Status GetData(std::vector<int16_t> *val) const;
  93. Status GetData(std::vector<int32_t> *val) const;
  94. Status GetData(std::vector<int64_t> *val) const;
  95. Status GetData(std::vector<bool> *val) const;
  96. Status GetData(std::vector<float> *val) const;
  97. Status GetData(std::vector<double> *val) const;
  98. Status GetFp16Data(std::vector<uint16_t> *val) const;
  99. DataType GetDataType() const;
  100. std::vector<int64_t> GetShape() const;
  101. bool IsValid() const { return proto_tensor_ != nullptr; }
  102. protected:
  103. ProtoMsgOwner message_owner_;
  104. private:
  105. const proto::Tensor *proto_tensor_;
  106. };
  107. class MS_API MutableTensor : public Tensor {
  108. public:
  109. MutableTensor(const ProtoMsgOwner &owner, proto::Tensor *proto_tensor)
  110. : Tensor(owner, proto_tensor), mutable_proto_tensor_(proto_tensor) {}
  111. ~MutableTensor() = default;
  112. // Bytes type: for images etc.
  113. Status SetBytesData(const std::vector<uint8_t> &val);
  114. Status SetStrData(const std::string &val);
  115. Status SetData(const std::vector<uint8_t> &val, const std::vector<int64_t> &shape);
  116. Status SetData(const std::vector<uint16_t> &val, const std::vector<int64_t> &shape);
  117. Status SetData(const std::vector<uint32_t> &val, const std::vector<int64_t> &shape);
  118. Status SetData(const std::vector<uint64_t> &val, const std::vector<int64_t> &shape);
  119. Status SetData(const std::vector<int8_t> &val, const std::vector<int64_t> &shape);
  120. Status SetData(const std::vector<int16_t> &val, const std::vector<int64_t> &shape);
  121. Status SetData(const std::vector<int32_t> &val, const std::vector<int64_t> &shape);
  122. Status SetData(const std::vector<int64_t> &val, const std::vector<int64_t> &shape);
  123. Status SetData(const std::vector<bool> &val, const std::vector<int64_t> &shape);
  124. Status SetData(const std::vector<float> &val, const std::vector<int64_t> &shape);
  125. Status SetData(const std::vector<double> &val, const std::vector<int64_t> &shape);
  126. Status SetData(const void *data, size_t data_bytes_len, const std::vector<int64_t> &shape, DataType data_type);
  127. private:
  128. proto::Tensor *mutable_proto_tensor_;
  129. };
  130. class MS_API Instance {
  131. public:
  132. Instance(const ProtoMsgOwner &owner, const proto::Instance *proto_instance, const proto::ErrorMsg *error_msg)
  133. : message_owner_(owner), proto_instance_(proto_instance), error_msg_(error_msg) {}
  134. virtual ~Instance() = default;
  135. Tensor Get(const std::string &item_name) const;
  136. bool IsValid() const { return proto_instance_ != nullptr; }
  137. bool HasErrorMsg(int64_t *error_code, std::string *error_msg) const;
  138. protected:
  139. ProtoMsgOwner message_owner_;
  140. private:
  141. const proto::Instance *proto_instance_;
  142. const proto::ErrorMsg *error_msg_;
  143. };
  144. class MS_API MutableInstance : public Instance {
  145. public:
  146. MutableInstance(const ProtoMsgOwner &owner, proto::Instance *proto_instance)
  147. : Instance(owner, proto_instance, nullptr), mutable_proto_instance_(proto_instance) {}
  148. ~MutableInstance() = default;
  149. MutableTensor Add(const std::string &item_name);
  150. private:
  151. proto::Instance *mutable_proto_instance_;
  152. };
  153. class MS_API InstancesRequest {
  154. public:
  155. InstancesRequest();
  156. ~InstancesRequest() = default;
  157. MutableInstance AddInstance();
  158. private:
  159. std::shared_ptr<proto::PredictRequest> request_ = nullptr;
  160. friend class Client;
  161. };
  162. class MS_API InstancesReply {
  163. public:
  164. InstancesReply();
  165. ~InstancesReply() = default;
  166. std::vector<Instance> GetResult() const;
  167. private:
  168. std::shared_ptr<proto::PredictReply> reply_ = nullptr;
  169. friend class Client;
  170. };
  171. class ClientImpl;
  172. class MS_API Client {
  173. public:
  174. Client(const std::string &server_ip, uint64_t server_port, const std::string &servable_name,
  175. const std::string &method_name, uint64_t version_number = 0);
  176. ~Client() = default;
  177. Status SendRequest(const InstancesRequest &request, InstancesReply *reply);
  178. private:
  179. std::string server_ip_;
  180. uint64_t server_port_;
  181. std::string servable_name_;
  182. std::string method_name_;
  183. uint64_t version_number_ = 0;
  184. std::shared_ptr<ClientImpl> impl_;
  185. };
  186. template <class T>
  187. Status &Status::operator<<(T val) {
  188. std::stringstream stringstream;
  189. stringstream << val;
  190. status_msg_ += stringstream.str();
  191. return *this;
  192. }
  193. template <class T>
  194. Status &Status::operator<<(const std::vector<T> &val) {
  195. operator<<("[");
  196. for (size_t i = 0; i < val.size(); i++) {
  197. operator<<(val[i]);
  198. if (i != val.size() - 1) {
  199. operator<<(", ");
  200. }
  201. }
  202. operator<<("[");
  203. return *this;
  204. }
  205. } // namespace client
  206. } // namespace serving
  207. } // namespace mindspore
  208. #endif // MINDSPORE_SERVING_CLIENT_H

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.