Browse Source

added retry and unit test

tags/v1.2.0-rc1
chendongsheng 5 years ago
parent
commit
b289c6184a
16 changed files with 109 additions and 19 deletions
  1. +5
    -4
      mindspore/ccsrc/ps/core/abstract_node.cc
  2. +10
    -0
      mindspore/ccsrc/ps/core/comm_util.cc
  3. +2
    -0
      mindspore/ccsrc/ps/core/comm_util.h
  4. +1
    -1
      mindspore/ccsrc/ps/core/node.h
  5. +4
    -2
      mindspore/ccsrc/ps/core/node_manager.cc
  6. +3
    -0
      mindspore/ccsrc/ps/core/node_manager.h
  7. +1
    -1
      mindspore/ccsrc/ps/core/scheduler_node.h
  8. +1
    -1
      mindspore/ccsrc/ps/core/server_node.cc
  9. +1
    -1
      mindspore/ccsrc/ps/core/server_node.h
  10. +10
    -5
      mindspore/ccsrc/ps/core/tcp_client.cc
  11. +4
    -0
      mindspore/ccsrc/ps/core/tcp_client.h
  12. +2
    -3
      mindspore/ccsrc/ps/core/tcp_server.cc
  13. +1
    -0
      mindspore/ccsrc/ps/core/tcp_server.h
  14. +1
    -1
      mindspore/ccsrc/ps/core/worker_node.h
  15. +43
    -0
      tests/ut/cpp/ps/core/cluster_available_timeout_test.cc
  16. +20
    -0
      tests/ut/cpp/ps/core/common_util_test.cc

+ 5
- 4
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -33,11 +33,12 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(register_message.SerializeAsString());
if (!SendMessageSync(client, comm_message)) {
MS_LOG(EXCEPTION) << "Node register timeout!";
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
}

MS_LOG(INFO) << "The node id:" << node_info_.node_id_
<< "is registering to scheduler, the request id is:" << message_meta.request_id();
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is registering to scheduler!";
}

void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
@@ -395,7 +396,7 @@ bool AbstractNode::InitClientToScheduler() {

client_to_scheduler_->Init();
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The worker node start a tcp client!";
MS_LOG(INFO) << "The node start a tcp client!";
client_to_scheduler_->Start();
});



+ 10
- 0
mindspore/ccsrc/ps/core/comm_util.cc View File

@@ -129,6 +129,16 @@ bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &ra
}
return true;
}

bool CommUtil::Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds) {
for (size_t attempt = 0; attempt < max_attempts; ++attempt) {
if (func()) {
return true;
}
std::this_thread::sleep_for(std::chrono::milliseconds(interval_milliseconds));
}
return false;
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/ps/core/comm_util.h View File

@@ -45,6 +45,7 @@
#include <sstream>
#include <string>
#include <utility>
#include <thread>

#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
@@ -68,6 +69,7 @@ class CommUtil {
static std::string GenerateUUID();
static std::string NodeRoleToString(const NodeRole &role);
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id);
static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds);

private:
static std::random_device rd;


+ 1
- 1
mindspore/ccsrc/ps/core/node.h View File

@@ -57,7 +57,7 @@ class Node {
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
using MessageCallback = std::function<void()>;

virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0;
virtual bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) = 0;
virtual bool Stop() = 0;
virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0;



+ 4
- 2
mindspore/ccsrc/ps/core/node_manager.cc View File

@@ -105,7 +105,7 @@ void NodeManager::UpdateClusterState() {
}

