Browse Source

fixed ssl error

tags/v1.5.0-rc1
chendongsheng 4 years ago
parent
commit
b05dbd5756
13 changed files with 317 additions and 400 deletions
  1. +25
    -31
      mindspore/ccsrc/ps/core/communicator/http_request_handler.cc
  2. +2
    -2
      mindspore/ccsrc/ps/core/communicator/http_request_handler.h
  3. +17
    -5
      mindspore/ccsrc/ps/core/communicator/http_server.cc
  4. +1
    -1
      mindspore/ccsrc/ps/core/communicator/ssl_client.cc
  5. +6
    -1
      mindspore/ccsrc/ps/core/communicator/ssl_http.cc
  6. +168
    -141
      mindspore/ccsrc/ps/core/communicator/ssl_wrapper.cc
  7. +16
    -28
      mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h
  8. +13
    -85
      mindspore/ccsrc/ps/core/communicator/tcp_client.cc
  9. +2
    -2
      mindspore/ccsrc/ps/core/communicator/tcp_client.h
  10. +43
    -97
      mindspore/ccsrc/ps/core/communicator/tcp_server.cc
  11. +5
    -6
      mindspore/ccsrc/ps/core/communicator/tcp_server.h
  12. +6
    -0
      mindspore/ccsrc/ps/ps_context.cc
  13. +13
    -1
      mindspore/ccsrc/ps/ps_context.h

+ 25
- 31
mindspore/ccsrc/ps/core/communicator/http_request_handler.cc View File

@@ -19,6 +19,13 @@
namespace mindspore {
namespace ps {
namespace core {
HttpRequestHandler::~HttpRequestHandler() {
if (evbase_) {
event_base_free(evbase_);
evbase_ = nullptr;
}
}

bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers) {
evbase_ = event_base_new();
MS_EXCEPTION_IF_NULL(evbase_);
@@ -27,33 +34,18 @@ bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string

if (PSContext::instance()->enable_ssl()) {
MS_LOG(INFO) << "Enable ssl support.";
SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(),
SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2);
EC_KEY *ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
MS_EXCEPTION_IF_NULL(ecdh);

X509 *cert = SSLWrapper::GetInstance().ReadCertFromFile(kCertificateChain);
if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) {
MS_LOG(INFO) << "Verify cert time failed.";
return false;
}

if (!SSL_CTX_use_certificate_chain_file(SSLWrapper::GetInstance().GetSSLCtx(), kCertificateChain)) {
MS_LOG(ERROR) << "SSL use certificate chain file failed!";
return false;
}

if (!SSL_CTX_use_PrivateKey_file(SSLWrapper::GetInstance().GetSSLCtx(), kPrivateKey, SSL_FILETYPE_PEM)) {
MS_LOG(ERROR) << "SSL use private key file failed!";
return false;
}

if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx())) {
MS_LOG(ERROR) << "SSL check private key file failed!";
return false;
if (!SSL_CTX_set_options(SSLHTTP::GetInstance().GetSSLCtx(), SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE |
SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
SSL_OP_NO_TLSv1_1)) {
if (evbase_) {
event_base_free(evbase_);
evbase_ = nullptr;
}
evhttp_free(http);
http = nullptr;
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
}

evhttp_set_bevcb(http, BuffereventCallback, SSLWrapper::GetInstance().GetSSLCtx());
evhttp_set_bevcb(http, BuffereventCallback, SSLHTTP::GetInstance().GetSSLCtx());
}

int result = evhttp_accept_socket(http, fd);
@@ -67,6 +59,7 @@ bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string
MS_EXCEPTION_IF_NULL(req);
MS_EXCEPTION_IF_NULL(arg);
auto httpReq = std::make_shared<HttpMessageHandler>();
MS_EXCEPTION_IF_NULL(httpReq);
httpReq->set_request(req);
httpReq->InitHttpMessage();
OnRequestReceive *func = reinterpret_cast<OnRequestReceive *>(arg);
@@ -74,6 +67,7 @@ bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string
};

// O SUCCESS,-1 ALREADY_EXIST,-2 FAILURE
MS_EXCEPTION_IF_NULL(handler.second);
int ret = evhttp_set_cb(http, handler.first.c_str(), TransFunc, reinterpret_cast<void *>(handler.second));
std::string log_prefix = "Ev http register handle of:";
if (ret == 0) {
@@ -101,16 +95,12 @@ void HttpRequestHandler::Run() {
} else {
MS_LOG(ERROR) << "Event base dispatch with unexpected error code!";
}

if (evbase_) {
event_base_free(evbase_);
evbase_ = nullptr;
}
}

bool HttpRequestHandler::Stop() {
MS_LOG(INFO) << "Stop http server!";

MS_EXCEPTION_IF_NULL(evbase_);
int ret = event_base_loopbreak(evbase_);
if (ret != 0) {
MS_LOG(ERROR) << "event base loop break failed!";
@@ -120,9 +110,13 @@ bool HttpRequestHandler::Stop() {
}

bufferevent *HttpRequestHandler::BuffereventCallback(event_base *base, void *arg) {
MS_EXCEPTION_IF_NULL(base);
MS_EXCEPTION_IF_NULL(arg);
SSL_CTX *ctx = reinterpret_cast<SSL_CTX *>(arg);
SSL *ssl = SSL_new(ctx);
MS_EXCEPTION_IF_NULL(ssl);
bufferevent *bev = bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE);
MS_EXCEPTION_IF_NULL(bev);
return bev;
}
} // namespace core


+ 2
- 2
mindspore/ccsrc/ps/core/communicator/http_request_handler.h View File

@@ -29,7 +29,7 @@

