Browse Source

!28765 Reject repeat registrations of alive nodes for ps

Merge pull request !28765 from zyli2020/parameter_server_cache
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
922af56db1
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 282 additions and 10 deletions
  1. +2
    -0
      mindspore/ccsrc/ps/CMakeLists.txt
  2. +1
    -1
      mindspore/ccsrc/ps/core/abstract_node.h
  3. +43
    -0
      mindspore/ccsrc/ps/core/ps_scheduler_node.h
  4. +55
    -0
      mindspore/ccsrc/ps/core/ps_server_node.cc
  5. +43
    -0
      mindspore/ccsrc/ps/core/ps_server_node.h
  6. +55
    -0
      mindspore/ccsrc/ps/core/ps_worker_node.cc
  7. +43
    -0
      mindspore/ccsrc/ps/core/ps_worker_node.h
  8. +14
    -0
      mindspore/ccsrc/ps/core/scheduler_node.cc
  9. +4
    -0
      mindspore/ccsrc/ps/core/scheduler_node.h
  10. +1
    -1
      mindspore/ccsrc/ps/parameter_server.cc
  11. +1
    -1
      mindspore/ccsrc/ps/parameter_server.h
  12. +3
    -3
      mindspore/ccsrc/ps/scheduler.cc
  13. +15
    -2
      mindspore/ccsrc/ps/scheduler.h
  14. +2
    -2
      mindspore/ccsrc/ps/worker.h

+ 2
- 0
mindspore/ccsrc/ps/CMakeLists.txt View File

@@ -16,7 +16,9 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/ps_worker_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/ps_server_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_client.cc")


+ 1
- 1
mindspore/ccsrc/ps/core/abstract_node.h View File

@@ -150,7 +150,7 @@ class AbstractNode : public Node {
const std::shared_ptr<TaskExecutor> &task_executor);

