You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grpc_client.cc 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "debug/debugger/grpc_client.h"
  17. #include <stdio.h>
  18. #include <stdlib.h>
  19. #include <openssl/pem.h>
  20. #include <openssl/err.h>
  21. #include <openssl/pkcs12.h>
  22. #include <openssl/x509.h>
  23. #include <openssl/evp.h>
  24. #include <thread>
  25. #include <vector>
  26. #include "utils/log_adapter.h"
  27. using debugger::Chunk;
  28. using debugger::EventListener;
  29. using debugger::EventReply;
  30. using debugger::EventReply_Status_FAILED;
  31. using debugger::GraphProto;
  32. using debugger::Metadata;
  33. using debugger::TensorProto;
  34. using debugger::WatchpointHit;
  35. #define CHUNK_SIZE 1024 * 1024 * 3
  36. namespace mindspore {
  37. GrpcClient::GrpcClient(const std::string &host, const std::string &port, const bool &ssl_certificate,
  38. const std::string &certificate_dir, const std::string &certificate_passphrase)
  39. : stub_(nullptr) {
  40. Init(host, port, ssl_certificate, certificate_dir, certificate_passphrase);
  41. }
  42. void GrpcClient::Init(const std::string &host, const std::string &port, const bool &ssl_certificate,
  43. const std::string &certificate_dir, const std::string &certificate_passphrase) {
  44. std::string target_str = host + ":" + port;
  45. MS_LOG(INFO) << "GrpcClient connecting to: " << target_str;
  46. std::shared_ptr<grpc::Channel> channel;
  47. if (ssl_certificate) {
  48. FILE *fp;
  49. EVP_PKEY *pkey = NULL;
  50. X509 *cert = NULL;
  51. STACK_OF(X509) *ca = NULL;
  52. PKCS12 *p12 = NULL;
  53. if ((fp = fopen(certificate_dir.c_str(), "rb")) == NULL) {
  54. MS_LOG(ERROR) << "Error opening file: " << certificate_dir;
  55. exit(EXIT_FAILURE);
  56. }
  57. p12 = d2i_PKCS12_fp(fp, NULL);
  58. fclose(fp);
  59. if (p12 == NULL) {
  60. MS_LOG(ERROR) << "Error reading PKCS#12 file";
  61. X509_free(cert);
  62. EVP_PKEY_free(pkey);
  63. sk_X509_pop_free(ca, X509_free);
  64. exit(EXIT_FAILURE);
  65. }
  66. if (!PKCS12_parse(p12, certificate_passphrase.c_str(), &pkey, &cert, &ca)) {
  67. MS_LOG(ERROR) << "Error parsing PKCS#12 file";
  68. X509_free(cert);
  69. EVP_PKEY_free(pkey);
  70. sk_X509_pop_free(ca, X509_free);
  71. exit(EXIT_FAILURE);
  72. }
  73. std::string strca;
  74. std::string strcert;
  75. std::string strkey;
  76. if (pkey == NULL || cert == NULL || ca == NULL) {
  77. MS_LOG(ERROR) << "Error private key or cert or CA certificate.";
  78. X509_free(cert);
  79. EVP_PKEY_free(pkey);
  80. sk_X509_pop_free(ca, X509_free);
  81. exit(EXIT_FAILURE);
  82. } else {
  83. ASN1_TIME *validtime = X509_getm_notAfter(cert);
  84. if (X509_cmp_current_time(validtime) < 0) {
  85. MS_LOG(ERROR) << "This certificate is over its valid time, please use a new certificate.";
  86. X509_free(cert);
  87. EVP_PKEY_free(pkey);
  88. sk_X509_pop_free(ca, X509_free);
  89. exit(EXIT_FAILURE);
  90. }
  91. int nid = X509_get_signature_nid(cert);
  92. int keybit = EVP_PKEY_bits(pkey);
  93. if (nid == NID_sha1) {
  94. MS_LOG(WARNING) << "Signature algrithm is sha1, which maybe not secure enough.";
  95. } else if (keybit < 2048) {
  96. MS_LOG(WARNING) << "The private key bits is: " << keybit << ", which maybe not secure enough.";
  97. }
  98. int dwPriKeyLen = i2d_PrivateKey(pkey, NULL); // get the length of private key
  99. unsigned char *pribuf = (unsigned char *)malloc(sizeof(unsigned char) * dwPriKeyLen);
  100. i2d_PrivateKey(pkey, &pribuf); // PrivateKey DER code
  101. strkey = std::string(reinterpret_cast<char const *>(pribuf), dwPriKeyLen);
  102. int dwcertLen = i2d_X509(cert, NULL); // get the length of private key
  103. unsigned char *certbuf = (unsigned char *)malloc(sizeof(unsigned char) * dwcertLen);
  104. i2d_X509(cert, &certbuf); // PrivateKey DER code
  105. strcert = std::string(reinterpret_cast<char const *>(certbuf), dwcertLen);
  106. int dwcaLen = i2d_X509(sk_X509_value(ca, 0), NULL); // get the length of private key
  107. unsigned char *cabuf = (unsigned char *)malloc(sizeof(unsigned char) * dwcaLen);
  108. i2d_X509(sk_X509_value(ca, 0), &cabuf); // PrivateKey DER code
  109. strca = std::string(reinterpret_cast<char const *>(cabuf), dwcaLen);
  110. free(pribuf);
  111. free(certbuf);
  112. free(cabuf);
  113. }
  114. grpc::SslCredentialsOptions opts = {strca, strkey, strcert};
  115. channel = grpc::CreateChannel(target_str, grpc::SslCredentials(opts));
  116. } else {
  117. channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials());
  118. }
  119. stub_ = EventListener::NewStub(channel);
  120. }
  121. void GrpcClient::Reset() { stub_ = nullptr; }
  122. EventReply GrpcClient::WaitForCommand(const Metadata &metadata) {
  123. EventReply reply;
  124. grpc::ClientContext context;
  125. grpc::Status status = stub_->WaitCMD(&context, metadata, &reply);
  126. if (!status.ok()) {
  127. MS_LOG(ERROR) << "RPC failed: WaitForCommand";
  128. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  129. reply.set_status(EventReply_Status_FAILED);
  130. }
  131. return reply;
  132. }
  133. EventReply GrpcClient::SendMetadata(const Metadata &metadata) {
  134. EventReply reply;
  135. grpc::ClientContext context;
  136. grpc::Status status = stub_->SendMetadata(&context, metadata, &reply);
  137. if (!status.ok()) {
  138. MS_LOG(ERROR) << "RPC failed: SendMetadata";
  139. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  140. reply.set_status(EventReply_Status_FAILED);
  141. }
  142. return reply;
  143. }
  144. std::vector<std::string> ChunkString(std::string str, int graph_size) {
  145. std::vector<std::string> buf;
  146. int size_iter = 0;
  147. while (size_iter < graph_size) {
  148. int chunk_size = CHUNK_SIZE;
  149. if (graph_size - size_iter < CHUNK_SIZE) {
  150. chunk_size = graph_size - size_iter;
  151. }
  152. std::string buffer;
  153. buffer.resize(chunk_size);
  154. memcpy(reinterpret_cast<char *>(buffer.data()), str.data() + size_iter, chunk_size);
  155. buf.push_back(buffer);
  156. size_iter += CHUNK_SIZE;
  157. }
  158. return buf;
  159. }
  160. EventReply GrpcClient::SendGraph(const GraphProto &graph) {
  161. EventReply reply;
  162. grpc::ClientContext context;
  163. Chunk chunk;
  164. std::unique_ptr<grpc::ClientWriter<Chunk> > writer(stub_->SendGraph(&context, &reply));
  165. std::string str = graph.SerializeAsString();
  166. int graph_size = graph.ByteSize();
  167. auto buf = ChunkString(str, graph_size);
  168. for (unsigned int i = 0; i < buf.size(); i++) {
  169. MS_LOG(INFO) << "RPC:sending the " << i << "chunk in graph";
  170. chunk.set_buffer(buf[i]);
  171. if (!writer->Write(chunk)) {
  172. break;
  173. }
  174. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  175. }
  176. writer->WritesDone();
  177. grpc::Status status = writer->Finish();
  178. if (!status.ok()) {
  179. MS_LOG(ERROR) << "RPC failed: SendGraph";
  180. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  181. reply.set_status(EventReply_Status_FAILED);
  182. }
  183. return reply;
  184. }
  185. EventReply GrpcClient::SendTensors(const std::list<TensorProto> &tensors) {
  186. EventReply reply;
  187. grpc::ClientContext context;
  188. std::unique_ptr<grpc::ClientWriter<TensorProto> > writer(stub_->SendTensors(&context, &reply));
  189. for (const auto &tensor : tensors) {
  190. if (!writer->Write(tensor)) {
  191. break;
  192. }
  193. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  194. }
  195. writer->WritesDone();
  196. grpc::Status status = writer->Finish();
  197. if (!status.ok()) {
  198. MS_LOG(ERROR) << "RPC failed: SendTensors";
  199. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  200. reply.set_status(EventReply_Status_FAILED);
  201. }
  202. return reply;
  203. }
  204. EventReply GrpcClient::SendWatchpointHits(const std::list<WatchpointHit> &watchpoints) {
  205. EventReply reply;
  206. grpc::ClientContext context;
  207. std::unique_ptr<grpc::ClientWriter<WatchpointHit> > writer(stub_->SendWatchpointHits(&context, &reply));
  208. for (const auto &watchpoint : watchpoints) {
  209. if (!writer->Write(watchpoint)) {
  210. break;
  211. }
  212. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  213. }
  214. writer->WritesDone();
  215. grpc::Status status = writer->Finish();
  216. if (!status.ok()) {
  217. MS_LOG(ERROR) << "RPC failed: SendWatchpointHits";
  218. MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
  219. reply.set_status(EventReply_Status_FAILED);
  220. }
  221. return reply;
  222. }
  223. } // namespace mindspore