#include "utils/log_adapter.h"
#include "ps/core/communicator/http_message_handler.h"
#include "ps/core/communicator/ssl_wrapper.h"
#include "ps/core/communicator/ssl_http.h"
#include "ps/constants.h"
#include "ps/ps_context.h"

@@ -44,7 +44,7 @@ using OnRequestReceive = std::function<void(std::shared_ptr<HttpMessageHandler>)
class HttpRequestHandler {
public:
HttpRequestHandler() : evbase_(nullptr) {}
virtual ~HttpRequestHandler() = default;
virtual ~HttpRequestHandler();

bool Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers);
void Run();


+ 17
- 5
mindspore/ccsrc/ps/core/communicator/http_server.cc View File

@@ -41,11 +41,15 @@
namespace mindspore {
namespace ps {
namespace core {
HttpServer::~HttpServer() { Stop(); }
HttpServer::~HttpServer() {
if (!Stop()) {
MS_LOG(WARNING) << "Stop http server failed.";
}
}

bool HttpServer::InitServer() {
if (server_address_ == "") {
MS_LOG(INFO) << "The server ip is empty.";
MS_LOG(WARNING) << "The server address is empty.";
std::string interface;
std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
@@ -91,22 +95,24 @@ bool HttpServer::InitServer() {

result = ::bind(fd_, (struct sockaddr *)&addr, sizeof(addr));
if (result < 0) {
MS_LOG(ERROR) << "Bind ip:" << server_address_ << " port:" << server_port_ << " failed!";
MS_LOG(ERROR) << "Bind ip:" << server_address_ << " port:" << server_port_ << "failed!";
close(fd_);
fd_ = -1;
return false;
}

MS_LOG(INFO) << "Bind ip:" << server_address_ << " port:" << server_port_ << " successful!";

result = ::listen(fd_, backlog_);
if (result < 0) {
MS_LOG(ERROR) << "Listen ip:" << server_address_ << " port:" << server_port_ << " failed!";
MS_LOG(ERROR) << "Listen ip:" << server_address_ << " port:" << server_port_ << "failed!";
close(fd_);
fd_ = -1;
return false;
}

int flags = 0;
if ((flags = fcntl(fd_, F_GETFL, 0)) < 0 || fcntl(fd_, F_SETFL, flags | O_NONBLOCK) < 0) {
if ((flags = fcntl(fd_, F_GETFL, 0)) < 0 || fcntl(fd_, F_SETFL, (unsigned int)flags | O_NONBLOCK) < 0) {
MS_LOG(ERROR) << "Set fcntl O_NONBLOCK failed!";
close(fd_);
fd_ = -1;
@@ -135,12 +141,14 @@ bool HttpServer::Start(bool is_detach) {
MS_LOG(INFO) << "Start http server!";
for (size_t i = 0; i < thread_num_; i++) {
auto http_request_handler = std::make_shared<HttpRequestHandler>();
MS_EXCEPTION_IF_NULL(http_request_handler);
if (!http_request_handler->Initialize(fd_, request_handlers_)) {
MS_LOG(ERROR) << "Http initialize failed.";
return false;
}
http_request_handlers.push_back(http_request_handler);
auto thread = std::make_shared<std::thread>(&HttpRequestHandler::Run, http_request_handler);
MS_EXCEPTION_IF_NULL(thread);
if (is_detach) {
thread->detach();
}
@@ -168,6 +176,10 @@ bool HttpServer::Stop() {
result = false;
}
}
if (fd_ != -1) {
close(fd_);
fd_ = -1;
}
is_stop_ = true;
}
return result;


+ 1
- 1
mindspore/ccsrc/ps/core/communicator/ssl_client.cc View File

@@ -58,7 +58,7 @@ void SSLClient::InitSSL() {
client_cert = path;

// 2. Parse the client password.
std::string client_password = CommUtil::ParseConfig(*config_, kClientPassword);
std::string client_password = PSContext::instance()->client_password();
if (client_password.empty()) {
MS_LOG(EXCEPTION) << "The client password's value is empty.";
}


+ 6
- 1
mindspore/ccsrc/ps/core/communicator/ssl_http.cc View File

@@ -62,7 +62,7 @@ void SSLHTTP::InitSSL() {
server_cert = path;

// 2. Parse the server password.
std::string server_password = CommUtil::ParseConfig(*(config_), kServerPassword);
std::string server_password = PSContext::instance()->server_password();
if (server_password.empty()) {
MS_LOG(EXCEPTION) << "The client password's value is empty.";
}
@@ -93,6 +93,11 @@ void SSLHTTP::InitSSL() {
if (!SSL_CTX_check_private_key(ssl_ctx_)) {
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
}

if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
}
}

void SSLHTTP::CleanSSL() {


+ 168
- 141
mindspore/ccsrc/ps/core/communicator/ssl_wrapper.cc View File

@@ -1,4 +1,3 @@

/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
@@ -33,11 +32,11 @@ namespace ps {
namespace core {
SSLWrapper::SSLWrapper()
: ssl_ctx_(nullptr),
client_ssl_ctx_(nullptr),
rootFirstCA_(nullptr),
rootSecondCA_(nullptr),
rootFirstCrl_(nullptr),
rootSecondCrl_(nullptr) {
check_time_thread_(nullptr),
running_(false),
is_ready_(false) {
InitSSL();
}

@@ -45,178 +44,206 @@ SSLWrapper::~SSLWrapper() { CleanSSL(); }

void SSLWrapper::InitSSL() {
CommUtil::InitOpenSSLEnv();
int rand = RAND_poll();
if (rand == 0) {
MS_LOG(ERROR) << "RAND_poll failed";
}
ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
if (!ssl_ctx_) {
MS_LOG(ERROR) << "SSL_CTX_new failed";
MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
}
X509_STORE *store = SSL_CTX_get_cert_store(ssl_ctx_);
MS_EXCEPTION_IF_NULL(store);
if (X509_STORE_set_default_paths(store) != 1) {
MS_LOG(ERROR) << "X509_STORE_set_default_paths failed";
}
client_ssl_ctx_ = SSL_CTX_new(SSLv23_client_method());
if (!ssl_ctx_) {
MS_LOG(ERROR) << "SSL_CTX_new failed";
MS_LOG(EXCEPTION) << "X509_STORE_set_default_paths failed";
}
}
void SSLWrapper::CleanSSL() {
if (ssl_ctx_ != nullptr) {
SSL_CTX_free(ssl_ctx_);
std::unique_ptr<Configuration> config_ =
std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(EXCEPTION) << "The config file is empty.";
}
ERR_free_strings();
EVP_cleanup();
ERR_remove_thread_state(nullptr);
CRYPTO_cleanup_all_ex_data();
}

SSL_CTX *SSLWrapper::GetSSLCtx(bool is_server) {
if (is_server) {
return ssl_ctx_;
} else {
return client_ssl_ctx_;
// 1.Parse the server's certificate and the ciphertext of key.
std::string server_cert = kCertificateChain;
std::string path = CommUtil::ParseConfig(*(config_), kServerCertPath);
if (!CommUtil::IsFileExists(path)) {
MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist.";
}
}

X509 *SSLWrapper::ReadCertFromFile(const std::string &certPath) const {
BIO *bio = BIO_new_file(certPath.c_str(), "r");
return PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
}

X509 *SSLWrapper::ReadCertFromPerm(std::string cert) {
BIO *bio = BIO_new_mem_buf(reinterpret_cast<void *>(cert.data()), -1);
return PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
}

X509_CRL *SSLWrapper::ReadCrlFromFile(const std::string &crlPath) const {
BIO *bio = BIO_new_file(crlPath.c_str(), "r");
return PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
}
server_cert = path;

bool SSLWrapper::VerifyCertTime(const X509 *cert) const {
ASN1_TIME *start = X509_getm_notBefore(cert);
ASN1_TIME *end = X509_getm_notAfter(cert);
MS_LOG(INFO) << "The server cert path:" << server_cert;

int day = 0;
int sec = 0;
ASN1_TIME_diff(&day, &sec, start, NULL);
// 2. Parse the server password.
std::string server_password = PSContext::instance()->server_password();
if (server_password.empty()) {
MS_LOG(EXCEPTION) << "The client password's value is empty.";
}

if (day < 0 || sec < 0) {
MS_LOG(INFO) << "Cert start time is later than now time.";
return false;
EVP_PKEY *pkey = nullptr;
X509 *cert = nullptr;
STACK_OF(X509) *ca_stack = nullptr;
BIO *bio = BIO_new_file(server_cert.c_str(), "rb");
MS_EXCEPTION_IF_NULL(bio);
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
MS_EXCEPTION_IF_NULL(p12);
BIO_free_all(bio);
if (!PKCS12_parse(p12, server_password.c_str(), &pkey, &cert, &ca_stack)) {
MS_LOG(EXCEPTION) << "PKCS12_parse failed.";
}
PKCS12_free(p12);
std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList);
std::vector<std::string> ciphers = CommUtil::Split(default_cipher_list, kColon);
if (!CommUtil::VerifyCipherList(ciphers)) {
MS_LOG(EXCEPTION) << "The cipher is wrong.";
}
day = 0;
sec = 0;
ASN1_TIME_diff(&day, &sec, NULL, end);
if (day < 0 || sec < 0) {
MS_LOG(INFO) << "Cert end time is sooner than now time.";
return false;
if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
}

return true;
}
std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
if (crl_path.empty()) {
MS_LOG(INFO) << "The crl path is empty.";
} else if (!CommUtil::VerifyCRL(cert, crl_path)) {
MS_LOG(EXCEPTION) << "Verify crl failed.";
}

bool SSLWrapper::VerifyCAChain(const std::string &keyAttestation, const std::string &equipCert,
const std::string &equipCACert, std::string) {
X509 *keyAttestationCertObj = ReadCertFromPerm(keyAttestation);
X509 *equipCertObj = ReadCertFromPerm(equipCert);
X509 *equipCACertObj = ReadCertFromPerm(equipCACert);
std::string client_ca = kCAcrt;
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
if (!CommUtil::IsFileExists(ca_path)) {
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
}
client_ca = ca_path;

if (!VerifyCertTime(keyAttestationCertObj) || !VerifyCertTime(equipCertObj) || !VerifyCertTime(equipCACertObj)) {
return false;
if (!CommUtil::VerifyCommonName(cert, client_ca)) {
MS_LOG(EXCEPTION) << "Verify common name failed.";
}

EVP_PKEY *equipPubKey = X509_get_pubkey(equipCertObj);
EVP_PKEY *equipCAPubKey = X509_get_pubkey(equipCACertObj);
SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, 0);
if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) {
MS_LOG(EXCEPTION) << "SSL load ca location failed!";
}

EVP_PKEY *rootFirstPubKey = X509_get_pubkey(rootFirstCA_);
EVP_PKEY *rootSecondPubKey = X509_get_pubkey(rootSecondCA_);
if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
}

int ret = 0;
ret = X509_verify(keyAttestationCertObj, equipPubKey);
if (ret != 1) {
MS_LOG(INFO) << "keyAttestationCert verify is failed";
return false;
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
}
ret = X509_verify(equipCertObj, equipCAPubKey);
if (ret != 1) {
MS_LOG(INFO) << "Equip cert verify is failed";
return false;

if (!SSL_CTX_check_private_key(ssl_ctx_)) {
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
}
int ret_first = X509_verify(equipCACertObj, rootFirstPubKey);
int ret_second = X509_verify(equipCACertObj, rootSecondPubKey);
if (ret_first != 1 && ret_second != 1) {
MS_LOG(INFO) << "Equip ca cert verify is failed";
return false;
if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
}
MS_LOG(INFO) << "VerifyCAChain success.";

EVP_PKEY_free(equipPubKey);
EVP_PKEY_free(equipCAPubKey);
EVP_PKEY_free(rootFirstPubKey);
EVP_PKEY_free(rootSecondPubKey);
return true;
}

bool SSLWrapper::VerifyCRL(const std::string &equipCert) {
X509 *equipCertObj = ReadCertFromPerm(equipCert);
if (rootFirstCrl_ == nullptr && rootSecondCrl_ == nullptr) {
MS_LOG(INFO) << "RootFirstCrl && rootSecondCrl is nullptr.";
return false;
if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) {
MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
}

EVP_PKEY *evp_pkey = X509_get_pubkey(equipCertObj);
int ret = X509_CRL_verify(rootFirstCrl_, evp_pkey);
if (ret == 1) {
MS_LOG(INFO) << "Equip cert in root first crl, verify failed";
return false;
}
ret = X509_CRL_verify(rootSecondCrl_, evp_pkey);
if (ret == 1) {
MS_LOG(INFO) << "Equip cert in root second crl, verify failed";
return false;
}
MS_LOG(INFO) << "VerifyCRL success.";
return true;
StartCheckCertTime(*config_, cert, client_ca);
}

bool SSLWrapper::VerifyRSAKey(const std::string &keyAttestation, const unsigned char *srcData,
const unsigned char *signData, int srcDataLen) {
if (keyAttestation.empty() || srcData == nullptr || signData == nullptr) {
MS_LOG(INFO) << "KeyAttestation or srcData or signData is empty.";
return false;
void SSLWrapper::CleanSSL() {
if (ssl_ctx_ != nullptr) {
SSL_CTX_free(ssl_ctx_);
}
ERR_free_strings();
EVP_cleanup();
ERR_remove_thread_state(nullptr);
CRYPTO_cleanup_all_ex_data();
StopCheckCertTime();
}

X509 *keyAttestationCertObj = ReadCertFromPerm(keyAttestation);
time_t SSLWrapper::ConvertAsn1Time(const ASN1_TIME *const time) const {
MS_EXCEPTION_IF_NULL(time);
struct tm t;
const char *str = (const char *)time->data;
MS_EXCEPTION_IF_NULL(str);
size_t i = 0;

if (memset_s(&t, sizeof(t), 0, sizeof(t)) != EOK) {
MS_LOG(EXCEPTION) << "Memset Failed!";
}

if (time->type == V_ASN1_UTCTIME) {
t.tm_year = (str[i++] - '0') * kBase;
t.tm_year += (str[i++] - '0');
if (t.tm_year < kSeventyYear) {
t.tm_year += kHundredYear;
}
} else if (time->type == V_ASN1_GENERALIZEDTIME) {
t.tm_year = (str[i++] - '0') * kThousandYear;
t.tm_year += (str[i++] - '0') * kHundredYear;
t.tm_year += (str[i++] - '0') * kBase;
t.tm_year += (str[i++] - '0');
t.tm_year -= kBaseYear;
}
t.tm_mon = (str[i++] - '0') * kBase;
// -1 since January is 0 not 1.
t.tm_mon += (str[i++] - '0') - kJanuary;
t.tm_mday = (str[i++] - '0') * kBase;
t.tm_mday += (str[i++] - '0');
t.tm_hour = (str[i++] - '0') * kBase;
t.tm_hour += (str[i++] - '0');
t.tm_min = (str[i++] - '0') * kBase;
t.tm_min += (str[i++] - '0');
t.tm_sec = (str[i++] - '0') * kBase;
t.tm_sec += (str[i++] - '0');

return mktime(&t);
}

EVP_PKEY *pubKey = X509_get_pubkey(keyAttestationCertObj);
RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey);
if (pRSAPublicKey == nullptr) {
MS_LOG(INFO) << "Get rsa public key failed.";
return false;
}
void SSLWrapper::StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path) {
MS_EXCEPTION_IF_NULL(cert);
MS_LOG(INFO) << "The server start check cert.";
int64_t interval = kCertCheckIntervalInHour;

int64_t warning_time = kCertExpireWarningTimeInDay;
if (config.Exists(kCertExpireWarningTime)) {
int64_t res_time = config.GetInt(kCertExpireWarningTime, 0);
if (res_time < kMinWarningTime || res_time > kMaxWarningTime) {
MS_LOG(EXCEPTION) << "The Certificate expiration warning time should be [7, 180]";
}
warning_time = res_time;
}
MS_LOG(INFO) << "The interval time is:" << interval << ", the warning time is:" << warning_time;
BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
MS_EXCEPTION_IF_NULL(ca_bio);
X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
BIO_free_all(ca_bio);
MS_EXCEPTION_IF_NULL(ca_cert);

running_ = true;
check_time_thread_ = std::make_unique<std::thread>([&, cert, ca_cert, interval, warning_time]() {
while (running_) {
if (!CommUtil::VerifyCertTime(cert, warning_time)) {
MS_LOG(WARNING) << "Verify server cert time failed.";
}

if (!CommUtil::VerifyCertTime(ca_cert, warning_time)) {
MS_LOG(WARNING) << "Verify ca cert time failed.";
}
std::unique_lock<std::mutex> lock(mutex_);
bool res = cond_.wait_for(lock, std::chrono::hours(interval), [&] {
bool result = is_ready_.load();
return result;
});
MS_LOG(INFO) << "Wait for res:" << res;
}
});
MS_EXCEPTION_IF_NULL(check_time_thread_);
}

int pubKeyLen = RSA_size(pRSAPublicKey);
int ret = RSA_verify(NID_sha256, srcData, srcDataLen, signData, pubKeyLen, pRSAPublicKey);
if (ret != 1) {
MS_LOG(WARNING) << "Verify error.";
int64_t ulErr = ERR_get_error();
char szErrMsg[1024] = {0};
MS_LOG(WARNING) << "Error number: " << ulErr;
ERR_error_string(ulErr, szErrMsg);
MS_LOG(INFO) << "Error message:" << szErrMsg;
return false;
void SSLWrapper::StopCheckCertTime() {
running_ = false;
is_ready_ = true;
cond_.notify_all();
if (check_time_thread_ != nullptr) {
check_time_thread_->join();
}
RSA_free(pRSAPublicKey);
X509_free(keyAttestationCertObj);
CRYPTO_cleanup_all_ex_data();

MS_LOG(INFO) << "VerifyRSAKey success.";
return true;
}

SSL_CTX *SSLWrapper::GetSSLCtx(bool) { return ssl_ctx_; }
} // namespace core
} // namespace ps
} // namespace mindspore

+ 16
- 28
mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h View File

@@ -27,9 +27,16 @@

#include <iostream>
#include <string>
#include <memory>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <atomic>

#include "utils/log_adapter.h"
#include "ps/core/comm_util.h"
#include "ps/core/file_configuration.h"
#include "ps/constants.h"

namespace mindspore {
namespace ps {
@@ -42,29 +49,6 @@ class SSLWrapper {
}
SSL_CTX *GetSSLCtx(bool is_server = true);

// read certificate from file path
X509 *ReadCertFromFile(const std::string &certPath) const;

// read Certificate Revocation List from file absolute path
X509_CRL *ReadCrlFromFile(const std::string &crlPath) const;

// read certificate from pem string
X509 *ReadCertFromPerm(std::string cert);

// verify valid of certificate time
bool VerifyCertTime(const X509 *cert) const;

// verify valid of certificate chain
bool VerifyCAChain(const std::string &keyAttestation, const std::string &equipCert, const std::string &equipCACert,
std::string rootCert);

// verify valid of sign data
bool VerifyRSAKey(const std::string &keyAttestation, const unsigned char *srcData, const unsigned char *signData,
int srcDataLen);

// verify valid of equip certificate with CRL
bool VerifyCRL(const std::string &equipCert);

private:
SSLWrapper();
virtual ~SSLWrapper();
@@ -73,18 +57,22 @@ class SSLWrapper {

void InitSSL();
void CleanSSL();
time_t ConvertAsn1Time(const ASN1_TIME *const time) const;
void StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path);
void StopCheckCertTime();

SSL_CTX *ssl_ctx_;
SSL_CTX *client_ssl_ctx_;

// The firset root ca certificate.
X509 *rootFirstCA_;
// The second root ca certificate.
X509 *rootSecondCA_;
// The firset root revocation list
X509_CRL *rootFirstCrl_;
// The second root revocation list
X509_CRL *rootSecondCrl_;
std::unique_ptr<std::thread> check_time_thread_;
std::atomic<bool> running_;
std::atomic<bool> is_ready_;
std::mutex mutex_;
std::condition_variable cond_;
std::mutex verify_mutex_;
};
} // namespace core
} // namespace ps


+ 13
- 85
mindspore/ccsrc/ps/core/communicator/tcp_client.cc View File

@@ -37,7 +37,7 @@ event_base *TcpClient::event_base_ = nullptr;
std::mutex TcpClient::event_base_mutex_;
bool TcpClient::is_started_ = false;

TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *config)
TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config)
: event_timeout_(nullptr),
buffer_event_(nullptr),
server_address_(std::move(address)),
@@ -87,8 +87,6 @@ void TcpClient::Init() {
MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
}

event_enable_debug_logging(EVENT_DBG_ALL);
event_set_log_callback(CommUtil::LogCallback);
int result = evthread_use_pthreads();
if (result != 0) {
MS_LOG(EXCEPTION) << "Use event pthread failed!";
@@ -108,10 +106,12 @@ void TcpClient::Init() {
sin.sin_port = htons(server_port_);

if (!PSContext::instance()->enable_ssl()) {
MS_LOG(INFO) << "SSL is disable.";
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
} else {
if (!EstablishSSL()) {
MS_LOG(EXCEPTION) << "Establish SSL failed.";
MS_LOG(WARNING) << "Establish SSL failed.";
return;
}
}

@@ -135,18 +135,21 @@ void TcpClient::StartWithDelay(int seconds) {
}

event_base_ = event_base_new();
MS_EXCEPTION_IF_NULL(event_base_);

timeval timeout_value{};
timeout_value.tv_sec = seconds;
timeout_value.tv_usec = 0;

event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
MS_EXCEPTION_IF_NULL(event_timeout_);
if (evtimer_add(event_timeout_, &timeout_value) == -1) {
MS_LOG(EXCEPTION) << "Event timeout failed!";
}
}

void TcpClient::Stop() {
MS_EXCEPTION_IF_NULL(event_base_);
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Stop tcp client!";
int ret = event_base_loopbreak(event_base_);
@@ -177,8 +180,8 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
char read_buffer[kMessageChunkLength];
int read = 0;

while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) {
tcp_client->OnReadHandler(read_buffer, read);
while ((read = bufferevent_read(bev, &read_buffer, SizeToInt(sizeof(read_buffer)))) > 0) {
tcp_client->OnReadHandler(read_buffer, IntToSize(read));
}
}

@@ -207,70 +210,9 @@ void TcpClient::NotifyConnected() {
bool TcpClient::EstablishSSL() {
MS_LOG(INFO) << "Enable ssl support.";

if (config_ == nullptr) {
MS_LOG(EXCEPTION) << "The config is empty.";
}

SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx(false));

// 1.Parse the client's certificate and the ciphertext of key.
std::string client_cert = kCertificateChain;
std::string path = CommUtil::ParseConfig(*config_, kClientCertPath);
if (!CommUtil::IsFileExists(path)) {
MS_LOG(WARNING) << "The key:" << kClientCertPath << "'s value is not exist.";
return false;
}
client_cert = path;

MS_LOG(INFO) << "The client cert path:" << client_cert;

// 2. Parse the client password.
std::string client_password = CommUtil::ParseConfig(*config_, kClientPassword);
if (client_password.empty()) {
MS_LOG(WARNING) << "The key:" << kClientPassword << "'s value is empty.";
return false;
}

MS_LOG(INFO) << "The client password:" << client_password;

EVP_PKEY *pkey = nullptr;
X509 *cert = nullptr;
STACK_OF(X509) *ca_stack = nullptr;
BIO *bio = BIO_new_file(client_cert.c_str(), "rb");
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
BIO_free_all(bio);
PKCS12_parse(p12, client_password.c_str(), &pkey, &cert, &ca_stack);
PKCS12_free(p12);

if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) {
MS_LOG(EXCEPTION) << "Verify cert time failed.";
}

if (!SSL_CTX_use_certificate(SSLWrapper::GetInstance().GetSSLCtx(false), cert)) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
}

if (!SSL_CTX_use_PrivateKey(SSLWrapper::GetInstance().GetSSLCtx(false), pkey)) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
}

std::string client_ca = kCAcrt;
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
if (!CommUtil::IsFileExists(ca_path)) {
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
}
client_ca = ca_path;
MS_LOG(INFO) << "The ca cert path:" << client_ca;

if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx(false))) {
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
}

if (!SSL_CTX_load_verify_locations(SSLWrapper::GetInstance().GetSSLCtx(false), client_ca.c_str(), nullptr)) {
MS_LOG(EXCEPTION) << "SSL load ca location failed!";
}

SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(false), SSL_OP_NO_SSLv2);
SSL *ssl = SSL_new(SSLClient::GetInstance().GetSSLCtx());
MS_ERROR_IF_NULL_W_RET_VAL(ssl, false);
MS_ERROR_IF_NULL_W_RET_VAL(event_base_, false);

buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING,
BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
@@ -334,7 +276,7 @@ bool TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_lock(buffer_event_);
bool res = true;
size_t buf_size = IntToUint(message.ByteSizeLong());
size_t buf_size = message.ByteSizeLong();
uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
MessageHeader header;
header.message_proto_ = Protos::PROTOBUF;
@@ -390,20 +332,6 @@ bool TcpClient::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Prot
return res;
}

void TcpClient::StartTimer(const uint32_t &time) {
MS_EXCEPTION_IF_NULL(event_base_);
struct event *ev = nullptr;
if (time == 0) {
MS_LOG(EXCEPTION) << "The time should not be 0!";
}
struct timeval timeout {};
timeout.tv_sec = time;
timeout.tv_usec = 0;
ev = event_new(event_base_, -1, EV_PERSIST, TimerCallback, this);
MS_EXCEPTION_IF_NULL(ev);
evtimer_add(ev, &timeout);
}

void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }

const event_base &TcpClient::eventbase() const { return *event_base_; }


+ 2
- 2
mindspore/ccsrc/ps/core/communicator/tcp_client.h View File

@@ -34,6 +34,7 @@
#include "ps/core/cluster_config.h"
#include "utils/convert_utils_base.h"
#include "ps/core/comm_util.h"
#include "ps/core/communicator/ssl_client.h"
#include "ps/core/communicator/ssl_wrapper.h"
#include "ps/constants.h"
#include "ps/ps_context.h"
@@ -69,7 +70,6 @@ class TcpClient {
void SetMessageCallback(const OnMessage &cb);
bool SendMessage(const CommMessage &message) const;
bool SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size);
void StartTimer(const uint32_t &time);
void set_timer_callback(const OnTimer &timer);
const event_base &eventbase() const;

@@ -107,7 +107,7 @@ class TcpClient {
std::atomic<bool> is_stop_;
std::atomic<bool> is_connected_;
// The Configuration file
Configuration *const config_;
Configuration *config_;
};
} // namespace core
} // namespace ps


