You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grpc_client.cc 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "debug/debugger/grpc_client.h"
  17. #include <thread>
  18. #include <vector>
  19. #include "utils/log_adapter.h"
  20. using debugger::Chunk;
  21. using debugger::EventListener;
  22. using debugger::EventReply;
  23. using debugger::EventReply_Status_FAILED;
  24. using debugger::GraphProto;
  25. using debugger::Heartbeat;
  26. using debugger::Metadata;
  27. using debugger::TensorBase;
  28. using debugger::TensorProto;
  29. using debugger::TensorSummary;
  30. using debugger::WatchpointHit;
  31. namespace mindspore {
  32. GrpcClient::GrpcClient(const std::string &host, const std::string &port) : stub_(nullptr) { Init(host, port); }
  33. void GrpcClient::Init(const std::string &host, const std::string &port) {
  34. std::string target_str = host + ":" + port;
  35. MS_LOG(INFO) << "GrpcClient connecting to: " << target_str;
  36. std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials());
  37. stub_ = EventListener::NewStub(channel);
  38. }
  39. void GrpcClient::Reset() { stub_ = nullptr; }
  40. EventReply GrpcClient::WaitForCommand(const Metadata &metadata) {
  41. EventReply reply;
  42. grpc::ClientContext context;
  43. grpc::Status status = stub_->WaitCMD(&context, metadata, &reply);
  44. if (!status.ok()) {
  45. MS_LOG(ERROR) << "RPC failed: WaitForCommand";
  46. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  47. reply.set_status(EventReply_Status_FAILED);
  48. }
  49. return reply;
  50. }
  51. EventReply GrpcClient::SendMetadata(const Metadata &metadata) {
  52. EventReply reply;
  53. grpc::ClientContext context;
  54. grpc::Status status = stub_->SendMetadata(&context, metadata, &reply);
  55. if (!status.ok()) {
  56. MS_LOG(ERROR) << "RPC failed: SendMetadata";
  57. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  58. reply.set_status(EventReply_Status_FAILED);
  59. }
  60. return reply;
  61. }
  62. std::vector<std::string> GrpcClient::ChunkString(std::string str, int graph_size) {
  63. std::vector<std::string> buf;
  64. constexpr auto l_chunk_size = 1024 * 1024 * 3;
  65. int size_iter = 0;
  66. while (size_iter < graph_size) {
  67. int chunk_size = l_chunk_size;
  68. if (graph_size - size_iter < l_chunk_size) {
  69. chunk_size = graph_size - size_iter;
  70. }
  71. std::string buffer;
  72. buffer.resize(chunk_size);
  73. auto err = memcpy_s(reinterpret_cast<char *>(buffer.data()), chunk_size, str.data() + size_iter, chunk_size);
  74. if (err != 0) {
  75. MS_LOG(EXCEPTION) << "memcpy_s failed. errorno is: " << err;
  76. }
  77. buf.push_back(buffer);
  78. if (size_iter > INT_MAX - l_chunk_size) {
  79. MS_EXCEPTION(ValueError) << size_iter << " + " << l_chunk_size << "would lead to integer overflow!";
  80. }
  81. size_iter += l_chunk_size;
  82. }
  83. return buf;
  84. }
  85. EventReply GrpcClient::SendGraph(const GraphProto &graph) {
  86. EventReply reply;
  87. grpc::ClientContext context;
  88. Chunk chunk;
  89. std::unique_ptr<grpc::ClientWriter<Chunk> > writer(stub_->SendGraph(&context, &reply));
  90. std::string str = graph.SerializeAsString();
  91. int graph_size = graph.ByteSize();
  92. auto buf = ChunkString(str, graph_size);
  93. for (unsigned int i = 0; i < buf.size(); i++) {
  94. MS_LOG(INFO) << "RPC:sending the " << i << "chunk in graph";
  95. chunk.set_buffer(buf[i]);
  96. if (!writer->Write(chunk)) {
  97. break;
  98. }
  99. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  100. }
  101. writer->WritesDone();
  102. grpc::Status status = writer->Finish();
  103. if (!status.ok()) {
  104. MS_LOG(ERROR) << "RPC failed: SendGraph";
  105. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  106. reply.set_status(EventReply_Status_FAILED);
  107. }
  108. return reply;
  109. }
  110. EventReply GrpcClient::SendMultiGraphs(const std::list<Chunk> &chunks) {
  111. EventReply reply;
  112. grpc::ClientContext context;
  113. std::unique_ptr<grpc::ClientWriter<Chunk> > writer(stub_->SendMultiGraphs(&context, &reply));
  114. for (const auto &chunk : chunks) {
  115. if (!writer->Write(chunk)) {
  116. break;
  117. }
  118. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  119. }
  120. writer->WritesDone();
  121. grpc::Status status = writer->Finish();
  122. if (!status.ok()) {
  123. MS_LOG(ERROR) << "RPC failed: SendMultigraphs";
  124. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  125. reply.set_status(EventReply_Status_FAILED);
  126. }
  127. return reply;
  128. }
  129. EventReply GrpcClient::SendTensors(const std::list<TensorProto> &tensors) {
  130. EventReply reply;
  131. grpc::ClientContext context;
  132. std::unique_ptr<grpc::ClientWriter<TensorProto> > writer(stub_->SendTensors(&context, &reply));
  133. for (const auto &tensor : tensors) {
  134. if (!writer->Write(tensor)) {
  135. break;
  136. }
  137. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  138. }
  139. writer->WritesDone();
  140. grpc::Status status = writer->Finish();
  141. if (!status.ok()) {
  142. MS_LOG(ERROR) << "RPC failed: SendTensors";
  143. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  144. reply.set_status(EventReply_Status_FAILED);
  145. }
  146. return reply;
  147. }
  148. EventReply GrpcClient::SendWatchpointHits(const std::list<WatchpointHit> &watchpoints) {
  149. EventReply reply;
  150. grpc::ClientContext context;
  151. std::unique_ptr<grpc::ClientWriter<WatchpointHit> > writer(stub_->SendWatchpointHits(&context, &reply));
  152. for (const auto &watchpoint : watchpoints) {
  153. if (!writer->Write(watchpoint)) {
  154. break;
  155. }
  156. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  157. }
  158. writer->WritesDone();
  159. grpc::Status status = writer->Finish();
  160. if (!status.ok()) {
  161. MS_LOG(ERROR) << "RPC failed: SendWatchpointHits";
  162. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  163. reply.set_status(EventReply_Status_FAILED);
  164. }
  165. return reply;
  166. }
  167. EventReply GrpcClient::SendHeartbeat(const Heartbeat &heartbeat) {
  168. EventReply reply;
  169. grpc::ClientContext context;
  170. grpc::Status status = stub_->SendHeartbeat(&context, heartbeat, &reply);
  171. if (!status.ok()) {
  172. MS_LOG(ERROR) << "RPC failed: SendHeartbeat";
  173. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  174. reply.set_status(EventReply_Status_FAILED);
  175. }
  176. return reply;
  177. }
  178. EventReply GrpcClient::SendTensorBase(const std::list<TensorBase> &tensor_base_list) {
  179. EventReply reply;
  180. grpc::ClientContext context;
  181. std::unique_ptr<grpc::ClientWriter<TensorBase> > writer(stub_->SendTensorBase(&context, &reply));
  182. for (const auto &tensor_base : tensor_base_list) {
  183. if (!writer->Write(tensor_base)) {
  184. break;
  185. }
  186. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  187. }
  188. writer->WritesDone();
  189. grpc::Status status = writer->Finish();
  190. if (!status.ok()) {
  191. MS_LOG(ERROR) << "RPC failed: SendTensorBase";
  192. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  193. reply.set_status(EventReply_Status_FAILED);
  194. }
  195. return reply;
  196. }
  197. EventReply GrpcClient::SendTensorStats(const std::list<TensorSummary> &tensor_summary_list) {
  198. EventReply reply;
  199. grpc::ClientContext context;
  200. std::unique_ptr<grpc::ClientWriter<TensorSummary> > writer(stub_->SendTensorStats(&context, &reply));
  201. for (const auto &tensor_summary : tensor_summary_list) {
  202. if (!writer->Write(tensor_summary)) {
  203. break;
  204. }
  205. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  206. }
  207. writer->WritesDone();
  208. grpc::Status status = writer->Finish();
  209. if (!status.ok()) {
  210. MS_LOG(ERROR) << "RPC failed: SendTensorStats";
  211. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  212. reply.set_status(EventReply_Status_FAILED);
  213. }
  214. return reply;
  215. }
  216. } // namespace mindspore