| @@ -38,21 +38,6 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| HttpMessageHandler::~HttpMessageHandler() { | |||
| if (!event_request_) { | |||
| evhttp_request_free(event_request_); | |||
| event_request_ = nullptr; | |||
| } | |||
| if (!event_uri_) { | |||
| evhttp_uri_free(const_cast<evhttp_uri *>(event_uri_)); | |||
| event_uri_ = nullptr; | |||
| } | |||
| if (!resp_buf_) { | |||
| evbuffer_free(resp_buf_); | |||
| resp_buf_ = nullptr; | |||
| } | |||
| } | |||
| void HttpMessageHandler::InitHttpMessage() { | |||
| MS_EXCEPTION_IF_NULL(event_request_); | |||
| event_uri_ = evhttp_request_get_evhttp_uri(event_request_); | |||
| @@ -54,7 +54,7 @@ class HttpMessageHandler { | |||
| resp_buf_(nullptr), | |||
| resp_code_(HTTP_OK) {} | |||
| ~HttpMessageHandler(); | |||
| virtual ~HttpMessageHandler() = default; | |||
| void InitHttpMessage(); | |||
| std::string GetRequestUri(); | |||
| @@ -99,10 +99,10 @@ bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *functio | |||
| auto TransFunc = [](struct evhttp_request *req, void *arg) { | |||
| MS_EXCEPTION_IF_NULL(req); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| HttpMessageHandler httpReq(req); | |||
| httpReq.InitHttpMessage(); | |||
| auto httpReq = std::make_shared<HttpMessageHandler>(req); | |||
| httpReq->InitHttpMessage(); | |||
| OnRequestReceive *func = reinterpret_cast<OnRequestReceive *>(arg); | |||
| (*func)(&httpReq); | |||
| (*func)(httpReq); | |||
| }; | |||
| MS_EXCEPTION_IF_NULL(event_http_); | |||
| @@ -125,6 +125,7 @@ bool HttpServer::UnRegisterRoute(const std::string &url) { | |||
| } | |||
| bool HttpServer::Start() { | |||
| MS_LOG(INFO) << "Start http server!"; | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| int ret = event_base_dispatch(event_base_); | |||
| if (ret == 0) { | |||
| @@ -142,6 +143,7 @@ bool HttpServer::Start() { | |||
| } | |||
| void HttpServer::Stop() { | |||
| MS_LOG(INFO) << "Stop http server!"; | |||
| if (event_http_) { | |||
| evhttp_free(event_http_); | |||
| event_http_ = nullptr; | |||
| @@ -30,6 +30,7 @@ | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <string> | |||
| namespace mindspore { | |||
| @@ -48,6 +49,8 @@ typedef enum eHttpMethod { | |||
| HM_PATCH = 1 << 8 | |||
| } HttpMethod; | |||
| using OnRequestReceive = std::function<void(std::shared_ptr<HttpMessageHandler>)>; | |||
| class HttpServer { | |||
| public: | |||
| // Server address only support IPV4 now, and should be in format of "x.x.x.x" | |||
| @@ -56,8 +59,6 @@ class HttpServer { | |||
| ~HttpServer(); | |||
| using OnRequestReceive = std::function<void(HttpMessageHandler *)>; | |||
| bool InitServer(); | |||
| void SetTimeOut(int seconds = 5); | |||
| @@ -85,7 +85,7 @@ void TcpClient::InitTcpClient() { | |||
| bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this); | |||
| if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) { | |||
| MS_LOG(EXCEPTION) << "buffer event enable read and write failed!"; | |||
| MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; | |||
| } | |||
| int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin)); | |||
| @@ -107,7 +107,7 @@ void TcpClient::StartWithDelay(int seconds) { | |||
| event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this); | |||
| if (evtimer_add(event_timeout_, &timeout_value) == -1) { | |||
| MS_LOG(EXCEPTION) << "event timeout failed!"; | |||
| MS_LOG(EXCEPTION) << "Event timeout failed!"; | |||
| } | |||
| } | |||
| @@ -212,7 +212,7 @@ void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; } | |||
| void TcpClient::SendMessage(const void *buf, size_t num) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) { | |||
| MS_LOG(EXCEPTION) << "event buffer add failed!"; | |||
| MS_LOG(EXCEPTION) << "Event buffer add failed!"; | |||
| } | |||
| } | |||
| } // namespace comm | |||
| @@ -102,6 +102,7 @@ void TcpServer::InitServer() { | |||
| void TcpServer::Start() { | |||
| std::unique_lock<std::recursive_mutex> l(connection_mutex_); | |||
| MS_LOG(INFO) << "Start tcp server!"; | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| int ret = event_base_dispatch(base_); | |||
| if (ret == 0) { | |||
| @@ -116,6 +117,7 @@ void TcpServer::Start() { | |||
| } | |||
| void TcpServer::Stop() { | |||
| MS_LOG(INFO) << "Stop tcp server!"; | |||
| if (signal_event_ != nullptr) { | |||
| event_free(signal_event_); | |||
| signal_event_ = nullptr; | |||
| @@ -171,7 +173,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||
| server->AddConnection(fd, conn); | |||
| bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn)); | |||
| if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { | |||
| MS_LOG(EXCEPTION) << "buffer event enable read and write failed!"; | |||
| MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; | |||
| } | |||
| } | |||
| @@ -194,7 +196,7 @@ void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) { | |||
| struct timeval delay = {0, 0}; | |||
| MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds."; | |||
| if (event_base_loopexit(base, &delay) == -1) { | |||
| MS_LOG(EXCEPTION) << "event base loop exit failed."; | |||
| MS_LOG(EXCEPTION) << "Event base loop exit failed."; | |||
| } | |||
| } | |||
| @@ -234,7 +236,7 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||
| // Notify about disconnection | |||
| if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); | |||
| } else { | |||
| MS_LOG(ERROR) << "unhandled event!"; | |||
| MS_LOG(ERROR) << "Unhandled event!"; | |||
| } | |||
| } | |||
| @@ -33,7 +33,7 @@ class TestHttpServer : public UT::Common { | |||
| public: | |||
| TestHttpServer() = default; | |||
| static void testGetHandler(HttpMessageHandler *resp) { | |||
| static void testGetHandler(std::shared_ptr<HttpMessageHandler> resp) { | |||
| std::string host = resp->GetRequestHost(); | |||
| EXPECT_STREQ(host.c_str(), "127.0.0.1"); | |||
| @@ -58,8 +58,8 @@ class TestHttpServer : public UT::Common { | |||
| void SetUp() override { | |||
| server_ = new HttpServer("0.0.0.0", 9999); | |||
| std::function<void(HttpMessageHandler *)> http_get_func = std::bind( | |||
| [](HttpMessageHandler *resp) { | |||
| OnRequestReceive http_get_func = std::bind( | |||
| [](std::shared_ptr<HttpMessageHandler> resp) { | |||
| EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); | |||
| EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1"); | |||
| EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1"); | |||
| @@ -68,8 +68,8 @@ class TestHttpServer : public UT::Common { | |||
| }, | |||
| std::placeholders::_1); | |||
| std::function<void(HttpMessageHandler *)> http_handler_func = std::bind( | |||
| [](HttpMessageHandler *resp) { | |||
| OnRequestReceive http_handler_func = std::bind( | |||
| [](std::shared_ptr<HttpMessageHandler> resp) { | |||
| std::string host = resp->GetRequestHost(); | |||
| EXPECT_STREQ(host.c_str(), "127.0.0.1"); | |||
| @@ -97,9 +97,13 @@ class TestHttpServer : public UT::Common { | |||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | |||
| http_server_thread_ = std::make_unique<std::thread>([&]() { server_->Start(); }); | |||
| http_server_thread_->detach(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5000)); | |||
| } | |||
| void TearDown() override { server_->Stop(); } | |||
| void TearDown() override { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5000)); | |||
| server_->Stop(); | |||
| } | |||
| private: | |||
| HttpServer *server_; | |||
| @@ -140,15 +144,13 @@ TEST_F(TestHttpServer, messageHandler) { | |||
| TEST_F(TestHttpServer, portErrorNoException) { | |||
| HttpServer *server_exception = new HttpServer("0.0.0.0", -1); | |||
| std::function<void(HttpMessageHandler *)> http_handler_func = | |||
| std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); | |||
| OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); | |||
| EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func)); | |||
| } | |||
| TEST_F(TestHttpServer, addressException) { | |||
| HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998); | |||
| std::function<void(HttpMessageHandler *)> http_handler_func = | |||
| std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); | |||
| OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); | |||
| ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); | |||
| } | |||
| @@ -38,7 +38,7 @@ class TestTcpServer : public UT::Common { | |||
| server_->Start(); | |||
| }); | |||
| http_server_thread_->detach(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(2000)); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5000)); | |||
| } | |||
| void TearDown() override { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(2000)); | |||