+ 43
- 97
mindspore/ccsrc/ps/core/communicator/tcp_server.cc View File

@@ -35,12 +35,19 @@ namespace core {
TcpConnection::~TcpConnection() { bufferevent_free(buffer_event_); }
void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); }

void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); }
void TcpConnection::OnReadHandler(const void *buffer, size_t num) {
MS_EXCEPTION_IF_NULL(buffer);
tcp_message_handler_.ReceiveMessage(buffer, num);
}

void TcpConnection::SendMessage(const void *buffer, size_t num) const {
MS_EXCEPTION_IF_NULL(buffer);
MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_lock(buffer_event_);
if (bufferevent_write(buffer_event_, buffer, num) == -1) {
MS_LOG(ERROR) << "Write message to buffer event failed!";
}
bufferevent_unlock(buffer_event_);
}

const TcpServer *TcpConnection::GetServer() const { return server_; }
@@ -93,24 +100,22 @@ bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const
}
int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
if (result < 0) {
bufferevent_unlock(buffer_event_);
MS_LOG(EXCEPTION) << "Bufferevent flush failed!";
}
bufferevent_unlock(buffer_event_);
MS_LOG(DEBUG) << "SendMessage the request id is:" << meta->request_id() << " the current time is:"
<< std::chrono::time_point_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now())
.time_since_epoch()
.count();
return res;
}

TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *config)
TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *const config)
: base_(nullptr),
signal_event_(nullptr),
listener_(nullptr),
server_address_(std::move(address)),
server_port_(port),
is_stop_(true),
config_(config) {}
config_(config),
max_connection_(0) {}

TcpServer::~TcpServer() {
if (signal_event_ != nullptr) {
@@ -146,14 +151,18 @@ void TcpServer::Init() {
MS_LOG(EXCEPTION) << "Use event pthread failed!";
}

event_enable_debug_logging(EVENT_DBG_ALL);
event_set_log_callback(CommUtil::LogCallback);
is_stop_ = false;
base_ = event_base_new();
MS_EXCEPTION_IF_NULL(base_);
if (!CommUtil::CheckIp(server_address_)) {
MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!";
}
MS_EXCEPTION_IF_NULL(config_);
max_connection_ = kConnectionNumDefault;
if (config_->Exists(kConnectionNum)) {
max_connection_ = config_->GetInt(kConnectionNum, 0);
}
MS_LOG(INFO) << "The max connection is:" << max_connection_;

struct sockaddr_in sin {};
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
@@ -210,35 +219,8 @@ void TcpServer::StartWithNoBlock() {
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
}

void TcpServer::StartTimerOnlyOnce(const uint32_t &time) {
MS_EXCEPTION_IF_NULL(base_);
if (time == 0) {
MS_LOG(EXCEPTION) << "The time should not be 0!";
}
struct event *ev = nullptr;
struct timeval timeout {};
timeout.tv_sec = time;
timeout.tv_usec = 0;
ev = evtimer_new(base_, TimerOnceCallback, this);
MS_EXCEPTION_IF_NULL(ev);
evtimer_add(ev, &timeout);
}

void TcpServer::StartTimer(const uint32_t &time) {
MS_EXCEPTION_IF_NULL(base_);
struct event *ev = nullptr;
if (time == 0) {
MS_LOG(EXCEPTION) << "The time should not be 0!";
}
struct timeval timeout {};
timeout.tv_sec = time;
timeout.tv_usec = 0;
ev = event_new(base_, -1, EV_PERSIST, TimerCallback, this);
MS_EXCEPTION_IF_NULL(ev);
evtimer_add(ev, &timeout);
}

void TcpServer::Stop() {
MS_EXCEPTION_IF_NULL(base_);
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Stop tcp server!";
if (event_base_got_break(base_)) {
@@ -279,71 +261,29 @@ std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_
void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
void *data) {
auto server = reinterpret_cast<class TcpServer *>(data);
auto base = reinterpret_cast<struct event_base *>(server->base_);
MS_EXCEPTION_IF_NULL(server);
auto base = reinterpret_cast<struct event_base *>(server->base_);
MS_EXCEPTION_IF_NULL(base);
MS_EXCEPTION_IF_NULL(sockaddr);

if (server->ConnectionNum() >= server->max_connection_) {
MS_LOG(WARNING) << "The current connection num:" << server->ConnectionNum() << " is greater or equal to "
<< server->max_connection_;
return;
}

struct bufferevent *bev = nullptr;

if (!PSContext::instance()->enable_ssl()) {
MS_LOG(INFO) << "SSL is disable.";
bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
} else {
MS_LOG(INFO) << "Enable ssl support.";
if (server->config_ == nullptr) {
MS_LOG(EXCEPTION) << "The config is empty.";
}

SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx());

// 1.Parse the server's certificate and the ciphertext of key.
std::string server_cert = kCertificateChain;
std::string path = CommUtil::ParseConfig(*(server->config_), kServerCertPath);
if (!CommUtil::IsFileExists(path)) {
MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist.";
}
server_cert = path;

MS_LOG(INFO) << "The server cert path:" << server_cert;

// 2. Parse the server password.
std::string server_password = CommUtil::ParseConfig(*(server->config_), kServerPassword);
if (server_password.empty()) {
MS_LOG(EXCEPTION) << "The key:" << kServerPassword << "'s value is empty.";
}

MS_LOG(INFO) << "The server password:" << server_password;

EVP_PKEY *pkey = nullptr;
X509 *cert = nullptr;
STACK_OF(X509) *ca_stack = nullptr;
BIO *bio = BIO_new_file(server_cert.c_str(), "rb");
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
BIO_free_all(bio);
PKCS12_parse(p12, server_password.c_str(), &pkey, &cert, &ca_stack);
PKCS12_free(p12);

if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) {
MS_LOG(EXCEPTION) << "Verify cert time failed.";
}

if (!SSL_CTX_use_certificate(SSLWrapper::GetInstance().GetSSLCtx(), cert)) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
}

if (!SSL_CTX_use_PrivateKey(SSLWrapper::GetInstance().GetSSLCtx(), pkey)) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
}

if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx())) {
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
}
SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(), SSL_OP_NO_SSLv2);

