Browse Source

add tcp communicator

pull/15803/head
chendongsheng 4 years ago
parent
commit
ddb54ae0ac
7 changed files with 253 additions and 3 deletions
  1. +1
    -0
      mindspore/ccsrc/ps/CMakeLists.txt
  2. +2
    -2
      mindspore/ccsrc/ps/core/communicator/communicator_base.h
  3. +1
    -1
      mindspore/ccsrc/ps/core/communicator/task_executor.h
  4. +91
    -0
      mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc
  5. +111
    -0
      mindspore/ccsrc/ps/core/communicator/tcp_communicator.h
  6. +32
    -0
      mindspore/ccsrc/ps/core/server_node.cc
  7. +15
    -0
      mindspore/ccsrc/ps/core/server_node.h

+ 1
- 0
mindspore/ccsrc/ps/CMakeLists.txt View File

@@ -14,6 +14,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")


+ 2
- 2
mindspore/ccsrc/ps/core/communicator/communicator_base.h View File

@@ -40,9 +40,9 @@ class CommunicatorBase {
using MessageCallback = std::function<void(std::shared_ptr<MessageHandler>)>;
using HttpMsgCallback = std::function<void(std::shared_ptr<HttpMessageHandler>)>;
using OnNodeEventCallback = std::function<void(const NodeEvent &)>;
using TcpMsgCallBack = std::function<void(std::shared_ptr<core::TcpConnection> conn,
using TcpMsgCallback = std::function<void(std::shared_ptr<core::TcpConnection> conn,
std::shared_ptr<core::MessageMeta> meta, DataPtr data, size_t size)>;
using CertainEventCallBack = std::function<void(void)>;
using CertainEventCallback = std::function<void(void)>;

CommunicatorBase() = default;



+ 1
- 1
mindspore/ccsrc/ps/core/communicator/task_executor.h View File

@@ -50,7 +50,7 @@ class TaskExecutor {
bool Submit(Fun &&function, Args &&... args) {
auto callee = std::bind(function, args...);
std::function<void()> task = [callee]() -> void { callee(); };
auto index = 0;
size_t index = 0;
for (size_t i = 0; i < submit_timeout_; i++) {
std::unique_lock<std::mutex> lock(mtx_);
if (task_num_ >= max_task_num_) {


+ 91
- 0
mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc View File

@@ -0,0 +1,91 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ps/core/communicator/tcp_communicator.h"
#include <memory>

namespace mindspore {
namespace ps {
namespace core {
bool TcpCommunicator::Start() {
if (running_) {
MS_LOG(INFO) << "The TCP communicator has already started.";
return true;
}
MS_EXCEPTION_IF_NULL(server_node_);

// Set message callback. For example, message of push/pull, etc.
tcp_msg_callback_ = std::bind(
[&](std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data,
size_t size) -> void {
TcpUserCommand user_command = static_cast<TcpUserCommand>(meta->user_cmd());
const std::string &msg_type = kUserCommandToMsgType.at(user_command);
if (msg_type == "" || !msg_callbacks_[msg_type]) {
MS_LOG(ERROR) << "Tcp server doesn't support command " << user_command << " " << msg_type;
return;
}

MS_LOG(DEBUG) << "TcpCommunicator receives message for " << msg_type;
std::shared_ptr<MessageHandler> tcp_msg_handler =
std::make_shared<TcpMsgHandler>(server_node_, conn, meta, data, size);
MS_EXCEPTION_IF_NULL(tcp_msg_handler);
task_executor_->Submit(msg_callbacks_[msg_type], tcp_msg_handler);
return;
},
std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
server_node_->set_handler(tcp_msg_callback_);

// Set event callback. For example, event of scaling out/in, etc.
event_callback_ = std::bind(
[&](const core::NodeEvent &event) -> void {
MS_LOG(INFO) << "Server receives event of " << event;
certain_event_to_callback_[event]();
},
std::placeholders::_1);
server_node_->set_event_callback(event_callback_);

server_node_->Start();
running_ = true;
running_thread_ = std::thread([&]() {
while (running_) {
std::this_thread::yield();
}
});
return true;
}

bool TcpCommunicator::Stop() {
MS_EXCEPTION_IF_NULL(server_node_);
server_node_->Finish();
server_node_->Stop();
running_ = false;
return true;
}

void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) {
msg_callbacks_.try_emplace(msg_type, cb);
return;
}

void TcpCommunicator::RegisterEventCallback(const core::NodeEvent &event, const CertainEventCallback &event_cb) {
certain_event_to_callback_.try_emplace(event, event_cb);
return;
}

ServerNode *TcpCommunicator::server_node() { return server_node_; }
} // namespace core
} // namespace ps
} // namespace mindspore

+ 111
- 0
mindspore/ccsrc/ps/core/communicator/tcp_communicator.h View File

@@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_

#include <map>
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include "proto/ps.pb.h"
#include "ps/core/server_node.h"
#include "ps/core/cluster_metadata.h"
#include "ps/ps_context.h"
#include "ps/core/communicator/task_executor.h"
#include "ps/core/communicator/communicator_base.h"
#include "ps/core/communicator/tcp_msg_handler.h"

namespace mindspore {
namespace ps {
namespace core {
enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent };

const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"},
{TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"},
{TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"},
{TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"},
};

class TcpCommunicator : public CommunicatorBase {
public:
explicit TcpCommunicator(const std::shared_ptr<TaskExecutor> &task_executor, ServerNode *node)
: task_executor_(task_executor),
running_(false),
server_num_(0),
worker_num_(0),
scheduler_ip_(""),
scheduler_port_(0),
server_node_(node) {}
~TcpCommunicator() = default;

bool Start() override;
bool Stop() override;

void RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) override;
void RegisterEventCallback(const core::NodeEvent &event, const CertainEventCallback &event_cb);

ServerNode *server_node();

template <class T>
bool SendPbRequest(const T &pb_msg, const uint32_t &rank_id, TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output = nullptr) {
const std::string &msg_str = pb_msg.SerializeAsString();
std::shared_ptr<unsigned char[]> msg(new unsigned char[msg_str.size()]);
size_t dest_size = msg_str.size();
size_t src_size = msg_str.size();
auto ret = memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy_s error, error no " << ret;
}

if (output != nullptr) {
if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command), output)) {
MS_LOG(ERROR) << "Query leader server whether count is enough failed.";
return false;
}
} else {
if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command))) {
MS_LOG(ERROR) << "Query leader server whether count is enough failed.";
return false;
}
}
return true;
}

private:
std::shared_ptr<TaskExecutor> task_executor_;
bool running_;

TcpMsgCallback tcp_msg_callback_;
OnNodeEventCallback event_callback_;
// Each NodeEvent corresponds to a CertainEventCallback to process the event.
std::map<core::NodeEvent, CertainEventCallback> certain_event_to_callback_;

uint32_t server_num_;
uint32_t worker_num_;

std::string scheduler_ip_;
uint16_t scheduler_port_;

ServerNode *server_node_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_

+ 32
- 0
mindspore/ccsrc/ps/core/server_node.cc View File

@@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "ps/core/server_node.h"
#include "ps/core/communicator/tcp_communicator.h"
#include "ps/core/communicator/http_communicator.h"

namespace mindspore {
namespace ps {
@@ -124,6 +126,36 @@ void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn,
server_->SendMessage(conn, meta, Protos::RAW, data, size);
}

std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::string &ip, std::int16_t port,
const std::shared_ptr<TaskExecutor> &task_executor) {
std::lock_guard<std::mutex> lock(communicator_mutex_);
if (!communicators_.count(kHttpCommunicator)) {
MS_LOG(INFO) << "Create Http communicator.";
auto http_comm = std::make_shared<HttpCommunicator>(ip, port, task_executor);
MS_EXCEPTION_IF_NULL(http_comm);
communicators_[kHttpCommunicator] = http_comm;
}
return communicators_[kHttpCommunicator];
}

std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateTcpComm(const std::string &scheduler_ip,
std::int16_t scheduler_port, uint32_t worker_num,
uint32_t server_num,
const std::shared_ptr<TaskExecutor> &task_executor) {
std::lock_guard<std::mutex> lock(communicator_mutex_);
if (!communicators_.count(kTcpCommunicator)) {
MS_LOG(INFO) << "Create Tcp communicator.";
auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this);
MS_EXCEPTION_IF_NULL(tcp_comm);
ClusterMetadata::instance()->Init(worker_num, server_num, scheduler_ip, scheduler_port);
MS_LOG(INFO) << "Initialize cluster metadata for server. Worker number:" << worker_num
<< ", Server number:" << server_num << ", Scheduler ip:" << scheduler_ip
<< ", Scheduler port:" << scheduler_port;
communicators_[kTcpCommunicator] = tcp_comm;
}
return communicators_[kTcpCommunicator];
}

bool ServerNode::Stop() {
MS_LOG(INFO) << "Stop server node!";
if (!is_already_stopped_.load()) {


+ 15
- 0
mindspore/ccsrc/ps/core/server_node.h View File

@@ -24,18 +24,25 @@
#include <thread>
#include <utility>
#include <vector>
#include <unordered_map>

#include "ps/core/cluster_metadata.h"
#include "ps/core/communicator/tcp_client.h"
#include "ps/core/communicator/tcp_server.h"
#include "ps/core/abstract_node.h"
#include "ps/core/communicator/task_executor.h"
#include "ps/core/communicator/communicator_base.h"

namespace mindspore {
namespace ps {
namespace core {
constexpr char kTcpCommunicator[] = "TCP";
constexpr char kHttpCommunicator[] = "HTTP";

class ServerNode : public AbstractNode {
public:
ServerNode() : server_(nullptr), server_thread_(nullptr) {}

~ServerNode() override = default;

bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override;
@@ -48,6 +55,12 @@ class ServerNode : public AbstractNode {
void set_handler(const RequestHandler &handler);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data, size_t size);

std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, std::int16_t port,
const std::shared_ptr<TaskExecutor> &task_executor);
std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port,
uint32_t worker_num, uint32_t server_num,
const std::shared_ptr<TaskExecutor> &task_executor);

private:
void CreateTcpServer();
void Initialize();
@@ -59,6 +72,8 @@ class ServerNode : public AbstractNode {
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;
RequestHandler request_handler_;
std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
std::mutex communicator_mutex_;
};
} // namespace core
} // namespace ps


Loading…
Cancel
Save