protected:
void Register(const std::shared_ptr<TcpClient> &client);
virtual void Register(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
void FetchServers(const std::shared_ptr<TcpClient> &client);



+ 43
- 0
mindspore/ccsrc/ps/core/ps_scheduler_node.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_

#include "ps/core/scheduler_node.h"
#include "ps/core/node_info.h"

namespace mindspore {
namespace ps {
namespace core {
// This class is a derived class of SchedulerNode specialized for Parameter Server. It is used to rewrite the specific
// logic for Parameter Server mode training in SchedulerNode. For example, the Scheduler of Parameter Server will reject
// the registration request of alive nodes.
class PSSchedulerNode : public SchedulerNode {
public:
PSSchedulerNode() = default;
~PSSchedulerNode() override = default;

private:
// Determine whether the registration request of the node should be rejected, the registration of the
// alive node should be rejected.
bool NeedRejectRegister(const NodeInfo &node_info) override { return node_info.is_alive; }
};
} // namespace core
} // namespace ps
} // namespace mindspore

#endif // MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_

+ 55
- 0
mindspore/ccsrc/ps/core/ps_server_node.cc View File

@@ -0,0 +1,55 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ps/core/ps_server_node.h"

namespace mindspore {
namespace ps {
namespace core {
void PSServerNode::Register(const std::shared_ptr<TcpClient> &client) {
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
MS_EXCEPTION_IF_NULL(message_meta);
message_meta->set_cmd(NodeCommand::REGISTER);
message_meta->set_rank_id(node_info_.rank_id_);

RegisterMessage register_message;
register_message.set_node_id(node_info_.node_id_);
register_message.set_role(node_info_.node_role_);
register_message.set_ip(node_info_.ip_);
register_message.set_port(node_info_.port_);

MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";

if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
register_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
} else {
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send register success!";
}

// Registrations of alive nodes or registrations request exceeding the set total number of nodes should be
// rejected, and the process exits after being rejected.
if (node_info_.rank_id_ == UINT32_MAX) {
MS_LOG(EXCEPTION) << "Register is rejected, and finish the node.";
}
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 43
- 0
mindspore/ccsrc/ps/core/ps_server_node.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PS_CORE_PS_SERVER_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_PS_SERVER_NODE_H_

#include <memory>
#include "ps/core/server_node.h"

namespace mindspore {
namespace ps {
namespace core {
// This class is a derived class of ServerNode specialized for Parameter Server. It is used to rewrite the logic
// specific to Parameter Server mode training in ServerNode. For example, the registration of Parameter Server's Server
// node is synchronous.
class PSServerNode : public ServerNode {
public:
PSServerNode() = default;
~PSServerNode() override = default;

private:
// The Server node registers to the Scheduler node, and the registration of the Server node of the Parameter Server
// is synchronous.
void Register(const std::shared_ptr<TcpClient> &client) override;
};
} // namespace core
} // namespace ps
} // namespace mindspore

#endif // MINDSPORE_CCSRC_PS_CORE_PS_SERVER_NODE_H_

+ 55
- 0
mindspore/ccsrc/ps/core/ps_worker_node.cc View File

@@ -0,0 +1,55 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ps/core/ps_worker_node.h"

namespace mindspore {
namespace ps {
namespace core {
void PSWorkerNode::Register(const std::shared_ptr<TcpClient> &client) {
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
MS_EXCEPTION_IF_NULL(message_meta);
message_meta->set_cmd(NodeCommand::REGISTER);
message_meta->set_rank_id(node_info_.rank_id_);

RegisterMessage register_message;
register_message.set_node_id(node_info_.node_id_);
register_message.set_role(node_info_.node_role_);
register_message.set_ip(node_info_.ip_);
register_message.set_port(node_info_.port_);

MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!";

if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
register_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!";
} else {
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send register success!";
}

// Registrations of alive nodes or registrations request exceeding the set total number of nodes should be
// rejected, and the process exits after being rejected.
if (node_info_.rank_id_ == UINT32_MAX) {
MS_LOG(EXCEPTION) << "Register is rejected, and finish the node.";
}
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 43
- 0
mindspore/ccsrc/ps/core/ps_worker_node.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PS_CORE_PS_WORKER_NODE_H_
#define MINDSPORE_CCSRC_PS_CORE_PS_WORKER_NODE_H_

#include <memory>
#include "ps/core/worker_node.h"

namespace mindspore {
namespace ps {
namespace core {
// This class is a derived class of WorkerNode specialized for Parameter Server. It is used to rewrite the logic
// specific to Parameter Server mode training in WorkerNode. For example, the registration of Parameter Server's Worker
// node is synchronous.
class PSWorkerNode : public WorkerNode {
public:
PSWorkerNode() = default;
~PSWorkerNode() override = default;

private:
// The Worker node registers to the Scheduler node, and the registration of the Worker node of the Parameter Server
// is synchronous.
void Register(const std::shared_ptr<TcpClient> &client) override;
};
} // namespace core
} // namespace ps
} // namespace mindspore

#endif // MINDSPORE_CCSRC_PS_CORE_PS_WORKER_NODE_H_

+ 14
- 0
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -256,6 +256,20 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
MS_LOG(INFO) << "The node id:" << node_id << " is registering to scheduler.";
client_mutex_.lock();
if (node_manager_.IsNodeRegistered(node_id)) {
NodeInfo node_info = node_manager_.QueryNodeInfo(node_id);
if (NeedRejectRegister(node_info)) {
MS_LOG(WARNING) << "The node(id: " << node_id << ") is alive, register is rejected!";
RegisterRespMessage register_rejected_message;
register_rejected_message.set_node_id(node_id);
register_rejected_message.set_rank_id(UINT32_MAX);
if (!server->SendMessage(conn, meta, Protos::PROTOBUF, register_rejected_message.SerializeAsString().data(),
register_rejected_message.ByteSizeLong())) {
MS_LOG(WARNING) << "Server response rejected message failed.";
}
client_mutex_.unlock();
return;
}

MS_LOG(INFO) << "The node id is registered.";
if (connected_nodes_.count(node_id)) {
(void)connected_nodes_.erase(node_id);


+ 4
- 0
mindspore/ccsrc/ps/core/scheduler_node.h View File

@@ -103,6 +103,10 @@ class SchedulerNode : public Node {
void ProcessSendEvent(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);

// Determine whether the registration request of the node should be rejected, the registration of the
// alive node should be rejected.
virtual bool NeedRejectRegister(const NodeInfo &node_info) { return false; }

// After scheduler collects all registered message, it actively sends finish to the node connected by the client.
void SendMetadata(const std::shared_ptr<TcpClient> &client, uint32_t rank_id);
// After scheduler collects all finish message, it actively sends finish to the node connected by the client.


+ 1
- 1
mindspore/ccsrc/ps/parameter_server.cc View File

@@ -29,7 +29,7 @@ static const uint32_t kCPUCoreNum = std::thread::hardware_concurrency();
void ParameterServer::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
server_node_ = std::make_shared<core::ServerNode>();
server_node_ = std::make_shared<core::PSServerNode>();

MS_LOG(INFO) << "PServer connected successfully.";
if (!PSContext::instance()->is_server()) {


+ 1
- 1
mindspore/ccsrc/ps/parameter_server.h View File

@@ -61,7 +61,7 @@
#include "utils/log_adapter.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/server_node.h"
#include "ps/core/ps_server_node.h"
#include "ps/core/node.h"

namespace mindspore {


+ 3
- 3
mindspore/ccsrc/ps/scheduler.cc View File

@@ -24,15 +24,15 @@ void Scheduler::Run() {
PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port();
PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num();
PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num();
if (!scheduler_node_.Start()) {
if (!scheduler_node_->Start()) {
MS_LOG(WARNING) << "Scheduler start failed.";
}
if (!scheduler_node_.Finish()) {
if (!scheduler_node_->Finish()) {
MS_LOG(WARNING) << "Scheduler finis failed.";
}
if (!scheduler_node_.Stop()) {
if (!scheduler_node_->Stop()) {
MS_LOG(WARNING) << "Scheduler stop failed.";
}
exit(1);


+ 15
- 2
mindspore/ccsrc/ps/scheduler.h View File

@@ -17,7 +17,9 @@
#ifndef MINDSPORE_CCSRC_PS_SCHEDULER_H_
#define MINDSPORE_CCSRC_PS_SCHEDULER_H_
#include <memory>
#include "ps/core/scheduler_node.h"
#include "ps/core/ps_scheduler_node.h"
#include "ps/util.h"
#include "ps/ps_context.h"
@@ -33,11 +35,22 @@ class Scheduler {
void Run();
private:
Scheduler() = default;
Scheduler() {
if (scheduler_node_ == nullptr) {
bool is_fl_mode = PSContext::instance()->server_mode() == ps::kServerModeFL ||
PSContext::instance()->server_mode() == ps::kServerModeHybrid;
if (is_fl_mode) {
scheduler_node_ = std::make_unique<core::SchedulerNode>();
} else {
scheduler_node_ = std::make_unique<core::PSSchedulerNode>();
}
}
}
~Scheduler() = default;
Scheduler(const Scheduler &) = delete;
Scheduler &operator=(const Scheduler &) = delete;
core::SchedulerNode scheduler_node_;
std::unique_ptr<core::SchedulerNode> scheduler_node_;
};
} // namespace ps
} // namespace mindspore


+ 2
- 2
mindspore/ccsrc/ps/worker.h View File

@@ -35,7 +35,7 @@
#include "ps/constants.h"
#include "utils/shape_utils.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/core/worker_node.h"
#include "ps/core/ps_worker_node.h"
#include "ps/embedding_table_shard_metadata.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
@@ -135,7 +135,7 @@ class Worker {
std::map<size_t, int64_t> key_to_optimId_;
std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
std::map<std::string, bool> param_to_init_in_server_;
core::WorkerNode worker_node_;
core::PSWorkerNode worker_node_;

EmbeddingPartitioner lookup_partitioner_;
KVPartitioner sparse_partitioner_;


Loading…
Cancel
Save