From: @anancds Reviewed-by: @limingqi107 Signed-off-by:pull/14569/MERGE
| @@ -16,7 +16,7 @@ message("libevent using openssl stub dir: " ${openssl_ROOT}) | |||
| mindspore_add_pkg(libevent | |||
| VER 2.1.12 | |||
| LIBS event event_pthreads event_core | |||
| LIBS event event_pthreads event_core event_openssl | |||
| URL ${REQ_URL} | |||
| MD5 ${MD5} | |||
| CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release -DBUILD_TESTING=OFF -DOPENSSL_ROOT_DIR:PATH=${openssl_ROOT}) | |||
| @@ -26,3 +26,4 @@ include_directories(${libevent_INC}) | |||
| add_library(mindspore::event ALIAS libevent::event) | |||
| add_library(mindspore::event_pthreads ALIAS libevent::event_pthreads) | |||
| add_library(mindspore::event_core ALIAS libevent::event_core) | |||
| add_library(mindspore::event_openssl ALIAS libevent::event_openssl) | |||
| @@ -370,8 +370,9 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) | |||
| else() | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| find_package(OpenSSL REQUIRED) | |||
| target_link_libraries(mindspore proto_input mindspore::protobuf | |||
| mindspore::event mindspore::event_pthreads) | |||
| mindspore::event mindspore::event_pthreads mindspore::event_openssl OpenSSL::SSL OpenSSL::Crypto) | |||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | |||
| if(${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(mindspore ibverbs rdmacm) | |||
| @@ -22,7 +22,8 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_client.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "worker.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "parameter_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/worker_queue.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_request_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/ssl_wrapper.cc") | |||
| endif() | |||
| if(NOT ENABLE_D) | |||
| @@ -51,6 +51,9 @@ constexpr char kSparseAdamOp[] = "Adam"; | |||
| constexpr char kSparseLazyAdamOp[] = "LazyAdam"; | |||
| constexpr char kSparseFtrlOp[] = "FTRL"; | |||
| constexpr char kCertificateChain[] = "server.crt"; | |||
| constexpr char kPrivateKey[] = "server.key.unsecure"; | |||
| constexpr int64_t kInitWeightsCmd = 10; | |||
| constexpr int64_t kInitWeightToOptimIdCmd = 11; | |||
| constexpr int64_t kInitOptimInputsShapeCmd = 12; | |||
| @@ -14,16 +14,39 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/communicator/worker_queue.h" | |||
| #include "ps/core/communicator/http_request_handler.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| bool WorkerQueue::Initialize(int fd, std::unordered_map<std::string, OnRequestReceive *> handlers) { | |||
| bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers) { | |||
| evbase_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(evbase_); | |||
| struct evhttp *http = evhttp_new(evbase_); | |||
| MS_EXCEPTION_IF_NULL(http); | |||
| SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(), | |||
| SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2); | |||
| EC_KEY *ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); | |||
| MS_EXCEPTION_IF_NULL(ecdh); | |||
| if (!SSL_CTX_use_certificate_chain_file(SSLWrapper::GetInstance().GetSSLCtx(), kCertificateChain)) { | |||
| MS_LOG(ERROR) << "SSL use certificate chain file failed!"; | |||
| return false; | |||
| } | |||
| if (!SSL_CTX_use_PrivateKey_file(SSLWrapper::GetInstance().GetSSLCtx(), kPrivateKey, SSL_FILETYPE_PEM)) { | |||
| MS_LOG(ERROR) << "SSL use private key file failed!"; | |||
| return false; | |||
| } | |||
| if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx())) { | |||
| MS_LOG(ERROR) << "SSL check private key file failed!"; | |||
| return false; | |||
| } | |||
| evhttp_set_bevcb(http, BuffereventCallback, SSLWrapper::GetInstance().GetSSLCtx()); | |||
| int result = evhttp_accept_socket(http, fd); | |||
| if (result < 0) { | |||
| MS_LOG(ERROR) << "Evhttp accept socket failed!"; | |||
| @@ -56,7 +79,7 @@ bool WorkerQueue::Initialize(int fd, std::unordered_map<std::string, OnRequestRe | |||
| return true; | |||
| } | |||
| void WorkerQueue::Run() { | |||
| void HttpRequestHandler::Run() { | |||
| MS_LOG(INFO) << "Start http server!"; | |||
| MS_EXCEPTION_IF_NULL(evbase_); | |||
| int ret = event_base_dispatch(evbase_); | |||
| @@ -76,7 +99,7 @@ void WorkerQueue::Run() { | |||
| } | |||
| } | |||
| void WorkerQueue::Stop() { | |||
| void HttpRequestHandler::Stop() { | |||
| MS_LOG(INFO) << "Stop http server!"; | |||
| int ret = event_base_loopbreak(evbase_); | |||
| @@ -84,6 +107,13 @@ void WorkerQueue::Stop() { | |||
| MS_LOG(EXCEPTION) << "event base loop break failed!"; | |||
| } | |||
| } | |||
| bufferevent *HttpRequestHandler::BuffereventCallback(event_base *base, void *arg) { | |||
| SSL_CTX *ctx = reinterpret_cast<SSL_CTX *>(arg); | |||
| SSL *ssl = SSL_new(ctx); | |||
| bufferevent *bev = bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE); | |||
| return bev; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,12 +14,14 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_WORKER_QUEUE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_WORKER_QUEUE_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_REQUEST_HANDLER_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_REQUEST_HANDLER_H_ | |||
| #include <event2/event.h> | |||
| #include <event2/http.h> | |||
| #include <event2/http_struct.h> | |||
| #include <event2/bufferevent.h> | |||
| #include <event2/bufferevent_ssl.h> | |||
| #include <string> | |||
| #include <memory> | |||
| @@ -27,19 +29,26 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/core/communicator/http_message_handler.h" | |||
| #include "ps/core/communicator/ssl_wrapper.h" | |||
| #include "ps/constants.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| using OnRequestReceive = std::function<void(std::shared_ptr<HttpMessageHandler>)>; | |||
| class WorkerQueue { | |||
| /* Each thread corresponds to one HttpRequestHandler, which is used to create one eventbase. All eventbase are listened | |||
| * on the same fd. Every evhttp_request is executed in one thread. | |||
| */ | |||
| class HttpRequestHandler { | |||
| public: | |||
| WorkerQueue() : evbase_(nullptr) {} | |||
| virtual ~WorkerQueue() = default; | |||
| HttpRequestHandler() : evbase_(nullptr) {} | |||
| virtual ~HttpRequestHandler() = default; | |||
| bool Initialize(int fd, std::unordered_map<std::string, OnRequestReceive *> handlers); | |||
| bool Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers); | |||
| void Run(); | |||
| void Stop(); | |||
| static bufferevent *BuffereventCallback(event_base *base, void *arg); | |||
| private: | |||
| struct event_base *evbase_; | |||
| @@ -47,4 +56,4 @@ class WorkerQueue { | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_WORKER_QUEUE_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_REQUEST_HANDLER_H_ | |||
| @@ -118,10 +118,10 @@ bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *functio | |||
| bool HttpServer::Start() { | |||
| MS_LOG(INFO) << "Start http server!"; | |||
| for (size_t i = 0; i < thread_num_; i++) { | |||
| auto worker_queue = std::make_shared<WorkerQueue>(); | |||
| worker_queue->Initialize(fd_, request_handlers_); | |||
| worker_queues_.push_back(worker_queue); | |||
| worker_threads_.emplace_back(std::make_shared<std::thread>(&WorkerQueue::Run, worker_queue)); | |||
| auto http_request_handler = std::make_shared<HttpRequestHandler>(); | |||
| http_request_handler->Initialize(fd_, request_handlers_); | |||
| http_request_handlers.push_back(http_request_handler); | |||
| worker_threads_.emplace_back(std::make_shared<std::thread>(&HttpRequestHandler::Run, http_request_handler)); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -139,7 +139,7 @@ void HttpServer::Stop() { | |||
| if (!is_stop_.load()) { | |||
| for (size_t i = 0; i < thread_num_; i++) { | |||
| worker_queues_[i]->Stop(); | |||
| http_request_handlers[i]->Stop(); | |||
| } | |||
| is_stop_ = true; | |||
| } | |||
| @@ -38,7 +38,7 @@ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "ps/core/communicator/worker_queue.h" | |||
| #include "ps/core/communicator/http_request_handler.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -75,7 +75,7 @@ class HttpServer { | |||
| int request_timeout_; | |||
| size_t thread_num_; | |||
| std::vector<std::shared_ptr<std::thread>> worker_threads_; | |||
| std::vector<std::shared_ptr<WorkerQueue>> worker_queues_; | |||
| std::vector<std::shared_ptr<HttpRequestHandler>> http_request_handlers; | |||
| int32_t backlog_; | |||
| std::unordered_map<std::string, OnRequestReceive *> request_handlers_; | |||
| int fd_; | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/communicator/ssl_wrapper.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| SSLWrapper::SSLWrapper() : ssl_ctx_(nullptr) { InitSSL(); } | |||
| SSLWrapper::~SSLWrapper() { CleanSSL(); } | |||
| void SSLWrapper::InitSSL() { | |||
| SSL_library_init(); | |||
| ERR_load_crypto_strings(); | |||
| SSL_load_error_strings(); | |||
| OpenSSL_add_all_algorithms(); | |||
| int rand = RAND_poll(); | |||
| if (rand == 0) { | |||
| MS_LOG(ERROR) << "RAND_poll failed"; | |||
| } | |||
| ssl_ctx_ = SSL_CTX_new(SSLv23_server_method()); | |||
| if (!ssl_ctx_) { | |||
| MS_LOG(ERROR) << "SSL_CTX_new failed"; | |||
| } | |||
| X509_STORE *store = SSL_CTX_get_cert_store(ssl_ctx_); | |||
| if (X509_STORE_set_default_paths(store) != 1) { | |||
| MS_LOG(ERROR) << "X509_STORE_set_default_paths failed"; | |||
| } | |||
| } | |||
| void SSLWrapper::CleanSSL() { | |||
| if (ssl_ctx_ != nullptr) { | |||
| SSL_CTX_free(ssl_ctx_); | |||
| } | |||
| ERR_free_strings(); | |||
| EVP_cleanup(); | |||
| ERR_remove_thread_state(nullptr); | |||
| CRYPTO_cleanup_all_ex_data(); | |||
| } | |||
| SSL_CTX *SSLWrapper::GetSSLCtx() { return ssl_ctx_; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_SSL_WRAPPER_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_SSL_WRAPPER_H_ | |||
| #include <openssl/ssl.h> | |||
| #include <openssl/rand.h> | |||
| #include <openssl/err.h> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class SSLWrapper { | |||
| public: | |||
| static SSLWrapper &GetInstance() { | |||
| static SSLWrapper instance; | |||
| return instance; | |||
| } | |||
| SSL_CTX *GetSSLCtx(); | |||
| private: | |||
| SSLWrapper(); | |||
| virtual ~SSLWrapper(); | |||
| SSLWrapper(const SSLWrapper &) = delete; | |||
| SSLWrapper &operator=(const SSLWrapper &) = delete; | |||
| void InitSSL(); | |||
| void CleanSSL(); | |||
| SSL_CTX *ssl_ctx_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_SSL_WRAPPER_H_ | |||
| @@ -161,6 +161,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/worker.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/core/communicator/http_request_handler.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/core/communicator/http_server.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/parameter_server.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") | |||