MS_EXCEPTION_IF_NULL(ssl);
bev = bufferevent_openssl_socket_new(base, fd, ssl, BUFFEREVENT_SSL_ACCEPTING,
BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
}

if (bev == nullptr) {
MS_LOG(ERROR) << "Error constructing buffer event!";
int ret = event_base_loopbreak(base);
@@ -386,9 +326,10 @@ std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent
OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }

void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
MS_EXCEPTION_IF_NULL(data);
auto server = reinterpret_cast<class TcpServer *>(data);
MS_EXCEPTION_IF_NULL(server);
struct event_base *base = server->base_;
MS_EXCEPTION_IF_NULL(base);
struct timeval delay = {0, 0};
MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
if (event_base_loopexit(base, &delay) == -1) {
@@ -402,6 +343,7 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {

auto conn = static_cast<class TcpConnection *>(connection);
struct evbuffer *buf = bufferevent_get_input(bev);
MS_EXCEPTION_IF_NULL(buf);
char read_buffer[kMessageChunkLength];
while (EVBUFFER_LENGTH(buf) > 0) {
int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
@@ -409,11 +351,6 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
}
conn->OnReadHandler(read_buffer, IntToSize(read));
MS_LOG(DEBUG) << "the current time is:"
<< std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now())
.time_since_epoch()
.count()
<< " the read size is:" << read;
}
}

