diff --git a/mindspore/ccsrc/ps/core/http_message_handler.cc b/mindspore/ccsrc/ps/core/http_message_handler.cc index 93e43ccf0c..c7f6c008ad 100644 --- a/mindspore/ccsrc/ps/core/http_message_handler.cc +++ b/mindspore/ccsrc/ps/core/http_message_handler.cc @@ -132,19 +132,17 @@ std::string HttpMessageHandler::GetUriFragment() { return std::string(fragment); } -std::string HttpMessageHandler::GetPostMsg() { +uint64_t HttpMessageHandler::GetPostMsg(unsigned char **buffer) { MS_EXCEPTION_IF_NULL(event_request_); - if (body_ != nullptr) { - return *body_; - } + MS_EXCEPTION_IF_NULL(buffer); + size_t len = evbuffer_get_length(event_request_->input_buffer); if (len == 0) { MS_LOG(EXCEPTION) << "The post message is empty!"; } - const char *post_message = reinterpret_cast(evbuffer_pullup(event_request_->input_buffer, -1)); - MS_EXCEPTION_IF_NULL(post_message); - body_ = std::make_unique(post_message, len); - return *body_; + *buffer = evbuffer_pullup(event_request_->input_buffer, -1); + MS_EXCEPTION_IF_NULL(*buffer); + return len; } void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) { diff --git a/mindspore/ccsrc/ps/core/http_message_handler.h b/mindspore/ccsrc/ps/core/http_message_handler.h index 72f0322c97..f5c7a0e619 100644 --- a/mindspore/ccsrc/ps/core/http_message_handler.h +++ b/mindspore/ccsrc/ps/core/http_message_handler.h @@ -62,7 +62,7 @@ class HttpMessageHandler { std::string GetHeadParam(const std::string &key); std::string GetPathParam(const std::string &key); std::string GetPostParam(const std::string &key); - std::string GetPostMsg(); + uint64_t GetPostMsg(unsigned char **buffer); std::string GetUriPath(); std::string GetUriQuery(); diff --git a/mindspore/ccsrc/ps/core/http_server.cc b/mindspore/ccsrc/ps/core/http_server.cc index 548e6ec1c6..44af4bb81f 100644 --- a/mindspore/ccsrc/ps/core/http_server.cc +++ b/mindspore/ccsrc/ps/core/http_server.cc @@ -49,6 +49,7 @@ bool HttpServer::InitServer() { MS_LOG(EXCEPTION) << "The http server ip:" << server_address_ << " is illegal!"; } + is_stop_ = false; event_base_ = event_base_new(); MS_EXCEPTION_IF_NULL(event_base_); event_http_ = evhttp_new(event_base_); @@ -146,13 +147,21 @@ bool HttpServer::Start() { void HttpServer::Stop() { MS_LOG(INFO) << "Stop http server!"; - if (event_http_) { - evhttp_free(event_http_); - event_http_ = nullptr; - } - if (event_base_) { - event_base_free(event_base_); - event_base_ = nullptr; + + if (!is_stop_.load()) { + int ret = event_base_loopbreak(event_base_); + if (ret != 0) { + MS_LOG(EXCEPTION) << "event base loop break failed!"; + } + if (event_http_) { + evhttp_free(event_http_); + event_http_ = nullptr; + } + if (event_base_) { + event_base_free(event_base_); + event_base_ = nullptr; + } + is_stop_ = true; } } diff --git a/mindspore/ccsrc/ps/core/http_server.h b/mindspore/ccsrc/ps/core/http_server.h index acea23db65..69d2936800 100644 --- a/mindspore/ccsrc/ps/core/http_server.h +++ b/mindspore/ccsrc/ps/core/http_server.h @@ -32,6 +32,7 @@ #include #include #include +#include namespace mindspore { namespace ps { @@ -55,7 +56,12 @@ class HttpServer { public: // Server address only support IPV4 now, and should be in format of "x.x.x.x" explicit HttpServer(const std::string &address, std::uint16_t port) - : server_address_(address), server_port_(port), event_base_(nullptr), event_http_(nullptr), is_init_(false) {} + : server_address_(address), + server_port_(port), + event_base_(nullptr), + event_http_(nullptr), + is_init_(false), + is_stop_(true) {} ~HttpServer(); @@ -84,6 +90,7 @@ class HttpServer { struct event_base *event_base_; struct evhttp *event_http_; bool is_init_; + std::atomic is_stop_; }; } // namespace core diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index 9a5c1b5987..9093eec627 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -198,7 +198,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); if (!bev) { MS_LOG(ERROR) << "Error constructing buffer event!"; - event_base_loopbreak(base); + int ret = event_base_loopbreak(base); + if (ret != 0) { + MS_LOG(EXCEPTION) << "event base loop break failed!"; + } return; } diff --git a/tests/ut/cpp/ps/core/http_server_test.cc b/tests/ut/cpp/ps/core/http_server_test.cc index 9646b235b7..bb0f9bede1 100644 --- a/tests/ut/cpp/ps/core/http_server_test.cc +++ b/tests/ut/cpp/ps/core/http_server_test.cc @@ -43,11 +43,19 @@ class TestHttpServer : public UT::Common { std::string path_param = resp->GetPathParam("key1"); std::string header_param = resp->GetHeadParam("headerKey"); std::string post_param = resp->GetPostParam("postKey"); - std::string post_message = resp->GetPostMsg(); + unsigned char *data = nullptr; + const uint64_t len = resp->GetPostMsg(&data); + char post_message[len + 1]; + if (memset_s(post_message, len + 1, 0, len + 1) != 0) { + MS_LOG(EXCEPTION) << "The memset_s error"; + } + if (memcpy_s(post_message, len, data, len) != 0) { + MS_LOG(EXCEPTION) << "The memset_s error"; + } EXPECT_STREQ(path_param.c_str(), "value1"); EXPECT_STREQ(header_param.c_str(), "headerValue"); EXPECT_STREQ(post_param.c_str(), "postValue"); - EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); + EXPECT_STREQ(post_message, "postKey=postValue"); const std::string rKey("headKey"); const std::string rVal("headValue"); @@ -79,11 +87,19 @@ class TestHttpServer : public UT::Common { std::string path_param = resp->GetPathParam("key1"); std::string header_param = resp->GetHeadParam("headerKey"); std::string post_param = resp->GetPostParam("postKey"); - std::string post_message = resp->GetPostMsg(); + unsigned char *data = nullptr; + const uint64_t len = resp->GetPostMsg(&data); + char post_message[len + 1]; + if (memset_s(post_message, len + 1, 0, len + 1) != 0) { + MS_LOG(EXCEPTION) << "The memset_s error"; + } + if (memcpy_s(post_message, len, data, len) != 0) { + MS_LOG(EXCEPTION) << "The memset_s error"; + } EXPECT_STREQ(path_param.c_str(), "value1"); EXPECT_STREQ(header_param.c_str(), "headerValue"); EXPECT_STREQ(post_param.c_str(), "postValue"); - EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); + EXPECT_STREQ(post_message, "postKey=postValue"); const std::string rKey("headKey"); const std::string rVal("headValue"); @@ -157,6 +173,6 @@ TEST_F(TestHttpServer, addressException) { ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); } -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore