/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "debug/debugger/grpc_client.h" #include #include #include #include #include #include #include #include #include #include "utils/log_adapter.h" using debugger::Chunk; using debugger::EventListener; using debugger::EventReply; using debugger::EventReply_Status_FAILED; using debugger::GraphProto; using debugger::Metadata; using debugger::TensorProto; using debugger::WatchpointHit; #define CHUNK_SIZE 1024 * 1024 * 3 namespace mindspore { GrpcClient::GrpcClient(const std::string &host, const std::string &port, const bool &ssl_certificate, const std::string &certificate_dir, const std::string &certificate_passphrase) : stub_(nullptr) { Init(host, port, ssl_certificate, certificate_dir, certificate_passphrase); } void GrpcClient::Init(const std::string &host, const std::string &port, const bool &ssl_certificate, const std::string &certificate_dir, const std::string &certificate_passphrase) { std::string target_str = host + ":" + port; MS_LOG(INFO) << "GrpcClient connecting to: " << target_str; std::shared_ptr channel; if (ssl_certificate) { FILE *fp; EVP_PKEY *pkey = NULL; X509 *cert = NULL; STACK_OF(X509) *ca = NULL; PKCS12 *p12 = NULL; if ((fp = fopen(certificate_dir.c_str(), "rb")) == NULL) { MS_LOG(ERROR) << "Error opening file: " << certificate_dir; exit(EXIT_FAILURE); } p12 = d2i_PKCS12_fp(fp, NULL); fclose(fp); if (p12 == NULL) { MS_LOG(ERROR) << "Error reading PKCS#12 file"; X509_free(cert); EVP_PKEY_free(pkey); sk_X509_pop_free(ca, X509_free); exit(EXIT_FAILURE); } if (!PKCS12_parse(p12, certificate_passphrase.c_str(), &pkey, &cert, &ca)) { MS_LOG(ERROR) << "Error parsing PKCS#12 file"; X509_free(cert); EVP_PKEY_free(pkey); sk_X509_pop_free(ca, X509_free); exit(EXIT_FAILURE); } std::string strca; std::string strcert; std::string strkey; if (pkey == NULL || cert == NULL || ca == NULL) { MS_LOG(ERROR) << "Error private key or cert or CA certificate."; X509_free(cert); EVP_PKEY_free(pkey); sk_X509_pop_free(ca, X509_free); exit(EXIT_FAILURE); } else { ASN1_TIME *validtime = X509_getm_notAfter(cert); if (X509_cmp_current_time(validtime) < 0) { MS_LOG(ERROR) << "This certificate is over its valid time, please use a new certificate."; X509_free(cert); EVP_PKEY_free(pkey); sk_X509_pop_free(ca, X509_free); exit(EXIT_FAILURE); } int nid = X509_get_signature_nid(cert); int keybit = EVP_PKEY_bits(pkey); if (nid == NID_sha1) { MS_LOG(WARNING) << "Signature algrithm is sha1, which maybe not secure enough."; } else if (keybit < 2048) { MS_LOG(WARNING) << "The private key bits is: " << keybit << ", which maybe not secure enough."; } int dwPriKeyLen = i2d_PrivateKey(pkey, NULL); // get the length of private key unsigned char *pribuf = (unsigned char *)malloc(sizeof(unsigned char) * dwPriKeyLen); i2d_PrivateKey(pkey, &pribuf); // PrivateKey DER code strkey = std::string(reinterpret_cast(pribuf), dwPriKeyLen); int dwcertLen = i2d_X509(cert, NULL); // get the length of private key unsigned char *certbuf = (unsigned char *)malloc(sizeof(unsigned char) * dwcertLen); i2d_X509(cert, &certbuf); // PrivateKey DER code strcert = std::string(reinterpret_cast(certbuf), dwcertLen); int dwcaLen = i2d_X509(sk_X509_value(ca, 0), NULL); // get the length of private key unsigned char *cabuf = (unsigned char *)malloc(sizeof(unsigned char) * dwcaLen); i2d_X509(sk_X509_value(ca, 0), &cabuf); // PrivateKey DER code strca = std::string(reinterpret_cast(cabuf), dwcaLen); free(pribuf); free(certbuf); free(cabuf); } grpc::SslCredentialsOptions opts = {strca, strkey, strcert}; channel = grpc::CreateChannel(target_str, grpc::SslCredentials(opts)); } else { channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()); } stub_ = EventListener::NewStub(channel); } void GrpcClient::Reset() { stub_ = nullptr; } EventReply GrpcClient::WaitForCommand(const Metadata &metadata) { EventReply reply; grpc::ClientContext context; grpc::Status status = stub_->WaitCMD(&context, metadata, &reply); if (!status.ok()) { MS_LOG(ERROR) << "RPC failed: WaitForCommand"; MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); reply.set_status(EventReply_Status_FAILED); } return reply; } EventReply GrpcClient::SendMetadata(const Metadata &metadata) { EventReply reply; grpc::ClientContext context; grpc::Status status = stub_->SendMetadata(&context, metadata, &reply); if (!status.ok()) { MS_LOG(ERROR) << "RPC failed: SendMetadata"; MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); reply.set_status(EventReply_Status_FAILED); } return reply; } std::vector ChunkString(std::string str, int graph_size) { std::vector buf; int size_iter = 0; while (size_iter < graph_size) { int chunk_size = CHUNK_SIZE; if (graph_size - size_iter < CHUNK_SIZE) { chunk_size = graph_size - size_iter; } std::string buffer; buffer.resize(chunk_size); memcpy(reinterpret_cast(buffer.data()), str.data() + size_iter, chunk_size); buf.push_back(buffer); size_iter += CHUNK_SIZE; } return buf; } EventReply GrpcClient::SendGraph(const GraphProto &graph) { EventReply reply; grpc::ClientContext context; Chunk chunk; std::unique_ptr > writer(stub_->SendGraph(&context, &reply)); std::string str = graph.SerializeAsString(); int graph_size = graph.ByteSize(); auto buf = ChunkString(str, graph_size); for (unsigned int i = 0; i < buf.size(); i++) { MS_LOG(INFO) << "RPC:sending the " << i << "chunk in graph"; chunk.set_buffer(buf[i]); if (!writer->Write(chunk)) { break; } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } writer->WritesDone(); grpc::Status status = writer->Finish(); if (!status.ok()) { MS_LOG(ERROR) << "RPC failed: SendGraph"; MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); reply.set_status(EventReply_Status_FAILED); } return reply; } EventReply GrpcClient::SendTensors(const std::list &tensors) { EventReply reply; grpc::ClientContext context; std::unique_ptr > writer(stub_->SendTensors(&context, &reply)); for (const auto &tensor : tensors) { if (!writer->Write(tensor)) { break; } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } writer->WritesDone(); grpc::Status status = writer->Finish(); if (!status.ok()) { MS_LOG(ERROR) << "RPC failed: SendTensors"; MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); reply.set_status(EventReply_Status_FAILED); } return reply; } EventReply GrpcClient::SendWatchpointHits(const std::list &watchpoints) { EventReply reply; grpc::ClientContext context; std::unique_ptr > writer(stub_->SendWatchpointHits(&context, &reply)); for (const auto &watchpoint : watchpoints) { if (!writer->Write(watchpoint)) { break; } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } writer->WritesDone(); grpc::Status status = writer->Finish(); if (!status.ok()) { MS_LOG(ERROR) << "RPC failed: SendWatchpointHits"; MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); reply.set_status(EventReply_Status_FAILED); } return reply; } } // namespace mindspore