@@ -421,9 +358,10 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(data);
struct evbuffer *output = bufferevent_get_output(bev);
size_t remain = evbuffer_get_length(output);
MS_EXCEPTION_IF_NULL(output);
auto conn = static_cast<class TcpConnection *>(data);
auto srv = const_cast<TcpServer *>(conn->GetServer());
MS_EXCEPTION_IF_NULL(srv);

if (events & BEV_EVENT_EOF) {
MS_LOG(INFO) << "Event buffer end of file, a client is disconnected from this server!";
@@ -434,7 +372,15 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void
// Free connection structures
srv->RemoveConnection(conn->GetFd());
} else if (events & BEV_EVENT_ERROR) {
MS_LOG(WARNING) << "Event buffer remain data: " << remain;
MS_LOG(WARNING) << "Connect to server error.";
if (PSContext::instance()->enable_ssl()) {
uint64_t err = bufferevent_get_openssl_error(bev);
MS_LOG(WARNING) << "The error number is:" << err;

MS_LOG(WARNING) << "Error message:" << ERR_reason_error_string(err)
<< ", the error lib:" << ERR_lib_error_string(err)
<< ", the error func:" << ERR_func_error_string(err);
}
// Free connection structures
srv->RemoveConnection(conn->GetFd());

@@ -443,7 +389,7 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void
srv->client_disconnection_(*srv, *conn);
}
} else {
MS_LOG(WARNING) << "Unhandled event!";
MS_LOG(WARNING) << "Unhandled event:" << events;
}
}



