| @@ -19,6 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| HttpRequestHandler::~HttpRequestHandler() { | |||
| if (evbase_) { | |||
| event_base_free(evbase_); | |||
| evbase_ = nullptr; | |||
| } | |||
| } | |||
| bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers) { | |||
| evbase_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(evbase_); | |||
| @@ -27,33 +34,18 @@ bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string | |||
| if (PSContext::instance()->enable_ssl()) { | |||
| MS_LOG(INFO) << "Enable ssl support."; | |||
| SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(), | |||
| SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2); | |||
| EC_KEY *ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); | |||
| MS_EXCEPTION_IF_NULL(ecdh); | |||
| X509 *cert = SSLWrapper::GetInstance().ReadCertFromFile(kCertificateChain); | |||
| if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) { | |||
| MS_LOG(INFO) << "Verify cert time failed."; | |||
| return false; | |||
| } | |||
| if (!SSL_CTX_use_certificate_chain_file(SSLWrapper::GetInstance().GetSSLCtx(), kCertificateChain)) { | |||
| MS_LOG(ERROR) << "SSL use certificate chain file failed!"; | |||
| return false; | |||
| } | |||
| if (!SSL_CTX_use_PrivateKey_file(SSLWrapper::GetInstance().GetSSLCtx(), kPrivateKey, SSL_FILETYPE_PEM)) { | |||
| MS_LOG(ERROR) << "SSL use private key file failed!"; | |||
| return false; | |||
| } | |||
| if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx())) { | |||
| MS_LOG(ERROR) << "SSL check private key file failed!"; | |||
| return false; | |||
| if (!SSL_CTX_set_options(SSLHTTP::GetInstance().GetSSLCtx(), SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | | |||
| SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | | |||
| SSL_OP_NO_TLSv1_1)) { | |||
| if (evbase_) { | |||
| event_base_free(evbase_); | |||
| evbase_ = nullptr; | |||
| } | |||
| evhttp_free(http); | |||
| http = nullptr; | |||
| MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed."; | |||
| } | |||
| evhttp_set_bevcb(http, BuffereventCallback, SSLWrapper::GetInstance().GetSSLCtx()); | |||
| evhttp_set_bevcb(http, BuffereventCallback, SSLHTTP::GetInstance().GetSSLCtx()); | |||
| } | |||
| int result = evhttp_accept_socket(http, fd); | |||
| @@ -67,6 +59,7 @@ bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string | |||
| MS_EXCEPTION_IF_NULL(req); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto httpReq = std::make_shared<HttpMessageHandler>(); | |||
| MS_EXCEPTION_IF_NULL(httpReq); | |||
| httpReq->set_request(req); | |||
| httpReq->InitHttpMessage(); | |||
| OnRequestReceive *func = reinterpret_cast<OnRequestReceive *>(arg); | |||
| @@ -74,6 +67,7 @@ bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string | |||
| }; | |||
| // O SUCCESS,-1 ALREADY_EXIST,-2 FAILURE | |||
| MS_EXCEPTION_IF_NULL(handler.second); | |||
| int ret = evhttp_set_cb(http, handler.first.c_str(), TransFunc, reinterpret_cast<void *>(handler.second)); | |||
| std::string log_prefix = "Ev http register handle of:"; | |||
| if (ret == 0) { | |||
| @@ -101,16 +95,12 @@ void HttpRequestHandler::Run() { | |||
| } else { | |||
| MS_LOG(ERROR) << "Event base dispatch with unexpected error code!"; | |||
| } | |||
| if (evbase_) { | |||
| event_base_free(evbase_); | |||
| evbase_ = nullptr; | |||
| } | |||
| } | |||
| bool HttpRequestHandler::Stop() { | |||
| MS_LOG(INFO) << "Stop http server!"; | |||
| MS_EXCEPTION_IF_NULL(evbase_); | |||
| int ret = event_base_loopbreak(evbase_); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "event base loop break failed!"; | |||
| @@ -120,9 +110,13 @@ bool HttpRequestHandler::Stop() { | |||
| } | |||
| bufferevent *HttpRequestHandler::BuffereventCallback(event_base *base, void *arg) { | |||
| MS_EXCEPTION_IF_NULL(base); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| SSL_CTX *ctx = reinterpret_cast<SSL_CTX *>(arg); | |||
| SSL *ssl = SSL_new(ctx); | |||
| MS_EXCEPTION_IF_NULL(ssl); | |||
| bufferevent *bev = bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE); | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| return bev; | |||
| } | |||
| } // namespace core | |||
| @@ -29,7 +29,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/core/communicator/http_message_handler.h" | |||
| #include "ps/core/communicator/ssl_wrapper.h" | |||
| #include "ps/core/communicator/ssl_http.h" | |||
| #include "ps/constants.h" | |||
| #include "ps/ps_context.h" | |||
| @@ -44,7 +44,7 @@ using OnRequestReceive = std::function<void(std::shared_ptr<HttpMessageHandler>) | |||
| class HttpRequestHandler { | |||
| public: | |||
| HttpRequestHandler() : evbase_(nullptr) {} | |||
| virtual ~HttpRequestHandler() = default; | |||
| virtual ~HttpRequestHandler(); | |||
| bool Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers); | |||
| void Run(); | |||
| @@ -41,11 +41,15 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| HttpServer::~HttpServer() { Stop(); } | |||
| HttpServer::~HttpServer() { | |||
| if (!Stop()) { | |||
| MS_LOG(WARNING) << "Stop http server failed."; | |||
| } | |||
| } | |||
| bool HttpServer::InitServer() { | |||
| if (server_address_ == "") { | |||
| MS_LOG(INFO) << "The server ip is empty."; | |||
| MS_LOG(WARNING) << "The server address is empty."; | |||
| std::string interface; | |||
| std::string server_ip; | |||
| CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); | |||
| @@ -91,22 +95,24 @@ bool HttpServer::InitServer() { | |||
| result = ::bind(fd_, (struct sockaddr *)&addr, sizeof(addr)); | |||
| if (result < 0) { | |||
| MS_LOG(ERROR) << "Bind ip:" << server_address_ << " port:" << server_port_ << " failed!"; | |||
| MS_LOG(ERROR) << "Bind ip:" << server_address_ << " port:" << server_port_ << "failed!"; | |||
| close(fd_); | |||
| fd_ = -1; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Bind ip:" << server_address_ << " port:" << server_port_ << " successful!"; | |||
| result = ::listen(fd_, backlog_); | |||
| if (result < 0) { | |||
| MS_LOG(ERROR) << "Listen ip:" << server_address_ << " port:" << server_port_ << " failed!"; | |||
| MS_LOG(ERROR) << "Listen ip:" << server_address_ << " port:" << server_port_ << "failed!"; | |||
| close(fd_); | |||
| fd_ = -1; | |||
| return false; | |||
| } | |||
| int flags = 0; | |||
| if ((flags = fcntl(fd_, F_GETFL, 0)) < 0 || fcntl(fd_, F_SETFL, flags | O_NONBLOCK) < 0) { | |||
| if ((flags = fcntl(fd_, F_GETFL, 0)) < 0 || fcntl(fd_, F_SETFL, (unsigned int)flags | O_NONBLOCK) < 0) { | |||
| MS_LOG(ERROR) << "Set fcntl O_NONBLOCK failed!"; | |||
| close(fd_); | |||
| fd_ = -1; | |||
| @@ -135,12 +141,14 @@ bool HttpServer::Start(bool is_detach) { | |||
| MS_LOG(INFO) << "Start http server!"; | |||
| for (size_t i = 0; i < thread_num_; i++) { | |||
| auto http_request_handler = std::make_shared<HttpRequestHandler>(); | |||
| MS_EXCEPTION_IF_NULL(http_request_handler); | |||
| if (!http_request_handler->Initialize(fd_, request_handlers_)) { | |||
| MS_LOG(ERROR) << "Http initialize failed."; | |||
| return false; | |||
| } | |||
| http_request_handlers.push_back(http_request_handler); | |||
| auto thread = std::make_shared<std::thread>(&HttpRequestHandler::Run, http_request_handler); | |||
| MS_EXCEPTION_IF_NULL(thread); | |||
| if (is_detach) { | |||
| thread->detach(); | |||
| } | |||
| @@ -168,6 +176,10 @@ bool HttpServer::Stop() { | |||
| result = false; | |||
| } | |||
| } | |||
| if (fd_ != -1) { | |||
| close(fd_); | |||
| fd_ = -1; | |||
| } | |||
| is_stop_ = true; | |||
| } | |||
| return result; | |||
| @@ -58,7 +58,7 @@ void SSLClient::InitSSL() { | |||
| client_cert = path; | |||
| // 2. Parse the client password. | |||
| std::string client_password = CommUtil::ParseConfig(*config_, kClientPassword); | |||
| std::string client_password = PSContext::instance()->client_password(); | |||
| if (client_password.empty()) { | |||
| MS_LOG(EXCEPTION) << "The client password's value is empty."; | |||
| } | |||
| @@ -62,7 +62,7 @@ void SSLHTTP::InitSSL() { | |||
| server_cert = path; | |||
| // 2. Parse the server password. | |||
| std::string server_password = CommUtil::ParseConfig(*(config_), kServerPassword); | |||
| std::string server_password = PSContext::instance()->server_password(); | |||
| if (server_password.empty()) { | |||
| MS_LOG(EXCEPTION) << "The client password's value is empty."; | |||
| } | |||
| @@ -93,6 +93,11 @@ void SSLHTTP::InitSSL() { | |||
| if (!SSL_CTX_check_private_key(ssl_ctx_)) { | |||
| MS_LOG(EXCEPTION) << "SSL check private key file failed!"; | |||
| } | |||
| if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | | |||
| SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) { | |||
| MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed."; | |||
| } | |||
| } | |||
| void SSLHTTP::CleanSSL() { | |||
| @@ -1,4 +1,3 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| @@ -33,11 +32,11 @@ namespace ps { | |||
| namespace core { | |||
| SSLWrapper::SSLWrapper() | |||
| : ssl_ctx_(nullptr), | |||
| client_ssl_ctx_(nullptr), | |||
| rootFirstCA_(nullptr), | |||
| rootSecondCA_(nullptr), | |||
| rootFirstCrl_(nullptr), | |||
| rootSecondCrl_(nullptr) { | |||
| check_time_thread_(nullptr), | |||
| running_(false), | |||
| is_ready_(false) { | |||
| InitSSL(); | |||
| } | |||
| @@ -45,178 +44,206 @@ SSLWrapper::~SSLWrapper() { CleanSSL(); } | |||
| void SSLWrapper::InitSSL() { | |||
| CommUtil::InitOpenSSLEnv(); | |||
| int rand = RAND_poll(); | |||
| if (rand == 0) { | |||
| MS_LOG(ERROR) << "RAND_poll failed"; | |||
| } | |||
| ssl_ctx_ = SSL_CTX_new(SSLv23_server_method()); | |||
| if (!ssl_ctx_) { | |||
| MS_LOG(ERROR) << "SSL_CTX_new failed"; | |||
| MS_LOG(EXCEPTION) << "SSL_CTX_new failed"; | |||
| } | |||
| X509_STORE *store = SSL_CTX_get_cert_store(ssl_ctx_); | |||
| MS_EXCEPTION_IF_NULL(store); | |||
| if (X509_STORE_set_default_paths(store) != 1) { | |||
| MS_LOG(ERROR) << "X509_STORE_set_default_paths failed"; | |||
| } | |||
| client_ssl_ctx_ = SSL_CTX_new(SSLv23_client_method()); | |||
| if (!ssl_ctx_) { | |||
| MS_LOG(ERROR) << "SSL_CTX_new failed"; | |||
| MS_LOG(EXCEPTION) << "X509_STORE_set_default_paths failed"; | |||
| } | |||
| } | |||
| void SSLWrapper::CleanSSL() { | |||
| if (ssl_ctx_ != nullptr) { | |||
| SSL_CTX_free(ssl_ctx_); | |||
| std::unique_ptr<Configuration> config_ = | |||
| std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path()); | |||
| MS_EXCEPTION_IF_NULL(config_); | |||
| if (!config_->Initialize()) { | |||
| MS_LOG(EXCEPTION) << "The config file is empty."; | |||
| } | |||
| ERR_free_strings(); | |||
| EVP_cleanup(); | |||
| ERR_remove_thread_state(nullptr); | |||
| CRYPTO_cleanup_all_ex_data(); | |||
| } | |||
| SSL_CTX *SSLWrapper::GetSSLCtx(bool is_server) { | |||
| if (is_server) { | |||
| return ssl_ctx_; | |||
| } else { | |||
| return client_ssl_ctx_; | |||
| // 1.Parse the server's certificate and the ciphertext of key. | |||
| std::string server_cert = kCertificateChain; | |||
| std::string path = CommUtil::ParseConfig(*(config_), kServerCertPath); | |||
| if (!CommUtil::IsFileExists(path)) { | |||
| MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist."; | |||
| } | |||
| } | |||
| X509 *SSLWrapper::ReadCertFromFile(const std::string &certPath) const { | |||
| BIO *bio = BIO_new_file(certPath.c_str(), "r"); | |||
| return PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); | |||
| } | |||
| X509 *SSLWrapper::ReadCertFromPerm(std::string cert) { | |||
| BIO *bio = BIO_new_mem_buf(reinterpret_cast<void *>(cert.data()), -1); | |||
| return PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); | |||
| } | |||
| X509_CRL *SSLWrapper::ReadCrlFromFile(const std::string &crlPath) const { | |||
| BIO *bio = BIO_new_file(crlPath.c_str(), "r"); | |||
| return PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr); | |||
| } | |||
| server_cert = path; | |||
| bool SSLWrapper::VerifyCertTime(const X509 *cert) const { | |||
| ASN1_TIME *start = X509_getm_notBefore(cert); | |||
| ASN1_TIME *end = X509_getm_notAfter(cert); | |||
| MS_LOG(INFO) << "The server cert path:" << server_cert; | |||
| int day = 0; | |||
| int sec = 0; | |||
| ASN1_TIME_diff(&day, &sec, start, NULL); | |||
| // 2. Parse the server password. | |||
| std::string server_password = PSContext::instance()->server_password(); | |||
| if (server_password.empty()) { | |||
| MS_LOG(EXCEPTION) << "The client password's value is empty."; | |||
| } | |||
| if (day < 0 || sec < 0) { | |||
| MS_LOG(INFO) << "Cert start time is later than now time."; | |||
| return false; | |||
| EVP_PKEY *pkey = nullptr; | |||
| X509 *cert = nullptr; | |||
| STACK_OF(X509) *ca_stack = nullptr; | |||
| BIO *bio = BIO_new_file(server_cert.c_str(), "rb"); | |||
| MS_EXCEPTION_IF_NULL(bio); | |||
| PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr); | |||
| MS_EXCEPTION_IF_NULL(p12); | |||
| BIO_free_all(bio); | |||
| if (!PKCS12_parse(p12, server_password.c_str(), &pkey, &cert, &ca_stack)) { | |||
| MS_LOG(EXCEPTION) << "PKCS12_parse failed."; | |||
| } | |||
| PKCS12_free(p12); | |||
| std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList); | |||
| std::vector<std::string> ciphers = CommUtil::Split(default_cipher_list, kColon); | |||
| if (!CommUtil::VerifyCipherList(ciphers)) { | |||
| MS_LOG(EXCEPTION) << "The cipher is wrong."; | |||
| } | |||
| day = 0; | |||
| sec = 0; | |||
| ASN1_TIME_diff(&day, &sec, NULL, end); | |||
| if (day < 0 || sec < 0) { | |||
| MS_LOG(INFO) << "Cert end time is sooner than now time."; | |||
| return false; | |||
| if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) { | |||
| MS_LOG(EXCEPTION) << "SSL use set cipher list failed!"; | |||
| } | |||
| return true; | |||
| } | |||
| std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath); | |||
| if (crl_path.empty()) { | |||
| MS_LOG(INFO) << "The crl path is empty."; | |||
| } else if (!CommUtil::VerifyCRL(cert, crl_path)) { | |||
| MS_LOG(EXCEPTION) << "Verify crl failed."; | |||
| } | |||
| bool SSLWrapper::VerifyCAChain(const std::string &keyAttestation, const std::string &equipCert, | |||
| const std::string &equipCACert, std::string) { | |||
| X509 *keyAttestationCertObj = ReadCertFromPerm(keyAttestation); | |||
| X509 *equipCertObj = ReadCertFromPerm(equipCert); | |||
| X509 *equipCACertObj = ReadCertFromPerm(equipCACert); | |||
| std::string client_ca = kCAcrt; | |||
| std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath); | |||
| if (!CommUtil::IsFileExists(ca_path)) { | |||
| MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist."; | |||
| } | |||
| client_ca = ca_path; | |||
| if (!VerifyCertTime(keyAttestationCertObj) || !VerifyCertTime(equipCertObj) || !VerifyCertTime(equipCACertObj)) { | |||
| return false; | |||
| if (!CommUtil::VerifyCommonName(cert, client_ca)) { | |||
| MS_LOG(EXCEPTION) << "Verify common name failed."; | |||
| } | |||
| EVP_PKEY *equipPubKey = X509_get_pubkey(equipCertObj); | |||
| EVP_PKEY *equipCAPubKey = X509_get_pubkey(equipCACertObj); | |||
| SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, 0); | |||
| if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) { | |||
| MS_LOG(EXCEPTION) << "SSL load ca location failed!"; | |||
| } | |||
| EVP_PKEY *rootFirstPubKey = X509_get_pubkey(rootFirstCA_); | |||
| EVP_PKEY *rootSecondPubKey = X509_get_pubkey(rootSecondCA_); | |||
| if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) { | |||
| MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; | |||
| } | |||
| int ret = 0; | |||
| ret = X509_verify(keyAttestationCertObj, equipPubKey); | |||
| if (ret != 1) { | |||
| MS_LOG(INFO) << "keyAttestationCert verify is failed"; | |||
| return false; | |||
| if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) { | |||
| MS_LOG(EXCEPTION) << "SSL use private key file failed!"; | |||
| } | |||
| ret = X509_verify(equipCertObj, equipCAPubKey); | |||
| if (ret != 1) { | |||
| MS_LOG(INFO) << "Equip cert verify is failed"; | |||
| return false; | |||
| if (!SSL_CTX_check_private_key(ssl_ctx_)) { | |||
| MS_LOG(EXCEPTION) << "SSL check private key file failed!"; | |||
| } | |||
| int ret_first = X509_verify(equipCACertObj, rootFirstPubKey); | |||
| int ret_second = X509_verify(equipCACertObj, rootSecondPubKey); | |||
| if (ret_first != 1 && ret_second != 1) { | |||
| MS_LOG(INFO) << "Equip ca cert verify is failed"; | |||
| return false; | |||
| if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | | |||
| SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) { | |||
| MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed."; | |||
| } | |||
| MS_LOG(INFO) << "VerifyCAChain success."; | |||
| EVP_PKEY_free(equipPubKey); | |||
| EVP_PKEY_free(equipCAPubKey); | |||
| EVP_PKEY_free(rootFirstPubKey); | |||
| EVP_PKEY_free(rootSecondPubKey); | |||
| return true; | |||
| } | |||
| bool SSLWrapper::VerifyCRL(const std::string &equipCert) { | |||
| X509 *equipCertObj = ReadCertFromPerm(equipCert); | |||
| if (rootFirstCrl_ == nullptr && rootSecondCrl_ == nullptr) { | |||
| MS_LOG(INFO) << "RootFirstCrl && rootSecondCrl is nullptr."; | |||
| return false; | |||
| if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) { | |||
| MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!"; | |||
| } | |||
| EVP_PKEY *evp_pkey = X509_get_pubkey(equipCertObj); | |||
| int ret = X509_CRL_verify(rootFirstCrl_, evp_pkey); | |||
| if (ret == 1) { | |||
| MS_LOG(INFO) << "Equip cert in root first crl, verify failed"; | |||
| return false; | |||
| } | |||
| ret = X509_CRL_verify(rootSecondCrl_, evp_pkey); | |||
| if (ret == 1) { | |||
| MS_LOG(INFO) << "Equip cert in root second crl, verify failed"; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "VerifyCRL success."; | |||
| return true; | |||
| StartCheckCertTime(*config_, cert, client_ca); | |||
| } | |||
| bool SSLWrapper::VerifyRSAKey(const std::string &keyAttestation, const unsigned char *srcData, | |||
| const unsigned char *signData, int srcDataLen) { | |||
| if (keyAttestation.empty() || srcData == nullptr || signData == nullptr) { | |||
| MS_LOG(INFO) << "KeyAttestation or srcData or signData is empty."; | |||
| return false; | |||
| void SSLWrapper::CleanSSL() { | |||
| if (ssl_ctx_ != nullptr) { | |||
| SSL_CTX_free(ssl_ctx_); | |||
| } | |||
| ERR_free_strings(); | |||
| EVP_cleanup(); | |||
| ERR_remove_thread_state(nullptr); | |||
| CRYPTO_cleanup_all_ex_data(); | |||
| StopCheckCertTime(); | |||
| } | |||
| X509 *keyAttestationCertObj = ReadCertFromPerm(keyAttestation); | |||
| time_t SSLWrapper::ConvertAsn1Time(const ASN1_TIME *const time) const { | |||
| MS_EXCEPTION_IF_NULL(time); | |||
| struct tm t; | |||
| const char *str = (const char *)time->data; | |||
| MS_EXCEPTION_IF_NULL(str); | |||
| size_t i = 0; | |||
| if (memset_s(&t, sizeof(t), 0, sizeof(t)) != EOK) { | |||
| MS_LOG(EXCEPTION) << "Memset Failed!"; | |||
| } | |||
| if (time->type == V_ASN1_UTCTIME) { | |||
| t.tm_year = (str[i++] - '0') * kBase; | |||
| t.tm_year += (str[i++] - '0'); | |||
| if (t.tm_year < kSeventyYear) { | |||
| t.tm_year += kHundredYear; | |||
| } | |||
| } else if (time->type == V_ASN1_GENERALIZEDTIME) { | |||
| t.tm_year = (str[i++] - '0') * kThousandYear; | |||
| t.tm_year += (str[i++] - '0') * kHundredYear; | |||
| t.tm_year += (str[i++] - '0') * kBase; | |||
| t.tm_year += (str[i++] - '0'); | |||
| t.tm_year -= kBaseYear; | |||
| } | |||
| t.tm_mon = (str[i++] - '0') * kBase; | |||
| // -1 since January is 0 not 1. | |||
| t.tm_mon += (str[i++] - '0') - kJanuary; | |||
| t.tm_mday = (str[i++] - '0') * kBase; | |||
| t.tm_mday += (str[i++] - '0'); | |||
| t.tm_hour = (str[i++] - '0') * kBase; | |||
| t.tm_hour += (str[i++] - '0'); | |||
| t.tm_min = (str[i++] - '0') * kBase; | |||
| t.tm_min += (str[i++] - '0'); | |||
| t.tm_sec = (str[i++] - '0') * kBase; | |||
| t.tm_sec += (str[i++] - '0'); | |||
| return mktime(&t); | |||
| } | |||
| EVP_PKEY *pubKey = X509_get_pubkey(keyAttestationCertObj); | |||
| RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey); | |||
| if (pRSAPublicKey == nullptr) { | |||
| MS_LOG(INFO) << "Get rsa public key failed."; | |||
| return false; | |||
| } | |||
| void SSLWrapper::StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path) { | |||
| MS_EXCEPTION_IF_NULL(cert); | |||
| MS_LOG(INFO) << "The server start check cert."; | |||
| int64_t interval = kCertCheckIntervalInHour; | |||
| int64_t warning_time = kCertExpireWarningTimeInDay; | |||
| if (config.Exists(kCertExpireWarningTime)) { | |||
| int64_t res_time = config.GetInt(kCertExpireWarningTime, 0); | |||
| if (res_time < kMinWarningTime || res_time > kMaxWarningTime) { | |||
| MS_LOG(EXCEPTION) << "The Certificate expiration warning time should be [7, 180]"; | |||
| } | |||
| warning_time = res_time; | |||
| } | |||
| MS_LOG(INFO) << "The interval time is:" << interval << ", the warning time is:" << warning_time; | |||
| BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r"); | |||
| MS_EXCEPTION_IF_NULL(ca_bio); | |||
| X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr); | |||
| BIO_free_all(ca_bio); | |||
| MS_EXCEPTION_IF_NULL(ca_cert); | |||
| running_ = true; | |||
| check_time_thread_ = std::make_unique<std::thread>([&, cert, ca_cert, interval, warning_time]() { | |||
| while (running_) { | |||
| if (!CommUtil::VerifyCertTime(cert, warning_time)) { | |||
| MS_LOG(WARNING) << "Verify server cert time failed."; | |||
| } | |||
| if (!CommUtil::VerifyCertTime(ca_cert, warning_time)) { | |||
| MS_LOG(WARNING) << "Verify ca cert time failed."; | |||
| } | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| bool res = cond_.wait_for(lock, std::chrono::hours(interval), [&] { | |||
| bool result = is_ready_.load(); | |||
| return result; | |||
| }); | |||
| MS_LOG(INFO) << "Wait for res:" << res; | |||
| } | |||
| }); | |||
| MS_EXCEPTION_IF_NULL(check_time_thread_); | |||
| } | |||
| int pubKeyLen = RSA_size(pRSAPublicKey); | |||
| int ret = RSA_verify(NID_sha256, srcData, srcDataLen, signData, pubKeyLen, pRSAPublicKey); | |||
| if (ret != 1) { | |||
| MS_LOG(WARNING) << "Verify error."; | |||
| int64_t ulErr = ERR_get_error(); | |||
| char szErrMsg[1024] = {0}; | |||
| MS_LOG(WARNING) << "Error number: " << ulErr; | |||
| ERR_error_string(ulErr, szErrMsg); | |||
| MS_LOG(INFO) << "Error message:" << szErrMsg; | |||
| return false; | |||
| void SSLWrapper::StopCheckCertTime() { | |||
| running_ = false; | |||
| is_ready_ = true; | |||
| cond_.notify_all(); | |||
| if (check_time_thread_ != nullptr) { | |||
| check_time_thread_->join(); | |||
| } | |||
| RSA_free(pRSAPublicKey); | |||
| X509_free(keyAttestationCertObj); | |||
| CRYPTO_cleanup_all_ex_data(); | |||
| MS_LOG(INFO) << "VerifyRSAKey success."; | |||
| return true; | |||
| } | |||
| SSL_CTX *SSLWrapper::GetSSLCtx(bool) { return ssl_ctx_; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -27,9 +27,16 @@ | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <chrono> | |||
| #include <condition_variable> | |||
| #include <mutex> | |||
| #include <atomic> | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/core/comm_util.h" | |||
| #include "ps/core/file_configuration.h" | |||
| #include "ps/constants.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -42,29 +49,6 @@ class SSLWrapper { | |||
| } | |||
| SSL_CTX *GetSSLCtx(bool is_server = true); | |||
| // read certificate from file path | |||
| X509 *ReadCertFromFile(const std::string &certPath) const; | |||
| // read Certificate Revocation List from file absolute path | |||
| X509_CRL *ReadCrlFromFile(const std::string &crlPath) const; | |||
| // read certificate from pem string | |||
| X509 *ReadCertFromPerm(std::string cert); | |||
| // verify valid of certificate time | |||
| bool VerifyCertTime(const X509 *cert) const; | |||
| // verify valid of certificate chain | |||
| bool VerifyCAChain(const std::string &keyAttestation, const std::string &equipCert, const std::string &equipCACert, | |||
| std::string rootCert); | |||
| // verify valid of sign data | |||
| bool VerifyRSAKey(const std::string &keyAttestation, const unsigned char *srcData, const unsigned char *signData, | |||
| int srcDataLen); | |||
| // verify valid of equip certificate with CRL | |||
| bool VerifyCRL(const std::string &equipCert); | |||
| private: | |||
| SSLWrapper(); | |||
| virtual ~SSLWrapper(); | |||
| @@ -73,18 +57,22 @@ class SSLWrapper { | |||
| void InitSSL(); | |||
| void CleanSSL(); | |||
| time_t ConvertAsn1Time(const ASN1_TIME *const time) const; | |||
| void StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path); | |||
| void StopCheckCertTime(); | |||
| SSL_CTX *ssl_ctx_; | |||
| SSL_CTX *client_ssl_ctx_; | |||
| // The firset root ca certificate. | |||
| X509 *rootFirstCA_; | |||
| // The second root ca certificate. | |||
| X509 *rootSecondCA_; | |||
| // The firset root revocation list | |||
| X509_CRL *rootFirstCrl_; | |||
| // The second root revocation list | |||
| X509_CRL *rootSecondCrl_; | |||
| std::unique_ptr<std::thread> check_time_thread_; | |||
| std::atomic<bool> running_; | |||
| std::atomic<bool> is_ready_; | |||
| std::mutex mutex_; | |||
| std::condition_variable cond_; | |||
| std::mutex verify_mutex_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -37,7 +37,7 @@ event_base *TcpClient::event_base_ = nullptr; | |||
| std::mutex TcpClient::event_base_mutex_; | |||
| bool TcpClient::is_started_ = false; | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *config) | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config) | |||
| : event_timeout_(nullptr), | |||
| buffer_event_(nullptr), | |||
| server_address_(std::move(address)), | |||
| @@ -87,8 +87,6 @@ void TcpClient::Init() { | |||
| MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | |||
| } | |||
| event_enable_debug_logging(EVENT_DBG_ALL); | |||
| event_set_log_callback(CommUtil::LogCallback); | |||
| int result = evthread_use_pthreads(); | |||
| if (result != 0) { | |||
| MS_LOG(EXCEPTION) << "Use event pthread failed!"; | |||
| @@ -108,10 +106,12 @@ void TcpClient::Init() { | |||
| sin.sin_port = htons(server_port_); | |||
| if (!PSContext::instance()->enable_ssl()) { | |||
| MS_LOG(INFO) << "SSL is disable."; | |||
| buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); | |||
| } else { | |||
| if (!EstablishSSL()) { | |||
| MS_LOG(EXCEPTION) << "Establish SSL failed."; | |||
| MS_LOG(WARNING) << "Establish SSL failed."; | |||
| return; | |||
| } | |||
| } | |||
| @@ -135,18 +135,21 @@ void TcpClient::StartWithDelay(int seconds) { | |||
| } | |||
| event_base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| timeval timeout_value{}; | |||
| timeout_value.tv_sec = seconds; | |||
| timeout_value.tv_usec = 0; | |||
| event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this); | |||
| MS_EXCEPTION_IF_NULL(event_timeout_); | |||
| if (evtimer_add(event_timeout_, &timeout_value) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event timeout failed!"; | |||
| } | |||
| } | |||
| void TcpClient::Stop() { | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Stop tcp client!"; | |||
| int ret = event_base_loopbreak(event_base_); | |||
| @@ -177,8 +180,8 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { | |||
| char read_buffer[kMessageChunkLength]; | |||
| int read = 0; | |||
| while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) { | |||
| tcp_client->OnReadHandler(read_buffer, read); | |||
| while ((read = bufferevent_read(bev, &read_buffer, SizeToInt(sizeof(read_buffer)))) > 0) { | |||
| tcp_client->OnReadHandler(read_buffer, IntToSize(read)); | |||
| } | |||
| } | |||
| @@ -207,70 +210,9 @@ void TcpClient::NotifyConnected() { | |||
| bool TcpClient::EstablishSSL() { | |||
| MS_LOG(INFO) << "Enable ssl support."; | |||
| if (config_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The config is empty."; | |||
| } | |||
| SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx(false)); | |||
| // 1.Parse the client's certificate and the ciphertext of key. | |||
| std::string client_cert = kCertificateChain; | |||
| std::string path = CommUtil::ParseConfig(*config_, kClientCertPath); | |||
| if (!CommUtil::IsFileExists(path)) { | |||
| MS_LOG(WARNING) << "The key:" << kClientCertPath << "'s value is not exist."; | |||
| return false; | |||
| } | |||
| client_cert = path; | |||
| MS_LOG(INFO) << "The client cert path:" << client_cert; | |||
| // 2. Parse the client password. | |||
| std::string client_password = CommUtil::ParseConfig(*config_, kClientPassword); | |||
| if (client_password.empty()) { | |||
| MS_LOG(WARNING) << "The key:" << kClientPassword << "'s value is empty."; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "The client password:" << client_password; | |||
| EVP_PKEY *pkey = nullptr; | |||
| X509 *cert = nullptr; | |||
| STACK_OF(X509) *ca_stack = nullptr; | |||
| BIO *bio = BIO_new_file(client_cert.c_str(), "rb"); | |||
| PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr); | |||
| BIO_free_all(bio); | |||
| PKCS12_parse(p12, client_password.c_str(), &pkey, &cert, &ca_stack); | |||
| PKCS12_free(p12); | |||
| if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) { | |||
| MS_LOG(EXCEPTION) << "Verify cert time failed."; | |||
| } | |||
| if (!SSL_CTX_use_certificate(SSLWrapper::GetInstance().GetSSLCtx(false), cert)) { | |||
| MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; | |||
| } | |||
| if (!SSL_CTX_use_PrivateKey(SSLWrapper::GetInstance().GetSSLCtx(false), pkey)) { | |||
| MS_LOG(EXCEPTION) << "SSL use private key file failed!"; | |||
| } | |||
| std::string client_ca = kCAcrt; | |||
| std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath); | |||
| if (!CommUtil::IsFileExists(ca_path)) { | |||
| MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist."; | |||
| } | |||
| client_ca = ca_path; | |||
| MS_LOG(INFO) << "The ca cert path:" << client_ca; | |||
| if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx(false))) { | |||
| MS_LOG(EXCEPTION) << "SSL check private key file failed!"; | |||
| } | |||
| if (!SSL_CTX_load_verify_locations(SSLWrapper::GetInstance().GetSSLCtx(false), client_ca.c_str(), nullptr)) { | |||
| MS_LOG(EXCEPTION) << "SSL load ca location failed!"; | |||
| } | |||
| SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(false), SSL_OP_NO_SSLv2); | |||
| SSL *ssl = SSL_new(SSLClient::GetInstance().GetSSLCtx()); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(ssl, false); | |||
| MS_ERROR_IF_NULL_W_RET_VAL(event_base_, false); | |||
| buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING, | |||
| BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); | |||
| @@ -334,7 +276,7 @@ bool TcpClient::SendMessage(const CommMessage &message) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| bufferevent_lock(buffer_event_); | |||
| bool res = true; | |||
| size_t buf_size = IntToUint(message.ByteSizeLong()); | |||
| size_t buf_size = message.ByteSizeLong(); | |||
| uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong()); | |||
| MessageHeader header; | |||
| header.message_proto_ = Protos::PROTOBUF; | |||
| @@ -390,20 +332,6 @@ bool TcpClient::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Prot | |||
| return res; | |||
| } | |||
| void TcpClient::StartTimer(const uint32_t &time) { | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| struct event *ev = nullptr; | |||
| if (time == 0) { | |||
| MS_LOG(EXCEPTION) << "The time should not be 0!"; | |||
| } | |||
| struct timeval timeout {}; | |||
| timeout.tv_sec = time; | |||
| timeout.tv_usec = 0; | |||
| ev = event_new(event_base_, -1, EV_PERSIST, TimerCallback, this); | |||
| MS_EXCEPTION_IF_NULL(ev); | |||
| evtimer_add(ev, &timeout); | |||
| } | |||
| void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | |||
| const event_base &TcpClient::eventbase() const { return *event_base_; } | |||
| @@ -34,6 +34,7 @@ | |||
| #include "ps/core/cluster_config.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "ps/core/comm_util.h" | |||
| #include "ps/core/communicator/ssl_client.h" | |||
| #include "ps/core/communicator/ssl_wrapper.h" | |||
| #include "ps/constants.h" | |||
| #include "ps/ps_context.h" | |||
| @@ -69,7 +70,6 @@ class TcpClient { | |||
| void SetMessageCallback(const OnMessage &cb); | |||
| bool SendMessage(const CommMessage &message) const; | |||
| bool SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size); | |||
| void StartTimer(const uint32_t &time); | |||
| void set_timer_callback(const OnTimer &timer); | |||
| const event_base &eventbase() const; | |||
| @@ -107,7 +107,7 @@ class TcpClient { | |||
| std::atomic<bool> is_stop_; | |||
| std::atomic<bool> is_connected_; | |||
| // The Configuration file | |||
| Configuration *const config_; | |||
| Configuration *config_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -35,12 +35,19 @@ namespace core { | |||
| TcpConnection::~TcpConnection() { bufferevent_free(buffer_event_); } | |||
| void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); } | |||
| void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); } | |||
| void TcpConnection::OnReadHandler(const void *buffer, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(buffer); | |||
| tcp_message_handler_.ReceiveMessage(buffer, num); | |||
| } | |||
| void TcpConnection::SendMessage(const void *buffer, size_t num) const { | |||
| MS_EXCEPTION_IF_NULL(buffer); | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| bufferevent_lock(buffer_event_); | |||
| if (bufferevent_write(buffer_event_, buffer, num) == -1) { | |||
| MS_LOG(ERROR) << "Write message to buffer event failed!"; | |||
| } | |||
| bufferevent_unlock(buffer_event_); | |||
| } | |||
| const TcpServer *TcpConnection::GetServer() const { return server_; } | |||
| @@ -93,24 +100,22 @@ bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const | |||
| } | |||
| int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH); | |||
| if (result < 0) { | |||
| bufferevent_unlock(buffer_event_); | |||
| MS_LOG(EXCEPTION) << "Bufferevent flush failed!"; | |||
| } | |||
| bufferevent_unlock(buffer_event_); | |||
| MS_LOG(DEBUG) << "SendMessage the request id is:" << meta->request_id() << " the current time is:" | |||
| << std::chrono::time_point_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now()) | |||
| .time_since_epoch() | |||
| .count(); | |||
| return res; | |||
| } | |||
| TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *config) | |||
| TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *const config) | |||
| : base_(nullptr), | |||
| signal_event_(nullptr), | |||
| listener_(nullptr), | |||
| server_address_(std::move(address)), | |||
| server_port_(port), | |||
| is_stop_(true), | |||
| config_(config) {} | |||
| config_(config), | |||
| max_connection_(0) {} | |||
| TcpServer::~TcpServer() { | |||
| if (signal_event_ != nullptr) { | |||
| @@ -146,14 +151,18 @@ void TcpServer::Init() { | |||
| MS_LOG(EXCEPTION) << "Use event pthread failed!"; | |||
| } | |||
| event_enable_debug_logging(EVENT_DBG_ALL); | |||
| event_set_log_callback(CommUtil::LogCallback); | |||
| is_stop_ = false; | |||
| base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| if (!CommUtil::CheckIp(server_address_)) { | |||
| MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(config_); | |||
| max_connection_ = kConnectionNumDefault; | |||
| if (config_->Exists(kConnectionNum)) { | |||
| max_connection_ = config_->GetInt(kConnectionNum, 0); | |||
| } | |||
| MS_LOG(INFO) << "The max connection is:" << max_connection_; | |||
| struct sockaddr_in sin {}; | |||
| if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | |||
| @@ -210,35 +219,8 @@ void TcpServer::StartWithNoBlock() { | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!"; | |||
| } | |||
| void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| if (time == 0) { | |||
| MS_LOG(EXCEPTION) << "The time should not be 0!"; | |||
| } | |||
| struct event *ev = nullptr; | |||
| struct timeval timeout {}; | |||
| timeout.tv_sec = time; | |||
| timeout.tv_usec = 0; | |||
| ev = evtimer_new(base_, TimerOnceCallback, this); | |||
| MS_EXCEPTION_IF_NULL(ev); | |||
| evtimer_add(ev, &timeout); | |||
| } | |||
| void TcpServer::StartTimer(const uint32_t &time) { | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| struct event *ev = nullptr; | |||
| if (time == 0) { | |||
| MS_LOG(EXCEPTION) << "The time should not be 0!"; | |||
| } | |||
| struct timeval timeout {}; | |||
| timeout.tv_sec = time; | |||
| timeout.tv_usec = 0; | |||
| ev = event_new(base_, -1, EV_PERSIST, TimerCallback, this); | |||
| MS_EXCEPTION_IF_NULL(ev); | |||
| evtimer_add(ev, &timeout); | |||
| } | |||
| void TcpServer::Stop() { | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Stop tcp server!"; | |||
| if (event_base_got_break(base_)) { | |||
| @@ -279,71 +261,29 @@ std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_ | |||
| void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, | |||
| void *data) { | |||
| auto server = reinterpret_cast<class TcpServer *>(data); | |||
| auto base = reinterpret_cast<struct event_base *>(server->base_); | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| auto base = reinterpret_cast<struct event_base *>(server->base_); | |||
| MS_EXCEPTION_IF_NULL(base); | |||
| MS_EXCEPTION_IF_NULL(sockaddr); | |||
| if (server->ConnectionNum() >= server->max_connection_) { | |||
| MS_LOG(WARNING) << "The current connection num:" << server->ConnectionNum() << " is greater or equal to " | |||
| << server->max_connection_; | |||
| return; | |||
| } | |||
| struct bufferevent *bev = nullptr; | |||
| if (!PSContext::instance()->enable_ssl()) { | |||
| MS_LOG(INFO) << "SSL is disable."; | |||
| bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); | |||
| } else { | |||
| MS_LOG(INFO) << "Enable ssl support."; | |||
| if (server->config_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The config is empty."; | |||
| } | |||
| SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx()); | |||
| // 1.Parse the server's certificate and the ciphertext of key. | |||
| std::string server_cert = kCertificateChain; | |||
| std::string path = CommUtil::ParseConfig(*(server->config_), kServerCertPath); | |||
| if (!CommUtil::IsFileExists(path)) { | |||
| MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist."; | |||
| } | |||
| server_cert = path; | |||
| MS_LOG(INFO) << "The server cert path:" << server_cert; | |||
| // 2. Parse the server password. | |||
| std::string server_password = CommUtil::ParseConfig(*(server->config_), kServerPassword); | |||
| if (server_password.empty()) { | |||
| MS_LOG(EXCEPTION) << "The key:" << kServerPassword << "'s value is empty."; | |||
| } | |||
| MS_LOG(INFO) << "The server password:" << server_password; | |||
| EVP_PKEY *pkey = nullptr; | |||
| X509 *cert = nullptr; | |||
| STACK_OF(X509) *ca_stack = nullptr; | |||
| BIO *bio = BIO_new_file(server_cert.c_str(), "rb"); | |||
| PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr); | |||
| BIO_free_all(bio); | |||
| PKCS12_parse(p12, server_password.c_str(), &pkey, &cert, &ca_stack); | |||
| PKCS12_free(p12); | |||
| if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) { | |||
| MS_LOG(EXCEPTION) << "Verify cert time failed."; | |||
| } | |||
| if (!SSL_CTX_use_certificate(SSLWrapper::GetInstance().GetSSLCtx(), cert)) { | |||
| MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; | |||
| } | |||
| if (!SSL_CTX_use_PrivateKey(SSLWrapper::GetInstance().GetSSLCtx(), pkey)) { | |||
| MS_LOG(EXCEPTION) << "SSL use private key file failed!"; | |||
| } | |||
| if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx())) { | |||
| MS_LOG(EXCEPTION) << "SSL check private key file failed!"; | |||
| } | |||
| SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(), SSL_OP_NO_SSLv2); | |||
| MS_EXCEPTION_IF_NULL(ssl); | |||
| bev = bufferevent_openssl_socket_new(base, fd, ssl, BUFFEREVENT_SSL_ACCEPTING, | |||
| BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); | |||
| } | |||
| if (bev == nullptr) { | |||
| MS_LOG(ERROR) << "Error constructing buffer event!"; | |||
| int ret = event_base_loopbreak(base); | |||
| @@ -386,9 +326,10 @@ std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent | |||
| OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; } | |||
| void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto server = reinterpret_cast<class TcpServer *>(data); | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| struct event_base *base = server->base_; | |||
| MS_EXCEPTION_IF_NULL(base); | |||
| struct timeval delay = {0, 0}; | |||
| MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds."; | |||
| if (event_base_loopexit(base, &delay) == -1) { | |||
| @@ -402,6 +343,7 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { | |||
| auto conn = static_cast<class TcpConnection *>(connection); | |||
| struct evbuffer *buf = bufferevent_get_input(bev); | |||
| MS_EXCEPTION_IF_NULL(buf); | |||
| char read_buffer[kMessageChunkLength]; | |||
| while (EVBUFFER_LENGTH(buf) > 0) { | |||
| int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)); | |||
| @@ -409,11 +351,6 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { | |||
| MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; | |||
| } | |||
| conn->OnReadHandler(read_buffer, IntToSize(read)); | |||
| MS_LOG(DEBUG) << "the current time is:" | |||
| << std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now()) | |||
| .time_since_epoch() | |||
| .count() | |||
| << " the read size is:" << read; | |||
| } | |||
| } | |||
| @@ -421,9 +358,10 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| struct evbuffer *output = bufferevent_get_output(bev); | |||
| size_t remain = evbuffer_get_length(output); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| auto conn = static_cast<class TcpConnection *>(data); | |||
| auto srv = const_cast<TcpServer *>(conn->GetServer()); | |||
| MS_EXCEPTION_IF_NULL(srv); | |||
| if (events & BEV_EVENT_EOF) { | |||
| MS_LOG(INFO) << "Event buffer end of file, a client is disconnected from this server!"; | |||
| @@ -434,7 +372,15 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||
| // Free connection structures | |||
| srv->RemoveConnection(conn->GetFd()); | |||
| } else if (events & BEV_EVENT_ERROR) { | |||
| MS_LOG(WARNING) << "Event buffer remain data: " << remain; | |||
| MS_LOG(WARNING) << "Connect to server error."; | |||
| if (PSContext::instance()->enable_ssl()) { | |||
| uint64_t err = bufferevent_get_openssl_error(bev); | |||
| MS_LOG(WARNING) << "The error number is:" << err; | |||
| MS_LOG(WARNING) << "Error message:" << ERR_reason_error_string(err) | |||
| << ", the error lib:" << ERR_lib_error_string(err) | |||
| << ", the error func:" << ERR_func_error_string(err); | |||
| } | |||
| // Free connection structures | |||
| srv->RemoveConnection(conn->GetFd()); | |||
| @@ -443,7 +389,7 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||
| srv->client_disconnection_(*srv, *conn); | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Unhandled event!"; | |||
| MS_LOG(WARNING) << "Unhandled event:" << events; | |||
| } | |||
| } | |||
| @@ -57,11 +57,11 @@ class TcpConnection { | |||
| using Callback = std::function<void(const std::shared_ptr<CommMessage>)>; | |||
| virtual void InitConnection(const messageReceive &callback); | |||
| virtual void SendMessage(const void *buffer, size_t num) const; | |||
| void InitConnection(const messageReceive &callback); | |||
| void SendMessage(const void *buffer, size_t num) const; | |||
| bool SendMessage(const std::shared_ptr<CommMessage> &message) const; | |||
| bool SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) const; | |||
| virtual void OnReadHandler(const void *buffer, size_t numBytes); | |||
| void OnReadHandler(const void *buffer, size_t numBytes); | |||
| const TcpServer *GetServer() const; | |||
| const evutil_socket_t &GetFd() const; | |||
| void set_callback(const Callback &callback); | |||
| @@ -97,8 +97,6 @@ class TcpServer { | |||
| void Init(); | |||
| void Start(); | |||
| void StartWithNoBlock(); | |||
| void StartTimerOnlyOnce(const uint32_t &time); | |||
| void StartTimer(const uint32_t &time); | |||
| void Stop(); | |||
| void SendToAllClients(const char *data, size_t len); | |||
| void AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection); | |||
| @@ -142,7 +140,8 @@ class TcpServer { | |||
| OnTimerOnce on_timer_once_callback_; | |||
| OnTimer on_timer_callback_; | |||
| // The Configuration file | |||
| Configuration *const config_; | |||
| Configuration *config_; | |||
| int64_t max_connection_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -413,5 +413,11 @@ std::string PSContext::config_file_path() const { return config_file_path_; } | |||
| void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; } | |||
| const std::string &PSContext::node_id() const { return node_id_; } | |||
| std::string PSContext::client_password() const { return client_password_; } | |||
| void PSContext::set_client_password(const std::string &password) { client_password_ = password; } | |||
| std::string PSContext::server_password() const { return server_password_; } | |||
| void PSContext::set_server_password(const std::string &password) { server_password_ = password; } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -83,6 +83,11 @@ class PSContext { | |||
| bool enable_ssl() const; | |||
| void set_enable_ssl(bool enabled); | |||
| std::string client_password() const; | |||
| void set_client_password(const std::string &password); | |||
| std::string server_password() const; | |||
| void set_server_password(const std::string &password); | |||
| // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set | |||
| // by ps_context. | |||
| void set_server_mode(const std::string &server_mode); | |||
| @@ -218,7 +223,9 @@ class PSContext { | |||
| dp_delta_(0.01), | |||
| dp_norm_clip_(1.0), | |||
| encrypt_type_(kNotEncryptType), | |||
| node_id_("") {} | |||
| node_id_(""), | |||
| client_password_(""), | |||
| server_password_("") {} | |||
| bool ps_enabled_; | |||
| bool is_worker_; | |||
| bool is_pserver_; | |||
| @@ -310,6 +317,11 @@ class PSContext { | |||
| // Unique id of the node | |||
| std::string node_id_; | |||
| // Password used to decode p12 file. | |||
| std::string client_password_; | |||
| // Password used to decode p12 file. | |||
| std::string server_password_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||