You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grpc_client.h 3.7 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H
  17. #define MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H
  18. #include <grpcpp/grpcpp.h>
  19. #include <grpcpp/health_check_service_interface.h>
  20. #include <grpcpp/ext/proto_server_reflection_plugin.h>
  21. #include <memory>
  22. #include <functional>
  23. #include <thread>
  24. #include <string>
  25. #include <utility>
  26. #include "common/serving_common.h"
  27. #include "proto/ms_service.pb.h"
  28. #include "proto/ms_service.grpc.pb.h"
  29. #include "proto/ms_master.pb.h"
  30. #include "proto/ms_master.grpc.pb.h"
  31. #include "proto/ms_worker.grpc.pb.h"
  32. #include "proto/ms_agent.pb.h"
  33. #include "proto/ms_agent.grpc.pb.h"
  34. namespace mindspore {
  35. namespace serving {
  36. using PredictOnFinish = std::function<void()>;
  37. using AsyncPredictCallback = std::function<void(Status status)>;
  38. template <typename Request, typename Reply, typename MSStub>
  39. class MSServiceClient {
  40. public:
  41. MSServiceClient() = default;
  42. ~MSServiceClient() {
  43. if (in_running_) {
  44. cq_.Shutdown();
  45. if (client_thread_.joinable()) {
  46. try {
  47. client_thread_.join();
  48. } catch (const std::system_error &) {
  49. } catch (...) {
  50. }
  51. }
  52. }
  53. in_running_ = false;
  54. }
  55. void Start() {
  56. client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this);
  57. in_running_ = true;
  58. }
  59. void AsyncCompleteRpc() {
  60. void *got_tag;
  61. bool ok = false;
  62. while (cq_.Next(&got_tag, &ok)) {
  63. AsyncClientCall *call = static_cast<AsyncClientCall *>(got_tag);
  64. if (call->status.ok()) {
  65. call->callback(SUCCESS);
  66. } else {
  67. MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message();
  68. call->callback(Status(FAILED, call->status.error_message()));
  69. }
  70. delete call;
  71. }
  72. }
  73. void PredictAsync(const Request &request, Reply *reply, MSStub *stub, AsyncPredictCallback callback) {
  74. AsyncClientCall *call = new AsyncClientCall;
  75. call->reply = reply;
  76. call->callback = std::move(callback);
  77. call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_);
  78. call->response_reader->StartCall();
  79. call->response_reader->Finish(call->reply, &call->status, call);
  80. }
  81. private:
  82. struct AsyncClientCall {
  83. grpc::ClientContext context;
  84. grpc::Status status;
  85. Reply *reply;
  86. AsyncPredictCallback callback;
  87. std::shared_ptr<grpc::ClientAsyncResponseReader<Reply>> response_reader;
  88. };
  89. grpc::CompletionQueue cq_;
  90. std::thread client_thread_;
  91. bool in_running_ = false;
  92. };
  93. using MSPredictClient = MSServiceClient<proto::PredictRequest, proto::PredictReply, proto::MSWorker::Stub>;
  94. using MSDistributedClient =
  95. MSServiceClient<proto::DistributedPredictRequest, proto::DistributedPredictReply, proto::MSAgent::Stub>;
  96. extern std::unique_ptr<MSPredictClient> client_;
  97. extern std::unique_ptr<MSDistributedClient> distributed_client_;
  98. } // namespace serving
  99. } // namespace mindspore
  100. #endif // MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.