+ 5
- 6
mindspore/ccsrc/ps/core/communicator/tcp_server.h View File

@@ -57,11 +57,11 @@ class TcpConnection {

using Callback = std::function<void(const std::shared_ptr<CommMessage>)>;

virtual void InitConnection(const messageReceive &callback);
virtual void SendMessage(const void *buffer, size_t num) const;
void InitConnection(const messageReceive &callback);
void SendMessage(const void *buffer, size_t num) const;
bool SendMessage(const std::shared_ptr<CommMessage> &message) const;
bool SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes);
void OnReadHandler(const void *buffer, size_t numBytes);
const TcpServer *GetServer() const;
const evutil_socket_t &GetFd() const;
void set_callback(const Callback &callback);
@@ -97,8 +97,6 @@ class TcpServer {
void Init();
void Start();
void StartWithNoBlock();
void StartTimerOnlyOnce(const uint32_t &time);
void StartTimer(const uint32_t &time);
void Stop();
void SendToAllClients(const char *data, size_t len);
void AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection);
@@ -142,7 +140,8 @@ class TcpServer {
OnTimerOnce on_timer_once_callback_;
OnTimer on_timer_callback_;
// The Configuration file
Configuration *const config_;
Configuration *config_;
int64_t max_connection_;
};
} // namespace core
} // namespace ps