// 2. update cluster finish state
if (finish_nodes_id_.size() == total_node_num_) {
if (finish_nodes_id_.size() == total_node_num_ || SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
is_cluster_finish_ = true;
is_cluster_ready_ = true;
}
@@ -119,7 +119,9 @@ void NodeManager::UpdateClusterState() {
void NodeManager::CheckClusterTimeout() {
if (total_node_num_ != nodes_info_.size()) {
MS_LOG(WARNING) << "The cluster is not ready after " << ClusterConfig::cluster_available_timeout()
<< " seconds,so finish the cluster";
<< " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
<< nodes_info_.size();
current_node_num_ = nodes_info_.size();
is_cluster_timeout_ = true;
}
}


+ 3
- 0
mindspore/ccsrc/ps/core/node_manager.h View File

@@ -35,6 +35,7 @@
#include "proto/ps.pb.h"
#include "ps/core/node.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"

namespace mindspore {
namespace ps {
@@ -47,6 +48,7 @@ class NodeManager {
is_cluster_timeout_(false),
is_node_timeout_(false),
total_node_num_(0),
current_node_num_(-1),
next_worker_rank_id_(-1),
next_server_rank_id_(-1) {}
virtual ~NodeManager() = default;
@@ -75,6 +77,7 @@ class NodeManager {
std::atomic<bool> is_cluster_timeout_;
std::atomic<bool> is_node_timeout_;
uint32_t total_node_num_;
int32_t current_node_num_;
std::atomic<int> next_worker_rank_id_;
std::atomic<int> next_server_rank_id_;
// worker nodes and server nodes


+ 1
- 1
mindspore/ccsrc/ps/core/scheduler_node.h View File

@@ -44,7 +44,7 @@ class SchedulerNode : public Node {
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
~SchedulerNode() override;

bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;



+ 1
- 1
mindspore/ccsrc/ps/core/server_node.cc View File

@@ -30,7 +30,7 @@ bool ServerNode::Start(const uint32_t &timeout) {
StartHeartbeatTimer(client_to_scheduler_);

if (!WaitForStart(timeout)) {
MS_LOG(EXCEPTION) << "Start Worker node timeout!";
MS_LOG(ERROR) << "Start Server node timeout!";
}
MS_LOG(INFO) << "The cluster is ready to use!";



+ 1
- 1
mindspore/ccsrc/ps/core/server_node.h View File

@@ -40,7 +40,7 @@ class ServerNode : public AbstractNode {
ServerNode() : server_(nullptr), server_thread_(nullptr) {}
~ServerNode() override;

bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;



+ 10
- 5
mindspore/ccsrc/ps/core/tcp_client.cc View File

@@ -36,6 +36,8 @@ namespace mindspore {
namespace ps {
namespace core {
event_base *TcpClient::event_base_ = nullptr;
std::mutex TcpClient::event_base_mutex_;
bool TcpClient::is_started_ = false;

TcpClient::TcpClient(const std::string &address, std::uint16_t port)
: event_timeout_(nullptr),
@@ -60,10 +62,6 @@ TcpClient::~TcpClient() {
event_free(event_timeout_);
event_timeout_ = nullptr;
}
if (event_base_) {
event_base_free(event_base_);
event_base_ = nullptr;
}
}

std::string TcpClient::GetServerAddress() const { return server_address_; }
@@ -234,6 +232,13 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void
}

void TcpClient::Start() {
event_base_mutex_.lock();
if (is_started_) {
event_base_mutex_.unlock();
return;
}
is_started_ = true;
event_base_mutex_.unlock();
MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_dispatch(event_base_);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
@@ -260,7 +265,7 @@ void TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
}


+ 4
- 0
mindspore/ccsrc/ps/core/tcp_client.h View File

@@ -35,6 +35,7 @@
#include "ps/core/cluster_config.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "utils/convert_utils_base.h"

namespace mindspore {
namespace ps {
@@ -86,6 +87,9 @@ class TcpClient {
OnTimer on_timer_callback_;

static event_base *event_base_;
static std::mutex event_base_mutex_;
static bool is_started_;

std::mutex connection_mutex_;
std::condition_variable connection_cond_;
event *event_timeout_;


+ 2
- 3
mindspore/ccsrc/ps/core/tcp_server.cc View File

@@ -32,7 +32,6 @@
namespace mindspore {
namespace ps {
namespace core {

void TcpConnection::InitConnection() {
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
@@ -58,7 +57,7 @@ void TcpConnection::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size,
sizeof(buf_size)) == -1) {
MS_LOG(EXCEPTION) << "Event buffer add header failed!";
@@ -304,7 +303,7 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
if (read == -1) {
MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
}
conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
conn->OnReadHandler(read_buffer, IntToSize(read));
}
}



+ 1
- 0
mindspore/ccsrc/ps/core/tcp_server.h View File

@@ -39,6 +39,7 @@
#include "ps/core/tcp_message_handler.h"
#include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"

namespace mindspore {
namespace ps {


+ 1
- 1
mindspore/ccsrc/ps/core/worker_node.h View File

@@ -40,7 +40,7 @@ class WorkerNode : public AbstractNode {
WorkerNode() = default;
~WorkerNode() override;

bool Start(const uint32_t &timeout = kTimeoutInSeconds) override;
bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;



+ 43
- 0
tests/ut/cpp/ps/core/cluster_available_timeout_test.cc View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/common_test.h"
#include "ps/core/node.h"
#include "ps/core/scheduler_node.h"

namespace mindspore {
namespace ps {
namespace core {
class TestClusterAvailableTimeout : public UT::Common {
public:
TestClusterAvailableTimeout() = default;
~TestClusterAvailableTimeout() override = default;

void SetUp() override {}
void TearDown() override {}
};

TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) {
ClusterConfig::Init(1, 1, std::make_unique<std::string>("127.0.0.1"), 9999);
ClusterConfig::set_cluster_available_timeout(3);
SchedulerNode node;
node.Start();
node.Finish();
node.Stop();
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 20
- 0
tests/ut/cpp/ps/core/common_util_test.cc View File

@@ -27,6 +27,18 @@ class TestCommUtil : public UT::Common {
public:
TestCommUtil() = default;
virtual ~TestCommUtil() = default;
struct MockRetry {
bool operator()(std::string mock) {
++count_;

if (count_ > 3) {
return true;
}
return false;
}

int count_{0};
};

void SetUp() override {}
void TearDown() override {}
@@ -47,6 +59,14 @@ TEST_F(TestCommUtil, ValidateRankId) {
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2));
}

TEST_F(TestCommUtil, Retry) {
bool const ret = CommUtil::Retry([]() -> bool { return false; }, 5, 100);
EXPECT_FALSE(ret);
MockRetry mock_retry;
bool const mock_ret = CommUtil::Retry([&] { return mock_retry("mock"); }, 5, 100);
EXPECT_TRUE(mock_ret);
}
} // namespace core
} // namespace ps
} // namespace mindspore

Loading…
Cancel
Save