| @@ -20,6 +20,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc") | |||||
| endif() | endif() | ||||
| if(NOT ENABLE_D) | if(NOT ENABLE_D) | ||||
| @@ -61,6 +61,11 @@ constexpr int kGroup3RandomLength = 4; | |||||
| constexpr int kGroup4RandomLength = 4; | constexpr int kGroup4RandomLength = 4; | ||||
| constexpr int kGroup5RandomLength = 12; | constexpr int kGroup5RandomLength = 12; | ||||
| // The size of the buffer for sending and receiving data is 4096 bytes. | |||||
| constexpr int kMessageChunkLength = 4096; | |||||
| // The timeout period for the http client to connect to the http server is 120 seconds. | |||||
| constexpr int kConnectionTimeout = 120; | |||||
| class CommUtil { | class CommUtil { | ||||
| public: | public: | ||||
| static bool CheckIpWithRegex(const std::string &ip); | static bool CheckIpWithRegex(const std::string &ip); | ||||
| @@ -80,5 +85,4 @@ class CommUtil { | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ | #endif // MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ | ||||
| @@ -0,0 +1,226 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/core/http_client.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| HttpClient::~HttpClient() { | |||||
| if (event_base_ != nullptr) { | |||||
| event_base_free(event_base_); | |||||
| event_base_ = nullptr; | |||||
| } | |||||
| } | |||||
| void HttpClient::Init() { | |||||
| event_base_ = event_base_new(); | |||||
| MS_EXCEPTION_IF_NULL(event_base_); | |||||
| dns_base_ = evdns_base_new(event_base_, 1); | |||||
| MS_EXCEPTION_IF_NULL(dns_base_); | |||||
| } | |||||
| Status HttpClient::Post(const std::string &url, const void *body, size_t len, std::shared_ptr<std::vector<char>> output, | |||||
| const std::map<std::string, std::string> &headers) { | |||||
| MS_EXCEPTION_IF_NULL(body); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| auto handler = std::make_shared<HttpMessageHandler>(); | |||||
| output->clear(); | |||||
| handler->set_body(output); | |||||
| struct evhttp_request *request = evhttp_request_new(ReadCallback, reinterpret_cast<void *>(handler.get())); | |||||
| MS_EXCEPTION_IF_NULL(request); | |||||
| InitRequest(handler, url, request); | |||||
| struct evhttp_connection *connection = | |||||
| evhttp_connection_base_new(event_base_, dns_base_, handler->GetHostByUri(), handler->GetUriPort()); | |||||
| if (!connection) { | |||||
| MS_LOG(ERROR) << "Create http connection failed!"; | |||||
| return Status::BADREQUEST; | |||||
| } | |||||
| struct evbuffer *buffer = evhttp_request_get_output_buffer(request); | |||||
| if (evbuffer_add(buffer, body, len) != 0) { | |||||
| MS_LOG(ERROR) << "Add buffer failed!"; | |||||
| return Status::INTERNAL; | |||||
| } | |||||
| AddHeaders(headers, request, handler); | |||||
| return CreateRequest(handler, connection, request, HttpMethod::HM_POST); | |||||
| } | |||||
| Status HttpClient::Get(const std::string &url, std::shared_ptr<std::vector<char>> output, | |||||
| const std::map<std::string, std::string> &headers) { | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| auto handler = std::make_shared<HttpMessageHandler>(); | |||||
| output->clear(); | |||||
| handler->set_body(output); | |||||
| struct evhttp_request *request = evhttp_request_new(ReadCallback, reinterpret_cast<void *>(handler.get())); | |||||
| MS_EXCEPTION_IF_NULL(request); | |||||
| InitRequest(handler, url, request); | |||||
| struct evhttp_connection *connection = | |||||
| evhttp_connection_base_new(event_base_, dns_base_, handler->GetHostByUri(), handler->GetUriPort()); | |||||
| if (!connection) { | |||||
| MS_LOG(ERROR) << "Create http connection failed!"; | |||||
| return Status::BADREQUEST; | |||||
| } | |||||
| AddHeaders(headers, request, handler); | |||||
| return CreateRequest(handler, connection, request, HttpMethod::HM_GET); | |||||
| } | |||||
| void HttpClient::set_connection_timeout(const int &timeout) { connection_timout_ = timeout; } | |||||
| void HttpClient::ReadCallback(struct evhttp_request *request, void *arg) { | |||||
| MS_EXCEPTION_IF_NULL(request); | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| auto handler = static_cast<HttpMessageHandler *>(arg); | |||||
| if (event_base_loopexit(handler->http_base(), nullptr) != 0) { | |||||
| MS_LOG(EXCEPTION) << "event base loop exit failed!"; | |||||
| } | |||||
| } | |||||
| int HttpClient::ReadHeaderDoneCallback(struct evhttp_request *request, void *arg) { | |||||
| MS_EXCEPTION_IF_NULL(request); | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| auto handler = static_cast<HttpMessageHandler *>(arg); | |||||
| handler->set_request(request); | |||||
| MS_LOG(DEBUG) << "The http response code is:" << evhttp_request_get_response_code(request) | |||||
| << ", The request code line is:" << evhttp_request_get_response_code_line(request); | |||||
| struct evkeyvalq *headers = evhttp_request_get_input_headers(request); | |||||
| struct evkeyval *header; | |||||
| TAILQ_FOREACH(header, headers, next) { | |||||
| MS_LOG(DEBUG) << "The key:" << header->key << ",The value:" << header->value; | |||||
| std::string len = "Content-Length"; | |||||
| if (!strcmp(header->key, len.c_str())) { | |||||
| handler->set_content_len(strtouq(header->value, nullptr, 10)); | |||||
| handler->InitBodySize(); | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| void HttpClient::ReadChunkDataCallback(struct evhttp_request *request, void *arg) { | |||||
| MS_EXCEPTION_IF_NULL(request); | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| auto handler = static_cast<HttpMessageHandler *>(arg); | |||||
| char buf[kMessageChunkLength]; | |||||
| struct evbuffer *evbuf = evhttp_request_get_input_buffer(request); | |||||
| MS_EXCEPTION_IF_NULL(evbuf); | |||||
| int n = 0; | |||||
| while ((n = evbuffer_remove(evbuf, &buf, sizeof(buf))) > 0) { | |||||
| handler->ReceiveMessage(buf, n); | |||||
| } | |||||
| } | |||||
| void HttpClient::RequestErrorCallback(enum evhttp_request_error error, void *arg) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| auto handler = static_cast<HttpMessageHandler *>(arg); | |||||
| MS_LOG(ERROR) << "The request failed, the error is:" << error; | |||||
| if (event_base_loopexit(handler->http_base(), nullptr) != 0) { | |||||
| MS_LOG(EXCEPTION) << "event base loop exit failed!"; | |||||
| } | |||||
| } | |||||
| void HttpClient::ConnectionCloseCallback(struct evhttp_connection *connection, void *arg) { | |||||
| MS_EXCEPTION_IF_NULL(connection); | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| MS_LOG(ERROR) << "Remote connection closed!"; | |||||
| if (event_base_loopexit((struct event_base *)arg, nullptr) != 0) { | |||||
| MS_LOG(EXCEPTION) << "event base loop exit failed!"; | |||||
| } | |||||
| } | |||||
| void HttpClient::AddHeaders(const std::map<std::string, std::string> &headers, struct evhttp_request *request, | |||||
| std::shared_ptr<HttpMessageHandler> handler) { | |||||
| MS_EXCEPTION_IF_NULL(request); | |||||
| if (evhttp_add_header(evhttp_request_get_output_headers(request), "Host", handler->GetHostByUri()) != 0) { | |||||
| MS_LOG(EXCEPTION) << "Add header failed!"; | |||||
| } | |||||
| for (auto &header : headers) { | |||||
| if (evhttp_add_header(evhttp_request_get_output_headers(request), header.first.data(), header.second.data()) != 0) { | |||||
| MS_LOG(EXCEPTION) << "Add header failed!"; | |||||
| } | |||||
| } | |||||
| } | |||||
| void HttpClient::InitRequest(std::shared_ptr<HttpMessageHandler> handler, const std::string &url, | |||||
| struct evhttp_request *request) { | |||||
| MS_EXCEPTION_IF_NULL(request); | |||||
| MS_EXCEPTION_IF_NULL(handler); | |||||
| handler->set_http_base(event_base_); | |||||
| handler->ParseUrl(url); | |||||
| evhttp_request_set_header_cb(request, ReadHeaderDoneCallback); | |||||
| evhttp_request_set_chunked_cb(request, ReadChunkDataCallback); | |||||
| evhttp_request_set_error_cb(request, RequestErrorCallback); | |||||
| MS_LOG(DEBUG) << "The url is:" << url << ", The host is:" << handler->GetHostByUri() | |||||
| << ", The port is:" << handler->GetUriPort() << ", The request_url is:" << handler->GetRequestPath(); | |||||
| } | |||||
| Status HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler, struct evhttp_connection *connection, | |||||
| struct evhttp_request *request, HttpMethod method) { | |||||
| MS_EXCEPTION_IF_NULL(handler); | |||||
| MS_EXCEPTION_IF_NULL(connection); | |||||
| MS_EXCEPTION_IF_NULL(request); | |||||
| evhttp_connection_set_closecb(connection, ConnectionCloseCallback, event_base_); | |||||
| evhttp_connection_set_timeout(connection, connection_timout_); | |||||
| if (evhttp_make_request(connection, request, evhttp_cmd_type(method), handler->GetRequestPath().c_str()) != 0) { | |||||
| MS_LOG(ERROR) << "Make request failed!"; | |||||
| return Status::INTERNAL; | |||||
| } | |||||
| if (!Start()) { | |||||
| MS_LOG(ERROR) << "Start http client failed!"; | |||||
| return Status::INTERNAL; | |||||
| } | |||||
| if (handler->request()) { | |||||
| MS_LOG(DEBUG) << "The http response code is:" << evhttp_request_get_response_code(handler->request()) | |||||
| << ", The request code line is:" << evhttp_request_get_response_code_line(handler->request()); | |||||
| return Status(evhttp_request_get_response_code(handler->request())); | |||||
| } | |||||
| return Status::INTERNAL; | |||||
| } | |||||
| bool HttpClient::Start() { | |||||
| MS_EXCEPTION_IF_NULL(event_base_); | |||||
| // int ret = event_base_dispatch(event_base_); | |||||
| int ret = event_base_loop(event_base_, 0); | |||||
| if (ret == 0) { | |||||
| MS_LOG(DEBUG) << "Event base dispatch success!"; | |||||
| return true; | |||||
| } else if (ret == 1) { | |||||
| MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; | |||||
| return false; | |||||
| } else if (ret == -1) { | |||||
| MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; | |||||
| return false; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Event base dispatch with unexpected error code!"; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_CLIENT_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_HTTP_CLIENT_H_ | |||||
| #include <event2/buffer.h> | |||||
| #include <event2/event.h> | |||||
| #include <event2/http.h> | |||||
| #include <event2/keyvalq_struct.h> | |||||
| #include <event2/listener.h> | |||||
| #include <event2/util.h> | |||||
| #include <event2/http_struct.h> | |||||
| #include <event2/dns.h> | |||||
| #include <event2/thread.h> | |||||
| #include <sys/queue.h> | |||||
| #include <cstdio> | |||||
| #include <cstdlib> | |||||
| #include <cstring> | |||||
| #include <functional> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <atomic> | |||||
| #include <thread> | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include "ps/core/http_message_handler.h" | |||||
| #include "ps/core/comm_util.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| enum class HttpMethod { HM_GET = 1 << 0, HM_POST = 1 << 1 }; | |||||
| enum class Status : int { | |||||
| OK = 200, // request completed ok | |||||
| BADREQUEST = 400, // invalid http request was made | |||||
| NOTFOUND = 404, // could not find content for uri | |||||
| INTERNAL = 500 // internal error | |||||
| }; | |||||
| class HttpClient { | |||||
| public: | |||||
| HttpClient() : event_base_(nullptr), dns_base_(nullptr), is_init_(false), connection_timout_(kConnectionTimeout) { | |||||
| Init(); | |||||
| } | |||||
| virtual ~HttpClient(); | |||||
| Status Post(const std::string &url, const void *body, size_t len, std::shared_ptr<std::vector<char>> output, | |||||
| const std::map<std::string, std::string> &headers = {}); | |||||
| Status Get(const std::string &url, std::shared_ptr<std::vector<char>> output, | |||||
| const std::map<std::string, std::string> &headers = {}); | |||||
| void set_connection_timeout(const int &timeout); | |||||
| private: | |||||
| static void ReadCallback(struct evhttp_request *remote_rsp, void *arg); | |||||
| static int ReadHeaderDoneCallback(struct evhttp_request *remote_rsp, void *arg); | |||||
| static void ReadChunkDataCallback(struct evhttp_request *remote_rsp, void *arg); | |||||
| static void RequestErrorCallback(enum evhttp_request_error error, void *arg); | |||||
| static void ConnectionCloseCallback(struct evhttp_connection *connection, void *arg); | |||||
| void AddHeaders(const std::map<std::string, std::string> &headers, struct evhttp_request *request, | |||||
| std::shared_ptr<HttpMessageHandler> handler); | |||||
| void InitRequest(std::shared_ptr<HttpMessageHandler> handler, const std::string &url, struct evhttp_request *request); | |||||
| Status CreateRequest(std::shared_ptr<HttpMessageHandler> handler, struct evhttp_connection *connection, | |||||
| struct evhttp_request *request, HttpMethod method); | |||||
| bool Start(); | |||||
| void Init(); | |||||
| struct event_base *event_base_; | |||||
| struct evdns_base *dns_base_; | |||||
| bool is_init_; | |||||
| int connection_timout_; | |||||
| }; | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_HTTP_CLIENT_H_ | |||||
| @@ -44,6 +44,7 @@ void HttpMessageHandler::InitHttpMessage() { | |||||
| const char *query = evhttp_uri_get_query(event_uri_); | const char *query = evhttp_uri_get_query(event_uri_); | ||||
| if (query) { | if (query) { | ||||
| MS_LOG(WARNING) << "The query is:" << query; | |||||
| evhttp_parse_query_str(query, &path_params_); | evhttp_parse_query_str(query, &path_params_); | ||||
| } | } | ||||
| @@ -52,6 +53,11 @@ void HttpMessageHandler::InitHttpMessage() { | |||||
| resp_buf_ = evhttp_request_get_output_buffer(event_request_); | resp_buf_ = evhttp_request_get_output_buffer(event_request_); | ||||
| } | } | ||||
| void HttpMessageHandler::ParseUrl(const std::string &url) { | |||||
| event_uri_ = evhttp_uri_parse(url.c_str()); | |||||
| MS_EXCEPTION_IF_NULL(event_uri_); | |||||
| } | |||||
| std::string HttpMessageHandler::GetHeadParam(const std::string &key) { | std::string HttpMessageHandler::GetHeadParam(const std::string &key) { | ||||
| MS_EXCEPTION_IF_NULL(head_params_); | MS_EXCEPTION_IF_NULL(head_params_); | ||||
| const char *val = evhttp_find_header(head_params_, key.c_str()); | const char *val = evhttp_find_header(head_params_, key.c_str()); | ||||
| @@ -74,8 +80,8 @@ void HttpMessageHandler::ParsePostParam() { | |||||
| post_param_parsed_ = true; | post_param_parsed_ = true; | ||||
| const char *post_message = reinterpret_cast<const char *>(evbuffer_pullup(event_request_->input_buffer, -1)); | const char *post_message = reinterpret_cast<const char *>(evbuffer_pullup(event_request_->input_buffer, -1)); | ||||
| MS_EXCEPTION_IF_NULL(post_message); | MS_EXCEPTION_IF_NULL(post_message); | ||||
| body_ = std::make_unique<std::string>(post_message, len); | |||||
| int ret = evhttp_parse_query_str(body_->c_str(), &post_params_); | |||||
| post_message_ = std::make_unique<std::string>(post_message, len); | |||||
| int ret = evhttp_parse_query_str(post_message_->c_str(), &post_params_); | |||||
| if (ret == -1) { | if (ret == -1) { | ||||
| MS_LOG(EXCEPTION) << "Parse post parameter failed!"; | MS_LOG(EXCEPTION) << "Parse post parameter failed!"; | ||||
| } | } | ||||
| @@ -105,9 +111,20 @@ std::string HttpMessageHandler::GetRequestHost() { | |||||
| return std::string(host); | return std::string(host); | ||||
| } | } | ||||
| const char *HttpMessageHandler::GetHostByUri() { | |||||
| MS_EXCEPTION_IF_NULL(event_uri_); | |||||
| const char *host = evhttp_uri_get_host(event_uri_); | |||||
| MS_EXCEPTION_IF_NULL(host); | |||||
| return host; | |||||
| } | |||||
| int HttpMessageHandler::GetUriPort() { | int HttpMessageHandler::GetUriPort() { | ||||
| MS_EXCEPTION_IF_NULL(event_uri_); | MS_EXCEPTION_IF_NULL(event_uri_); | ||||
| return evhttp_uri_get_port(event_uri_); | |||||
| int port = evhttp_uri_get_port(event_uri_); | |||||
| if (port < 0) { | |||||
| MS_LOG(EXCEPTION) << "The port:" << port << " should not be less than 0!"; | |||||
| } | |||||
| return port; | |||||
| } | } | ||||
| std::string HttpMessageHandler::GetUriPath() { | std::string HttpMessageHandler::GetUriPath() { | ||||
| @@ -117,6 +134,21 @@ std::string HttpMessageHandler::GetUriPath() { | |||||
| return std::string(path); | return std::string(path); | ||||
| } | } | ||||
| std::string HttpMessageHandler::GetRequestPath() { | |||||
| MS_EXCEPTION_IF_NULL(event_uri_); | |||||
| const char *path = evhttp_uri_get_path(event_uri_); | |||||
| if (path == nullptr || strlen(path) == 0) { | |||||
| path = "/"; | |||||
| } | |||||
| std::string path_res(path); | |||||
| const char *query = evhttp_uri_get_query(event_uri_); | |||||
| if (query) { | |||||
| path_res.append("?"); | |||||
| path_res.append(query); | |||||
| } | |||||
| return path_res; | |||||
| } | |||||
| std::string HttpMessageHandler::GetUriQuery() { | std::string HttpMessageHandler::GetUriQuery() { | ||||
| MS_EXCEPTION_IF_NULL(event_uri_); | MS_EXCEPTION_IF_NULL(event_uri_); | ||||
| const char *query = evhttp_uri_get_query(event_uri_); | const char *query = evhttp_uri_get_query(event_uri_); | ||||
| @@ -202,6 +234,41 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) { | |||||
| evhttp_send_error(event_request_, nCode, message.c_str()); | evhttp_send_error(event_request_, nCode, message.c_str()); | ||||
| } | } | ||||
| } | } | ||||
| void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||||
| MS_EXCEPTION_IF_NULL(buffer); | |||||
| int ret = memcpy_s(body_->data() + offset_, num, buffer, num); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| offset_ += num; | |||||
| } | |||||
| void HttpMessageHandler::set_content_len(const uint64_t &len) { content_len_ = len; } | |||||
| uint64_t HttpMessageHandler::content_len() { return content_len_; } | |||||
| event_base *HttpMessageHandler::http_base() { return event_base_; } | |||||
| void HttpMessageHandler::set_http_base(const struct event_base *base) { | |||||
| MS_EXCEPTION_IF_NULL(base); | |||||
| event_base_ = const_cast<event_base *>(base); | |||||
| } | |||||
| void HttpMessageHandler::set_request(const struct evhttp_request *req) { | |||||
| MS_EXCEPTION_IF_NULL(req); | |||||
| event_request_ = const_cast<evhttp_request *>(req); | |||||
| } | |||||
| struct evhttp_request *HttpMessageHandler::request() { | |||||
| return event_request_; | |||||
| } | |||||
| void HttpMessageHandler::InitBodySize() { body_->resize(content_len()); } | |||||
| std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; } | |||||
| void HttpMessageHandler::set_body(std::shared_ptr<std::vector<char>> body) { body_ = body; } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,37 +32,49 @@ | |||||
| #include <list> | #include <list> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "ps/core/comm_util.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| using HttpHeaders = std::map<std::string, std::list<std::string>>; | using HttpHeaders = std::map<std::string, std::list<std::string>>; | ||||
| using VectorPtr = std::shared_ptr<std::vector<char>>; | |||||
| class HttpMessageHandler { | class HttpMessageHandler { | ||||
| public: | public: | ||||
| explicit HttpMessageHandler(struct evhttp_request *req) | |||||
| : event_request_(req), | |||||
| HttpMessageHandler() | |||||
| : event_request_(nullptr), | |||||
| event_uri_(nullptr), | event_uri_(nullptr), | ||||
| path_params_{0}, | path_params_{0}, | ||||
| head_params_(nullptr), | head_params_(nullptr), | ||||
| post_params_{0}, | post_params_{0}, | ||||
| post_param_parsed_(false), | post_param_parsed_(false), | ||||
| post_message_(nullptr), | |||||
| body_(nullptr), | body_(nullptr), | ||||
| resp_headers_(nullptr), | resp_headers_(nullptr), | ||||
| resp_buf_(nullptr), | resp_buf_(nullptr), | ||||
| resp_code_(HTTP_OK) {} | |||||
| resp_code_(HTTP_OK), | |||||
| content_len_(0), | |||||
| event_base_(nullptr), | |||||
| offset_(0) {} | |||||
| virtual ~HttpMessageHandler() = default; | virtual ~HttpMessageHandler() = default; | ||||
| void InitHttpMessage(); | void InitHttpMessage(); | ||||
| void ParseUrl(const std::string &url); | |||||
| std::string GetRequestUri(); | std::string GetRequestUri(); | ||||
| std::string GetRequestHost(); | std::string GetRequestHost(); | ||||
| const char *GetHostByUri(); | |||||
| std::string GetHeadParam(const std::string &key); | std::string GetHeadParam(const std::string &key); | ||||
| std::string GetPathParam(const std::string &key); | std::string GetPathParam(const std::string &key); | ||||
| std::string GetPostParam(const std::string &key); | std::string GetPostParam(const std::string &key); | ||||
| uint64_t GetPostMsg(unsigned char **buffer); | uint64_t GetPostMsg(unsigned char **buffer); | ||||
| std::string GetUriPath(); | std::string GetUriPath(); | ||||
| std::string GetRequestPath(); | |||||
| std::string GetUriQuery(); | std::string GetUriQuery(); | ||||
| // It will return -1 if no port set | // It will return -1 if no port set | ||||
| @@ -83,6 +95,18 @@ class HttpMessageHandler { | |||||
| // If message is empty, libevent will use default error code message instead | // If message is empty, libevent will use default error code message instead | ||||
| void RespError(int nCode, const std::string &message); | void RespError(int nCode, const std::string &message); | ||||
| // Body length should no more than MAX_POST_BODY_LEN, default 64kB | |||||
| void ParsePostParam(); | |||||
| void ReceiveMessage(const void *buffer, size_t num); | |||||
| void set_content_len(const uint64_t &len); | |||||
| uint64_t content_len(); | |||||
| event_base *http_base(); | |||||
| void set_http_base(const struct event_base *base); | |||||
| void set_request(const struct evhttp_request *req); | |||||
| struct evhttp_request *request(); | |||||
| void InitBodySize(); | |||||
| VectorPtr body(); | |||||
| void set_body(VectorPtr body); | |||||
| private: | private: | ||||
| struct evhttp_request *event_request_; | struct evhttp_request *event_request_; | ||||
| @@ -91,13 +115,14 @@ class HttpMessageHandler { | |||||
| struct evkeyvalq *head_params_; | struct evkeyvalq *head_params_; | ||||
| struct evkeyvalq post_params_; | struct evkeyvalq post_params_; | ||||
| bool post_param_parsed_; | bool post_param_parsed_; | ||||
| std::unique_ptr<std::string> body_; | |||||
| std::unique_ptr<std::string> post_message_; | |||||
| VectorPtr body_; | |||||
| struct evkeyvalq *resp_headers_; | struct evkeyvalq *resp_headers_; | ||||
| struct evbuffer *resp_buf_; | struct evbuffer *resp_buf_; | ||||
| int resp_code_; | int resp_code_; | ||||
| // Body length should no more than MAX_POST_BODY_LEN, default 64kB | |||||
| void ParsePostParam(); | |||||
| uint64_t content_len_; | |||||
| struct event_base *event_base_; | |||||
| uint64_t offset_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -57,6 +57,7 @@ bool HttpServer::InitServer() { | |||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| event_http_ = evhttp_new(event_base_); | event_http_ = evhttp_new(event_base_); | ||||
| MS_EXCEPTION_IF_NULL(event_http_); | MS_EXCEPTION_IF_NULL(event_http_); | ||||
| evhttp_set_timeout(event_http_, request_timeout_); | |||||
| int ret = evhttp_bind_socket(event_http_, server_address_.c_str(), server_port_); | int ret = evhttp_bind_socket(event_http_, server_address_.c_str(), server_port_); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "Http bind server addr:" << server_address_.c_str() << " port:" << server_port_ << "failed"; | MS_LOG(EXCEPTION) << "Http bind server addr:" << server_address_.c_str() << " port:" << server_port_ << "failed"; | ||||
| @@ -70,7 +71,7 @@ void HttpServer::SetTimeOut(int seconds) { | |||||
| if (seconds < 0) { | if (seconds < 0) { | ||||
| MS_LOG(EXCEPTION) << "The timeout seconds:" << seconds << "is less than 0!"; | MS_LOG(EXCEPTION) << "The timeout seconds:" << seconds << "is less than 0!"; | ||||
| } | } | ||||
| evhttp_set_timeout(event_http_, seconds); | |||||
| request_timeout_ = seconds; | |||||
| } | } | ||||
| void HttpServer::SetAllowedMethod(u_int16_t methods) { | void HttpServer::SetAllowedMethod(u_int16_t methods) { | ||||
| @@ -105,7 +106,8 @@ bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *functio | |||||
| auto TransFunc = [](struct evhttp_request *req, void *arg) { | auto TransFunc = [](struct evhttp_request *req, void *arg) { | ||||
| MS_EXCEPTION_IF_NULL(req); | MS_EXCEPTION_IF_NULL(req); | ||||
| MS_EXCEPTION_IF_NULL(arg); | MS_EXCEPTION_IF_NULL(arg); | ||||
| auto httpReq = std::make_shared<HttpMessageHandler>(req); | |||||
| auto httpReq = std::make_shared<HttpMessageHandler>(); | |||||
| httpReq->set_request(req); | |||||
| httpReq->InitHttpMessage(); | httpReq->InitHttpMessage(); | ||||
| OnRequestReceive *func = reinterpret_cast<OnRequestReceive *>(arg); | OnRequestReceive *func = reinterpret_cast<OnRequestReceive *>(arg); | ||||
| (*func)(httpReq); | (*func)(httpReq); | ||||
| @@ -144,8 +146,9 @@ bool HttpServer::Start() { | |||||
| MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; | MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; | ||||
| return false; | return false; | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; | |||||
| MS_LOG(EXCEPTION) << "Event base dispatch with unexpected error code!"; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void HttpServer::Stop() { | void HttpServer::Stop() { | ||||
| @@ -38,18 +38,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| typedef enum eHttpMethod { | |||||
| HM_GET = 1 << 0, | |||||
| HM_POST = 1 << 1, | |||||
| HM_HEAD = 1 << 2, | |||||
| HM_PUT = 1 << 3, | |||||
| HM_DELETE = 1 << 4, | |||||
| HM_OPTIONS = 1 << 5, | |||||
| HM_TRACE = 1 << 6, | |||||
| HM_CONNECT = 1 << 7, | |||||
| HM_PATCH = 1 << 8 | |||||
| } HttpMethod; | |||||
| using OnRequestReceive = std::function<void(std::shared_ptr<HttpMessageHandler>)>; | using OnRequestReceive = std::function<void(std::shared_ptr<HttpMessageHandler>)>; | ||||
| class HttpServer { | class HttpServer { | ||||
| @@ -61,12 +49,13 @@ class HttpServer { | |||||
| event_base_(nullptr), | event_base_(nullptr), | ||||
| event_http_(nullptr), | event_http_(nullptr), | ||||
| is_init_(false), | is_init_(false), | ||||
| is_stop_(true) {} | |||||
| is_stop_(true), | |||||
| request_timeout_(300) {} | |||||
| ~HttpServer(); | ~HttpServer(); | ||||
| bool InitServer(); | bool InitServer(); | ||||
| void SetTimeOut(int seconds = 5); | |||||
| void SetTimeOut(int seconds); | |||||
| // Default allowed methods: GET, POST, HEAD, PUT, DELETE | // Default allowed methods: GET, POST, HEAD, PUT, DELETE | ||||
| void SetAllowedMethod(u_int16_t methods); | void SetAllowedMethod(u_int16_t methods); | ||||
| @@ -91,9 +80,9 @@ class HttpServer { | |||||
| struct evhttp *event_http_; | struct evhttp *event_http_; | ||||
| bool is_init_; | bool is_init_; | ||||
| std::atomic<bool> is_stop_; | std::atomic<bool> is_stop_; | ||||
| int request_timeout_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ | #endif // MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ | ||||
| @@ -176,7 +176,7 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { | |||||
| struct evbuffer *input = bufferevent_get_input(const_cast<struct bufferevent *>(bev)); | struct evbuffer *input = bufferevent_get_input(const_cast<struct bufferevent *>(bev)); | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| char read_buffer[4096]; | |||||
| char read_buffer[kMessageChunkLength]; | |||||
| while (EVBUFFER_LENGTH(input) > 0) { | while (EVBUFFER_LENGTH(input) > 0) { | ||||
| int read = evbuffer_remove(input, &read_buffer, sizeof(read_buffer)); | int read = evbuffer_remove(input, &read_buffer, sizeof(read_buffer)); | ||||
| @@ -330,7 +330,7 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { | |||||
| auto conn = static_cast<class TcpConnection *>(connection); | auto conn = static_cast<class TcpConnection *>(connection); | ||||
| struct evbuffer *buf = bufferevent_get_input(bev); | struct evbuffer *buf = bufferevent_get_input(bev); | ||||
| char read_buffer[4096]; | |||||
| char read_buffer[kMessageChunkLength]; | |||||
| while (EVBUFFER_LENGTH(buf) > 0) { | while (EVBUFFER_LENGTH(buf) > 0) { | ||||
| int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)); | int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)); | ||||
| if (read == -1) { | if (read == -1) { | ||||
| @@ -0,0 +1,121 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include <cstdio> | |||||
| #include <cstdlib> | |||||
| #include <cstring> | |||||
| #include <iostream> | |||||
| #include <string> | |||||
| #include <thread> | |||||
| #include <memory> | |||||
| #include "common/common_test.h" | |||||
| #include "ps/core/http_server.h" | |||||
| #include "ps/core/http_client.h" | |||||
| using namespace std; | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class TestHttpClient : public UT::Common { | |||||
| public: | |||||
| TestHttpClient() : server_(nullptr), http_server_thread_(nullptr) {} | |||||
| virtual ~TestHttpClient() = default; | |||||
| OnRequestReceive http_get_func = std::bind( | |||||
| [](std::shared_ptr<HttpMessageHandler> resp) { | |||||
| EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget"); | |||||
| const unsigned char ret[] = "get request success!\n"; | |||||
| resp->QuickResponse(200, ret, 22); | |||||
| }, | |||||
| std::placeholders::_1); | |||||
| OnRequestReceive http_handler_func = std::bind( | |||||
| [](std::shared_ptr<HttpMessageHandler> resp) { | |||||
| std::string host = resp->GetRequestHost(); | |||||
| EXPECT_STREQ(host.c_str(), "127.0.0.1"); | |||||
| std::string path_param = resp->GetPathParam("key1"); | |||||
| std::string header_param = resp->GetHeadParam("headerKey"); | |||||
| unsigned char *data = nullptr; | |||||
| const uint64_t len = resp->GetPostMsg(&data); | |||||
| char post_message[len + 1]; | |||||
| if (memset_s(post_message, len + 1, 0, len + 1) != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||||
| } | |||||
| if (memcpy_s(post_message, len, data, len) != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memset_s error"; | |||||
| } | |||||
| EXPECT_STREQ(path_param.c_str(), "value1"); | |||||
| EXPECT_STREQ(header_param.c_str(), "headerValue"); | |||||
| EXPECT_STREQ(post_message, "postKey=postValue"); | |||||
| const std::string rKey("headKey"); | |||||
| const std::string rVal("headValue"); | |||||
| const std::string rBody("post request success!\n"); | |||||
| resp->AddRespHeadParam(rKey, rVal); | |||||
| resp->AddRespString(rBody); | |||||
| resp->SetRespCode(200); | |||||
| resp->SendResponse(); | |||||
| }, | |||||
| std::placeholders::_1); | |||||
| void SetUp() override { | |||||
| server_ = std::make_unique<HttpServer>("0.0.0.0", 9999); | |||||
| server_->RegisterRoute("/httpget", &http_get_func); | |||||
| server_->RegisterRoute("/handler", &http_handler_func); | |||||
| http_server_thread_ = std::make_unique<std::thread>([&]() { server_->Start(); }); | |||||
| http_server_thread_->detach(); | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(1000)); | |||||
| } | |||||
| void TearDown() override { | |||||
| server_->Stop(); | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(2000)); | |||||
| } | |||||
| private: | |||||
| std::unique_ptr<HttpServer> server_; | |||||
| std::unique_ptr<std::thread> http_server_thread_; | |||||
| }; | |||||
| TEST_F(TestHttpClient, Get) { | |||||
| HttpClient client; | |||||
| std::map<std::string, std::string> headers = {{"headerKey", "headerValue"}}; | |||||
| auto output = std::make_shared<std::vector<char>>(); | |||||
| auto ret = client.Get("http://127.0.0.1:9999/httpget", output, headers); | |||||
| EXPECT_STREQ("get request success!\n", output->data()); | |||||
| EXPECT_EQ(Status::OK, ret); | |||||
| } | |||||
| TEST_F(TestHttpClient, Post) { | |||||
| HttpClient client; | |||||
| std::map<std::string, std::string> headers = {{"headerKey", "headerValue"}}; | |||||
| auto output = std::make_shared<std::vector<char>>(); | |||||
| std::string post_data = "postKey=postValue"; | |||||
| auto ret = | |||||
| client.Post("http://127.0.0.1:9999/handler?key1=value1", post_data.c_str(), post_data.length(), output, headers); | |||||
| EXPECT_STREQ("post request success!\n", output->data()); | |||||
| EXPECT_EQ(Status::OK, ret); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -42,7 +42,6 @@ class TestHttpServer : public UT::Common { | |||||
| std::string path_param = resp->GetPathParam("key1"); | std::string path_param = resp->GetPathParam("key1"); | ||||
| std::string header_param = resp->GetHeadParam("headerKey"); | std::string header_param = resp->GetHeadParam("headerKey"); | ||||
| std::string post_param = resp->GetPostParam("postKey"); | |||||
| unsigned char *data = nullptr; | unsigned char *data = nullptr; | ||||
| const uint64_t len = resp->GetPostMsg(&data); | const uint64_t len = resp->GetPostMsg(&data); | ||||
| char post_message[len + 1]; | char post_message[len + 1]; | ||||
| @@ -54,7 +53,6 @@ class TestHttpServer : public UT::Common { | |||||
| } | } | ||||
| EXPECT_STREQ(path_param.c_str(), "value1"); | EXPECT_STREQ(path_param.c_str(), "value1"); | ||||
| EXPECT_STREQ(header_param.c_str(), "headerValue"); | EXPECT_STREQ(header_param.c_str(), "headerValue"); | ||||
| EXPECT_STREQ(post_param.c_str(), "postValue"); | |||||
| EXPECT_STREQ(post_message, "postKey=postValue"); | EXPECT_STREQ(post_message, "postKey=postValue"); | ||||
| const std::string rKey("headKey"); | const std::string rKey("headKey"); | ||||