+ 6
- 0
mindspore/ccsrc/ps/ps_context.cc View File

@@ -413,5 +413,11 @@ std::string PSContext::config_file_path() const { return config_file_path_; }
void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; }

const std::string &PSContext::node_id() const { return node_id_; }

std::string PSContext::client_password() const { return client_password_; }
void PSContext::set_client_password(const std::string &password) { client_password_ = password; }

std::string PSContext::server_password() const { return server_password_; }
void PSContext::set_server_password(const std::string &password) { server_password_ = password; }
} // namespace ps
} // namespace mindspore

+ 13
- 1
mindspore/ccsrc/ps/ps_context.h View File

@@ -83,6 +83,11 @@ class PSContext {
bool enable_ssl() const;
void set_enable_ssl(bool enabled);

std::string client_password() const;
void set_client_password(const std::string &password);
std::string server_password() const;
void set_server_password(const std::string &password);

// In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set
// by ps_context.
void set_server_mode(const std::string &server_mode);
@@ -218,7 +223,9 @@ class PSContext {
dp_delta_(0.01),
dp_norm_clip_(1.0),
encrypt_type_(kNotEncryptType),
node_id_("") {}
node_id_(""),
client_password_(""),
server_password_("") {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@@ -310,6 +317,11 @@ class PSContext {

// Unique id of the node
std::string node_id_;

// Password used to decode p12 file.
std::string client_password_;
// Password used to decode p12 file.
std::string server_password_;
};
} // namespace ps
} // namespace mindspore


Loading…
Cancel
Save