From: @anancds Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -132,19 +132,17 @@ std::string HttpMessageHandler::GetUriFragment() { | |||||
| return std::string(fragment); | return std::string(fragment); | ||||
| } | } | ||||
| std::string HttpMessageHandler::GetPostMsg() { | |||||
| uint64_t HttpMessageHandler::GetPostMsg(unsigned char **buffer) { | |||||
| MS_EXCEPTION_IF_NULL(event_request_); | MS_EXCEPTION_IF_NULL(event_request_); | ||||
| if (body_ != nullptr) { | |||||
| return *body_; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(buffer); | |||||
| size_t len = evbuffer_get_length(event_request_->input_buffer); | size_t len = evbuffer_get_length(event_request_->input_buffer); | ||||
| if (len == 0) { | if (len == 0) { | ||||
| MS_LOG(EXCEPTION) << "The post message is empty!"; | MS_LOG(EXCEPTION) << "The post message is empty!"; | ||||
| } | } | ||||
| const char *post_message = reinterpret_cast<const char *>(evbuffer_pullup(event_request_->input_buffer, -1)); | |||||
| MS_EXCEPTION_IF_NULL(post_message); | |||||
| body_ = std::make_unique<std::string>(post_message, len); | |||||
| return *body_; | |||||
| *buffer = evbuffer_pullup(event_request_->input_buffer, -1); | |||||
| MS_EXCEPTION_IF_NULL(*buffer); | |||||
| return len; | |||||
| } | } | ||||
| void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) { | void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) { | ||||
| @@ -62,7 +62,7 @@ class HttpMessageHandler { | |||||
| std::string GetHeadParam(const std::string &key); | std::string GetHeadParam(const std::string &key); | ||||
| std::string GetPathParam(const std::string &key); | std::string GetPathParam(const std::string &key); | ||||
| std::string GetPostParam(const std::string &key); | std::string GetPostParam(const std::string &key); | ||||
| std::string GetPostMsg(); | |||||
| uint64_t GetPostMsg(unsigned char **buffer); | |||||
| std::string GetUriPath(); | std::string GetUriPath(); | ||||
| std::string GetUriQuery(); | std::string GetUriQuery(); | ||||
| @@ -49,6 +49,7 @@ bool HttpServer::InitServer() { | |||||
| MS_LOG(EXCEPTION) << "The http server ip:" << server_address_ << " is illegal!"; | MS_LOG(EXCEPTION) << "The http server ip:" << server_address_ << " is illegal!"; | ||||
| } | } | ||||
| is_stop_ = false; | |||||
| event_base_ = event_base_new(); | event_base_ = event_base_new(); | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| event_http_ = evhttp_new(event_base_); | event_http_ = evhttp_new(event_base_); | ||||
| @@ -146,13 +147,21 @@ bool HttpServer::Start() { | |||||
| void HttpServer::Stop() { | void HttpServer::Stop() { | ||||
| MS_LOG(INFO) << "Stop http server!"; | MS_LOG(INFO) << "Stop http server!"; | ||||
| if (event_http_) { | |||||
| evhttp_free(event_http_); | |||||
| event_http_ = nullptr; | |||||
| } | |||||
| if (event_base_) { | |||||
| event_base_free(event_base_); | |||||
| event_base_ = nullptr; | |||||
| if (!is_stop_.load()) { | |||||
| int ret = event_base_loopbreak(event_base_); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "event base loop break failed!"; | |||||
| } | |||||
| if (event_http_) { | |||||
| evhttp_free(event_http_); | |||||
| event_http_ = nullptr; | |||||
| } | |||||
| if (event_base_) { | |||||
| event_base_free(event_base_); | |||||
| event_base_ = nullptr; | |||||
| } | |||||
| is_stop_ = true; | |||||
| } | } | ||||
| } | } | ||||
| @@ -32,6 +32,7 @@ | |||||
| #include <functional> | #include <functional> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <atomic> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -55,7 +56,12 @@ class HttpServer { | |||||
| public: | public: | ||||
| // Server address only support IPV4 now, and should be in format of "x.x.x.x" | // Server address only support IPV4 now, and should be in format of "x.x.x.x" | ||||
| explicit HttpServer(const std::string &address, std::uint16_t port) | explicit HttpServer(const std::string &address, std::uint16_t port) | ||||
| : server_address_(address), server_port_(port), event_base_(nullptr), event_http_(nullptr), is_init_(false) {} | |||||
| : server_address_(address), | |||||
| server_port_(port), | |||||
| event_base_(nullptr), | |||||
| event_http_(nullptr), | |||||
| is_init_(false), | |||||
| is_stop_(true) {} | |||||
| ~HttpServer(); | ~HttpServer(); | ||||
| @@ -84,6 +90,7 @@ class HttpServer { | |||||
| struct event_base *event_base_; | struct event_base *event_base_; | ||||
| struct evhttp *event_http_; | struct evhttp *event_http_; | ||||
| bool is_init_; | bool is_init_; | ||||
| std::atomic<bool> is_stop_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| @@ -198,7 +198,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | ||||
| if (!bev) { | if (!bev) { | ||||
| MS_LOG(ERROR) << "Error constructing buffer event!"; | MS_LOG(ERROR) << "Error constructing buffer event!"; | ||||
| event_base_loopbreak(base); | |||||
| int ret = event_base_loopbreak(base); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "event base loop break failed!"; | |||||
| } | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -43,11 +43,19 @@ class TestHttpServer : public UT::Common { | |||||
| std::string path_param = resp->GetPathParam("key1"); | std::string path_param = resp->GetPathParam("key1"); | ||||
| std::string header_param = resp->GetHeadParam("headerKey"); | std::string header_param = resp->GetHeadParam("headerKey"); | ||||
| std::string post_param = resp->GetPostParam("postKey"); | std::string post_param = resp->GetPostParam("postKey"); | ||||
| std::string post_message = resp->GetPostMsg(); | |||||
| unsigned char *data = nullptr; | |||||
| const uint64_t len = resp->GetPostMsg(&data); | |||||
| char post_message[len + 1]; | |||||
| if (memset_s(post_message, len + 1, 0, len + 1) != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||||
| } | |||||
| if (memcpy_s(post_message, len, data, len) != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||||
| } | |||||
| EXPECT_STREQ(path_param.c_str(), "value1"); | EXPECT_STREQ(path_param.c_str(), "value1"); | ||||
| EXPECT_STREQ(header_param.c_str(), "headerValue"); | EXPECT_STREQ(header_param.c_str(), "headerValue"); | ||||
| EXPECT_STREQ(post_param.c_str(), "postValue"); | EXPECT_STREQ(post_param.c_str(), "postValue"); | ||||
| EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); | |||||
| EXPECT_STREQ(post_message, "postKey=postValue"); | |||||
| const std::string rKey("headKey"); | const std::string rKey("headKey"); | ||||
| const std::string rVal("headValue"); | const std::string rVal("headValue"); | ||||
| @@ -79,11 +87,19 @@ class TestHttpServer : public UT::Common { | |||||
| std::string path_param = resp->GetPathParam("key1"); | std::string path_param = resp->GetPathParam("key1"); | ||||
| std::string header_param = resp->GetHeadParam("headerKey"); | std::string header_param = resp->GetHeadParam("headerKey"); | ||||
| std::string post_param = resp->GetPostParam("postKey"); | std::string post_param = resp->GetPostParam("postKey"); | ||||
| std::string post_message = resp->GetPostMsg(); | |||||
| unsigned char *data = nullptr; | |||||
| const uint64_t len = resp->GetPostMsg(&data); | |||||
| char post_message[len + 1]; | |||||
| if (memset_s(post_message, len + 1, 0, len + 1) != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||||
| } | |||||
| if (memcpy_s(post_message, len, data, len) != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||||
| } | |||||
| EXPECT_STREQ(path_param.c_str(), "value1"); | EXPECT_STREQ(path_param.c_str(), "value1"); | ||||
| EXPECT_STREQ(header_param.c_str(), "headerValue"); | EXPECT_STREQ(header_param.c_str(), "headerValue"); | ||||
| EXPECT_STREQ(post_param.c_str(), "postValue"); | EXPECT_STREQ(post_param.c_str(), "postValue"); | ||||
| EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); | |||||
| EXPECT_STREQ(post_message, "postKey=postValue"); | |||||
| const std::string rKey("headKey"); | const std::string rKey("headKey"); | ||||
| const std::string rVal("headValue"); | const std::string rVal("headValue"); | ||||
| @@ -157,6 +173,6 @@ TEST_F(TestHttpServer, addressException) { | |||||
| ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); | ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); | ||||
| } | } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||