| @@ -370,7 +370,7 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| else() | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| target_link_libraries(mindspore proto_input mindspore::protobuf | |||
| mindspore::event mindspore::event_pthreads mindspore::event_openssl) | |||
| mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json) | |||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | |||
| if(${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(mindspore ibverbs rdmacm) | |||
| @@ -25,6 +25,8 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include "ps/core/communicator/request_process_result_code.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| constexpr char kEnvCommType[] = "MS_COMM_TYPE"; | |||
| @@ -70,6 +72,10 @@ constexpr int64_t kPullCmd = 51; | |||
| constexpr size_t kInvalidKey = UINT64_MAX; | |||
| constexpr int64_t kInvalidID = -1; | |||
| constexpr uint32_t kMaxMessageSize = static_cast<uint32_t>(100 * (uint32_t(1) << 20)); | |||
| constexpr char kServerNum[] = "server_num"; | |||
| constexpr char kWorkerNum[] = "worker_num"; | |||
| using DataPtr = std::shared_ptr<unsigned char[]>; | |||
| using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | |||
| using Key = uint64_t; | |||
| @@ -129,6 +135,10 @@ const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum | |||
| << " is out of bound."; \ | |||
| } \ | |||
| } | |||
| #define ERROR_STATUS(result, code, message) \ | |||
| MS_LOG(ERROR) << message; \ | |||
| result = RequestProcessResult(code, message) | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CONSTANTS_H_ | |||
| @@ -33,8 +33,9 @@ void HttpClient::Init() { | |||
| MS_EXCEPTION_IF_NULL(dns_base_); | |||
| } | |||
| Status HttpClient::Post(const std::string &url, const void *body, size_t len, std::shared_ptr<std::vector<char>> output, | |||
| const std::map<std::string, std::string> &headers) { | |||
| ResponseCode HttpClient::Post(const std::string &url, const void *body, size_t len, | |||
| std::shared_ptr<std::vector<char>> output, | |||
| const std::map<std::string, std::string> &headers) { | |||
| MS_EXCEPTION_IF_NULL(body); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| auto handler = std::make_shared<HttpMessageHandler>(); | |||
| @@ -50,13 +51,13 @@ Status HttpClient::Post(const std::string &url, const void *body, size_t len, st | |||
| evhttp_connection_base_new(event_base_, dns_base_, handler->GetHostByUri(), handler->GetUriPort()); | |||
| if (connection == nullptr) { | |||
| MS_LOG(ERROR) << "Create http connection failed!"; | |||
| return Status::BADREQUEST; | |||
| return ResponseCode::BADREQUEST; | |||
| } | |||
| struct evbuffer *buffer = evhttp_request_get_output_buffer(request); | |||
| if (evbuffer_add(buffer, body, len) != 0) { | |||
| MS_LOG(ERROR) << "Add buffer failed!"; | |||
| return Status::INTERNAL; | |||
| return ResponseCode::INTERNAL; | |||
| } | |||
| AddHeaders(headers, request, handler); | |||
| @@ -64,8 +65,8 @@ Status HttpClient::Post(const std::string &url, const void *body, size_t len, st | |||
| return CreateRequest(handler, connection, request, HttpMethod::HM_POST); | |||
| } | |||
| Status HttpClient::Get(const std::string &url, std::shared_ptr<std::vector<char>> output, | |||
| const std::map<std::string, std::string> &headers) { | |||
| ResponseCode HttpClient::Get(const std::string &url, std::shared_ptr<std::vector<char>> output, | |||
| const std::map<std::string, std::string> &headers) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| auto handler = std::make_shared<HttpMessageHandler>(); | |||
| output->clear(); | |||
| @@ -80,7 +81,7 @@ Status HttpClient::Get(const std::string &url, std::shared_ptr<std::vector<char> | |||
| evhttp_connection_base_new(event_base_, dns_base_, handler->GetHostByUri(), handler->GetUriPort()); | |||
| if (connection == nullptr) { | |||
| MS_LOG(ERROR) << "Create http connection failed!"; | |||
| return Status::BADREQUEST; | |||
| return ResponseCode::BADREQUEST; | |||
| } | |||
| AddHeaders(headers, request, handler); | |||
| @@ -178,8 +179,9 @@ void HttpClient::InitRequest(std::shared_ptr<HttpMessageHandler> handler, const | |||
| << ", The port is:" << handler->GetUriPort() << ", The request_url is:" << handler->GetRequestPath(); | |||
| } | |||
| Status HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler, struct evhttp_connection *connection, | |||
| struct evhttp_request *request, HttpMethod method) { | |||
| ResponseCode HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler, | |||
| struct evhttp_connection *connection, struct evhttp_request *request, | |||
| HttpMethod method) { | |||
| MS_EXCEPTION_IF_NULL(handler); | |||
| MS_EXCEPTION_IF_NULL(connection); | |||
| MS_EXCEPTION_IF_NULL(request); | |||
| @@ -188,18 +190,18 @@ Status HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler, st | |||
| if (evhttp_make_request(connection, request, evhttp_cmd_type(method), handler->GetRequestPath().c_str()) != 0) { | |||
| MS_LOG(ERROR) << "Make request failed!"; | |||
| return Status::INTERNAL; | |||
| return ResponseCode::INTERNAL; | |||
| } | |||
| if (!Start()) { | |||
| MS_LOG(ERROR) << "Start http client failed!"; | |||
| return Status::INTERNAL; | |||
| return ResponseCode::INTERNAL; | |||
| } | |||
| if (handler->request()) { | |||
| return Status(evhttp_request_get_response_code(handler->request())); | |||
| return ResponseCode(evhttp_request_get_response_code(handler->request())); | |||
| } | |||
| return Status::INTERNAL; | |||
| return ResponseCode::INTERNAL; | |||
| } | |||
| bool HttpClient::Start() { | |||
| @@ -47,7 +47,7 @@ namespace ps { | |||
| namespace core { | |||
| enum class HttpMethod { HM_GET = 1 << 0, HM_POST = 1 << 1 }; | |||
| enum class Status : int { | |||
| enum class ResponseCode : int { | |||
| OK = 200, // request completed ok | |||
| BADREQUEST = 400, // invalid http request was made | |||
| NOTFOUND = 404, // could not find content for uri | |||
| @@ -62,10 +62,10 @@ class HttpClient { | |||
| virtual ~HttpClient(); | |||
| Status Post(const std::string &url, const void *body, size_t len, std::shared_ptr<std::vector<char>> output, | |||
| const std::map<std::string, std::string> &headers = {}); | |||
| Status Get(const std::string &url, std::shared_ptr<std::vector<char>> output, | |||
| const std::map<std::string, std::string> &headers = {}); | |||
| ResponseCode Post(const std::string &url, const void *body, size_t len, std::shared_ptr<std::vector<char>> output, | |||
| const std::map<std::string, std::string> &headers = {}); | |||
| ResponseCode Get(const std::string &url, std::shared_ptr<std::vector<char>> output, | |||
| const std::map<std::string, std::string> &headers = {}); | |||
| void set_connection_timeout(const int &timeout); | |||
| @@ -80,8 +80,8 @@ class HttpClient { | |||
| std::shared_ptr<HttpMessageHandler> handler); | |||
| void InitRequest(std::shared_ptr<HttpMessageHandler> handler, const std::string &url, | |||
| const struct evhttp_request *request); | |||
| Status CreateRequest(std::shared_ptr<HttpMessageHandler> handler, struct evhttp_connection *connection, | |||
| struct evhttp_request *request, HttpMethod method); | |||
| ResponseCode CreateRequest(std::shared_ptr<HttpMessageHandler> handler, struct evhttp_connection *connection, | |||
| struct evhttp_request *request, HttpMethod method); | |||
| bool Start(); | |||
| void Init(); | |||
| @@ -87,6 +87,43 @@ void HttpMessageHandler::ParsePostParam() { | |||
| } | |||
| } | |||
| RequestProcessResult HttpMessageHandler::ParsePostMessageToJson() { | |||
| MS_EXCEPTION_IF_NULL(event_request_); | |||
| RequestProcessResult result(RequestProcessResultCode::kSuccess); | |||
| std::string message; | |||
| size_t len = evbuffer_get_length(event_request_->input_buffer); | |||
| if (len == 0) { | |||
| ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message size is invalid."); | |||
| return result; | |||
| } else if (len > kMaxMessageSize) { | |||
| ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "The post message is bigger than 100mb."); | |||
| return result; | |||
| } else { | |||
| message.resize(len); | |||
| auto buffer = evbuffer_pullup(event_request_->input_buffer, -1); | |||
| if (buffer == nullptr) { | |||
| ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Get http post message failed."); | |||
| return result; | |||
| } | |||
| size_t dest_size = len; | |||
| size_t src_size = len; | |||
| if (memcpy_s(message.data(), dest_size, buffer, src_size) != EOK) { | |||
| ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Copy message failed."); | |||
| return result; | |||
| } | |||
| try { | |||
| request_message_ = nlohmann::json::parse(message); | |||
| } catch (nlohmann::json::exception &e) { | |||
| std::string illegal_exception = e.what(); | |||
| ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, "Illegal JSON format:" + illegal_exception); | |||
| return result; | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| std::string HttpMessageHandler::GetPostParam(const std::string &key) { | |||
| if (!post_param_parsed_) { | |||
| ParsePostParam(); | |||
| @@ -204,7 +241,7 @@ void HttpMessageHandler::SetRespCode(int code) { resp_code_ = code; } | |||
| void HttpMessageHandler::SendResponse() { | |||
| MS_EXCEPTION_IF_NULL(event_request_); | |||
| MS_EXCEPTION_IF_NULL(resp_buf_); | |||
| evhttp_send_reply(event_request_, resp_code_, nullptr, resp_buf_); | |||
| evhttp_send_reply(event_request_, resp_code_, "Client", resp_buf_); | |||
| } | |||
| void HttpMessageHandler::QuickResponse(int code, const unsigned char *body, size_t len) { | |||
| @@ -226,6 +263,14 @@ void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, co | |||
| evhttp_send_reply(event_request_, resp_code_, nullptr, resp_buf_); | |||
| } | |||
| void HttpMessageHandler::ErrorResponse(int code, RequestProcessResult result) { | |||
| nlohmann::json error_json = {{"error_message", result.StatusMessage()}}; | |||
| std::string out_error = error_json.dump(); | |||
| AddRespString(out_error); | |||
| SetRespCode(code); | |||
| SendResponse(); | |||
| } | |||
| void HttpMessageHandler::RespError(int nCode, const std::string &message) { | |||
| MS_EXCEPTION_IF_NULL(event_request_); | |||
| if (message.empty()) { | |||
| @@ -269,6 +314,25 @@ void HttpMessageHandler::InitBodySize() { body_->resize(content_len()); } | |||
| std::shared_ptr<std::vector<char>> HttpMessageHandler::body() { return body_; } | |||
| void HttpMessageHandler::set_body(std::shared_ptr<std::vector<char>> body) { body_ = body; } | |||
| const nlohmann::json &HttpMessageHandler::request_message() const { return request_message_; } | |||
| RequestProcessResult HttpMessageHandler::ParseValueFromKey(const std::string &key, int32_t *const value) { | |||
| RequestProcessResult result(RequestProcessResultCode::kSuccess); | |||
| if (!request_message_.contains(key)) { | |||
| std::string message = "The json is not contain the key:" + key; | |||
| ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message); | |||
| return result; | |||
| } | |||
| int32_t res = request_message_.at(key); | |||
| if (res < 0) { | |||
| std::string message = "The value should not be less than 0."; | |||
| ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message); | |||
| return result; | |||
| } | |||
| *value = res; | |||
| return result; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -36,6 +36,9 @@ | |||
| #include "ps/core/comm_util.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/core/communicator/request_process_result_code.h" | |||
| #include "nlohmann/json.hpp" | |||
| #include "ps/constants.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -92,11 +95,13 @@ class HttpMessageHandler { | |||
| void SendResponse(); | |||
| void QuickResponse(int code, const unsigned char *body, size_t len); | |||
| void SimpleResponse(int code, const HttpHeaders &headers, const std::string &body); | |||
| void ErrorResponse(int code, RequestProcessResult status); | |||
| // If message is empty, libevent will use default error code message instead | |||
| void RespError(int nCode, const std::string &message); | |||
| // Body length should no more than MAX_POST_BODY_LEN, default 64kB | |||
| void ParsePostParam(); | |||
| RequestProcessResult ParsePostMessageToJson(); | |||
| void ReceiveMessage(const void *buffer, size_t num); | |||
| void set_content_len(const uint64_t &len); | |||
| uint64_t content_len(); | |||
| @@ -107,6 +112,8 @@ class HttpMessageHandler { | |||
| void InitBodySize(); | |||
| VectorPtr body(); | |||
| void set_body(VectorPtr body); | |||
| const nlohmann::json &request_message() const; | |||
| RequestProcessResult ParseValueFromKey(const std::string &key, int32_t *const value); | |||
| private: | |||
| struct evhttp_request *event_request_; | |||
| @@ -123,6 +130,7 @@ class HttpMessageHandler { | |||
| uint64_t content_len_; | |||
| struct event_base *event_base_; | |||
| uint64_t offset_; | |||
| nlohmann::json request_message_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -0,0 +1,102 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <sstream> | |||
| #include <iostream> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| enum class RequestProcessResultCode { kSuccess = 0, kSystemError = 1, kInvalidInputs = 2 }; | |||
| class LogStream { | |||
| public: | |||
| LogStream() { sstream_ = std::make_shared<std::stringstream>(); } | |||
| ~LogStream() = default; | |||
| template <typename T> | |||
| LogStream &operator<<(const T &val) noexcept { | |||
| (*sstream_) << val; | |||
| return *this; | |||
| } | |||
| template <typename T> | |||
| LogStream &operator<<(const std::vector<T> &val) noexcept { | |||
| (*sstream_) << "["; | |||
| for (size_t i = 0; i < val.size(); i++) { | |||
| (*this) << val[i]; | |||
| if (i + 1 < val.size()) { | |||
| (*sstream_) << ", "; | |||
| } | |||
| } | |||
| (*sstream_) << "]"; | |||
| return *this; | |||
| } | |||
| LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept { | |||
| (*sstream_) << func; | |||
| return *this; | |||
| } | |||
| const std::shared_ptr<std::stringstream> &stream() const { return sstream_; } | |||
| private: | |||
| std::shared_ptr<std::stringstream> sstream_; | |||
| }; | |||
| /* This class encapsulates user defined messages and user defined result codes, used to return http response message. | |||
| * | |||
| */ | |||
| class RequestProcessResult { | |||
| public: | |||
| RequestProcessResult() : code_(RequestProcessResultCode::kSystemError) {} | |||
| explicit RequestProcessResult(enum RequestProcessResultCode code, const std::string &msg = "") | |||
| : code_(code), msg_(msg) {} | |||
| ~RequestProcessResult() = default; | |||
| bool IsSuccess() const { return code_ == RequestProcessResultCode::kSuccess; } | |||
| enum RequestProcessResultCode ResultCode() const { return code_; } | |||
| std::string StatusMessage() const { return msg_; } | |||
| bool operator==(const RequestProcessResult &other) const { return code_ == other.code_; } | |||
| bool operator==(enum RequestProcessResultCode other_code) const { return code_ == other_code; } | |||
| bool operator!=(const RequestProcessResult &other) const { return code_ != other.code_; } | |||
| bool operator!=(enum RequestProcessResultCode other_code) const { return code_ != other_code; } | |||
| operator bool() const = delete; | |||
| RequestProcessResult &operator<(const LogStream &stream) noexcept __attribute__((visibility("default"))) { | |||
| msg_ = stream.stream()->str(); | |||
| return *this; | |||
| } | |||
| RequestProcessResult &operator=(const std::string &message) noexcept __attribute__((visibility("default"))) { | |||
| msg_ = message; | |||
| return *this; | |||
| } | |||
| private: | |||
| enum RequestProcessResultCode code_; | |||
| std::string msg_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_ | |||