| @@ -132,19 +132,17 @@ std::string HttpMessageHandler::GetUriFragment() { | |||
| return std::string(fragment); | |||
| } | |||
| std::string HttpMessageHandler::GetPostMsg() { | |||
| uint64_t HttpMessageHandler::GetPostMsg(unsigned char **buffer) { | |||
| MS_EXCEPTION_IF_NULL(event_request_); | |||
| if (body_ != nullptr) { | |||
| return *body_; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(buffer); | |||
| size_t len = evbuffer_get_length(event_request_->input_buffer); | |||
| if (len == 0) { | |||
| MS_LOG(EXCEPTION) << "The post message is empty!"; | |||
| } | |||
| const char *post_message = reinterpret_cast<const char *>(evbuffer_pullup(event_request_->input_buffer, -1)); | |||
| MS_EXCEPTION_IF_NULL(post_message); | |||
| body_ = std::make_unique<std::string>(post_message, len); | |||
| return *body_; | |||
| *buffer = evbuffer_pullup(event_request_->input_buffer, -1); | |||
| MS_EXCEPTION_IF_NULL(*buffer); | |||
| return len; | |||
| } | |||
| void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) { | |||
| @@ -62,7 +62,7 @@ class HttpMessageHandler { | |||
| std::string GetHeadParam(const std::string &key); | |||
| std::string GetPathParam(const std::string &key); | |||
| std::string GetPostParam(const std::string &key); | |||
| std::string GetPostMsg(); | |||
| uint64_t GetPostMsg(unsigned char **buffer); | |||
| std::string GetUriPath(); | |||
| std::string GetUriQuery(); | |||
| @@ -49,6 +49,7 @@ bool HttpServer::InitServer() { | |||
| MS_LOG(EXCEPTION) << "The http server ip:" << server_address_ << " is illegal!"; | |||
| } | |||
| is_stop_ = false; | |||
| event_base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| event_http_ = evhttp_new(event_base_); | |||
| @@ -146,13 +147,21 @@ bool HttpServer::Start() { | |||
| void HttpServer::Stop() { | |||
| MS_LOG(INFO) << "Stop http server!"; | |||
| if (event_http_) { | |||
| evhttp_free(event_http_); | |||
| event_http_ = nullptr; | |||
| } | |||
| if (event_base_) { | |||
| event_base_free(event_base_); | |||
| event_base_ = nullptr; | |||
| if (!is_stop_.load()) { | |||
| int ret = event_base_loopbreak(event_base_); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "event base loop break failed!"; | |||
| } | |||
| if (event_http_) { | |||
| evhttp_free(event_http_); | |||
| event_http_ = nullptr; | |||
| } | |||
| if (event_base_) { | |||
| event_base_free(event_base_); | |||
| event_base_ = nullptr; | |||
| } | |||
| is_stop_ = true; | |||
| } | |||
| } | |||
| @@ -32,6 +32,7 @@ | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <atomic> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -55,7 +56,12 @@ class HttpServer { | |||
| public: | |||
| // Server address only support IPV4 now, and should be in format of "x.x.x.x" | |||
| explicit HttpServer(const std::string &address, std::uint16_t port) | |||
| : server_address_(address), server_port_(port), event_base_(nullptr), event_http_(nullptr), is_init_(false) {} | |||
| : server_address_(address), | |||
| server_port_(port), | |||
| event_base_(nullptr), | |||
| event_http_(nullptr), | |||
| is_init_(false), | |||
| is_stop_(true) {} | |||
| ~HttpServer(); | |||
| @@ -84,6 +90,7 @@ class HttpServer { | |||
| struct event_base *event_base_; | |||
| struct evhttp *event_http_; | |||
| bool is_init_; | |||
| std::atomic<bool> is_stop_; | |||
| }; | |||
| } // namespace core | |||
| @@ -198,7 +198,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | |||
| if (!bev) { | |||
| MS_LOG(ERROR) << "Error constructing buffer event!"; | |||
| event_base_loopbreak(base); | |||
| int ret = event_base_loopbreak(base); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "event base loop break failed!"; | |||
| } | |||
| return; | |||
| } | |||
| @@ -43,11 +43,19 @@ class TestHttpServer : public UT::Common { | |||
| std::string path_param = resp->GetPathParam("key1"); | |||
| std::string header_param = resp->GetHeadParam("headerKey"); | |||
| std::string post_param = resp->GetPostParam("postKey"); | |||
| std::string post_message = resp->GetPostMsg(); | |||
| unsigned char *data = nullptr; | |||
| const uint64_t len = resp->GetPostMsg(&data); | |||
| char post_message[len + 1]; | |||
| if (memset_s(post_message, len + 1, 0, len + 1) != 0) { | |||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||
| } | |||
| if (memcpy_s(post_message, len, data, len) != 0) { | |||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||
| } | |||
| EXPECT_STREQ(path_param.c_str(), "value1"); | |||
| EXPECT_STREQ(header_param.c_str(), "headerValue"); | |||
| EXPECT_STREQ(post_param.c_str(), "postValue"); | |||
| EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); | |||
| EXPECT_STREQ(post_message, "postKey=postValue"); | |||
| const std::string rKey("headKey"); | |||
| const std::string rVal("headValue"); | |||
| @@ -79,11 +87,19 @@ class TestHttpServer : public UT::Common { | |||
| std::string path_param = resp->GetPathParam("key1"); | |||
| std::string header_param = resp->GetHeadParam("headerKey"); | |||
| std::string post_param = resp->GetPostParam("postKey"); | |||
| std::string post_message = resp->GetPostMsg(); | |||
| unsigned char *data = nullptr; | |||
| const uint64_t len = resp->GetPostMsg(&data); | |||
| char post_message[len + 1]; | |||
| if (memset_s(post_message, len + 1, 0, len + 1) != 0) { | |||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||
| } | |||
| if (memcpy_s(post_message, len, data, len) != 0) { | |||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||
| } | |||
| EXPECT_STREQ(path_param.c_str(), "value1"); | |||
| EXPECT_STREQ(header_param.c_str(), "headerValue"); | |||
| EXPECT_STREQ(post_param.c_str(), "postValue"); | |||
| EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); | |||
| EXPECT_STREQ(post_message, "postKey=postValue"); | |||
| const std::string rKey("headKey"); | |||
| const std::string rVal("headValue"); | |||
| @@ -157,6 +173,6 @@ TEST_F(TestHttpServer, addressException) { | |||
| ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); | |||
| } | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||