You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.h 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
  17. #define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "ps/core/communicator/communicator_base.h"
  22. #include "ps/core/communicator/tcp_communicator.h"
  23. #include "ps/core/communicator/task_executor.h"
  24. #include "ps/server/common.h"
  25. #include "ps/server/executor.h"
  26. #include "ps/server/iteration.h"
  27. namespace mindspore {
  28. namespace ps {
  29. namespace server {
  30. // Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
  31. class Server {
  32. public:
  33. static Server &GetInstance() {
  34. static Server instance;
  35. return instance;
  36. }
  37. void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
  38. const FuncGraphPtr &func_graph, size_t executor_threshold);
  39. // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
  40. // blocked until the server is finalized.
  41. // func_graph is the frontend graph which will be parse in server's exector and aggregator.
  42. void Run();
  43. private:
  44. Server()
  45. : server_node_(nullptr),
  46. task_executor_(nullptr),
  47. use_tcp_(false),
  48. use_http_(false),
  49. http_port_(0),
  50. func_graph_(nullptr),
  51. executor_threshold_(0),
  52. communicator_with_server_(nullptr),
  53. communicators_with_worker_({}),
  54. iteration_(nullptr),
  55. scheduler_ip_(""),
  56. scheduler_port_(0),
  57. server_num_(0),
  58. worker_num_(0) {}
  59. ~Server() = default;
  60. Server(const Server &) = delete;
  61. Server &operator=(const Server &) = delete;
  62. // Load variables which is set by ps_context.
  63. void InitServerContext();
  64. // Initialize the server cluster, server node and communicators.
  65. void InitCluster();
  66. bool InitCommunicatorWithServer();
  67. bool InitCommunicatorWithWorker();
  68. // Initialize iteration with rounds. Which rounds to use could be set by ps_context as well.
  69. void InitIteration();
  70. // Initialize executor according to the server mode.
  71. void InitExecutor();
  72. // Create round kernels and bind these kernels with corresponding Round.
  73. void RegisterRoundKernel();
  74. // The communicators should be started after all initializations are completed.
  75. void StartCommunicator();
  76. // The server node is initialized in Server.
  77. std::shared_ptr<core::ServerNode> server_node_;
  78. // The task executor of the communicators. This helps server to handle network message concurrently. The tasks
  79. // submitted to this task executor is asynchronous.
  80. std::shared_ptr<core::TaskExecutor> task_executor_;
  81. // Which protocol should communicators use.
  82. bool use_tcp_;
  83. bool use_http_;
  84. uint64_t http_port_;
  85. // The configure of all rounds.
  86. std::vector<RoundConfig> rounds_config_;
  87. // The graph passed by the frontend without backend optimizing.
  88. FuncGraphPtr func_graph_;
  89. // The threshold count for executor to do aggregation or optimizing.
  90. size_t executor_threshold_;
  91. // Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
  92. // operations, etc.
  93. std::shared_ptr<core::CommunicatorBase> communicator_with_server_;
  94. // The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
  95. // In some cases, both types should be supported in one distributed training job. So here we may have multiple
  96. // communicators.
  97. std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
  98. // Iteration consists of multiple kinds of rounds.
  99. std::shared_ptr<Iteration> iteration_;
  100. // Variables set by ps context.
  101. std::string scheduler_ip_;
  102. uint16_t scheduler_port_;
  103. uint32_t server_num_;
  104. uint32_t worker_num_;
  105. };
  106. } // namespace server
  107. } // namespace ps
  108. } // namespace mindspore
  109. #endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_