You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_aggregator.h 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
  17. #define MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include <utility>
  23. #include "ps/server/common.h"
  24. #include "ps/server/memory_register.h"
  25. #include "ps/server/kernel/aggregation_kernel_factory.h"
  26. #include "ps/server/kernel/optimizer_kernel_factory.h"
  27. namespace mindspore {
  28. namespace ps {
  29. namespace server {
  30. // Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server
  31. // kernels.
  32. typedef struct {
  33. std::vector<AddressPtr> inputs;
  34. std::vector<AddressPtr> workspace;
  35. std::vector<AddressPtr> outputs;
  36. } KernelParams;
  37. // ParameterAggregator includes methods for aggregating gradients and optimizing weights(launching aggregation and
  38. // optimizer kernels), getting weights, etc. It's not thread-safe, which means the caller must acquire lock before
  39. // calling ParameterAggregator methods concurrently.
  40. // Each ParameterAggregator is corresponding to one weight for now.
  41. // ParameterAggregator is stateful because the process of aggregation and optimizing could be stateful.
  42. // For example, the finite-state machine for the ParameterAggregator in parameter server training mode is below:
  43. // Initial->Aggregating->Aggregation done->Optimizing->Optimizing done->Pulling->Pull done->Initial.
  44. class ParameterAggregator {
  45. public:
  46. ParameterAggregator()
  47. : server_mode_(ServerMode::PARAMETER_SERVER),
  48. required_push_count_(0),
  49. required_pull_count_(0),
  50. current_pull_count_(0),
  51. aggregation_done_(false),
  52. optimizing_done_(false),
  53. pulling_done_(true),
  54. memory_register_(nullptr) {}
  55. ~ParameterAggregator() = default;
  56. // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now.
  57. // The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful.
  58. bool Init(const CNodePtr &cnode, size_t threshold_count = 0);
  59. // Update old data stored in ParameterAggregator with new data.
  60. // The data could have many meanings: weights, gradients, learning_rate, momentum, etc.
  61. bool UpdateData(const std::map<std::string, Address> &new_data);
  62. // Launch aggregators/optimizers of this ParameterAggregator in order.
  63. bool LaunchAggregators();
  64. bool LaunchOptimizers();
  65. // The implementation for primitive Pull in parameter server training mode.
  66. // Every call of this method will increase the count for pull by 1.
  67. AddressPtr Pull();
  68. // Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing
  69. // any change of status.
  70. AddressPtr GetWeight();
  71. // After aggregation/optimizing/pulling of one iteration is done, caller must reset the status to ensure the
  72. // correctness of the aggregation/optimizing/pulling for next iteration.
  73. void ResetAggregationStatus();
  74. void ResetOptimizingStatus();
  75. void ResetPullingStatus();
  76. // Returns the aggregation/optimizing/pulling status to the caller.
  77. bool IsAggregationDone() const;
  78. bool IsOptimizingDone() const;
  79. bool IsPullingDone() const;
  80. private:
  81. // Initializing aggregation/optimizer kenerls based on the cnode. The reason of this is described in the file
  82. // kernel/kernel_factory.h.
  83. bool InitAggregationKernels(const CNodePtr &cnode);
  84. bool InitOptimizerKernels(const CNodePtr &cnode);
  85. // Assign memory for server kernel K(AggregationKernel/OptimizerKernel).
  86. // The memory assigned can be accessed by MemoryRegister. The memory could be weights, gradients, learning_rate,
  87. // momentum, etc.
  88. template <typename K>
  89. bool AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
  90. std::shared_ptr<MemoryRegister> memory_register);
  91. // Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in
  92. // memory_register.
  93. bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel,
  94. const std::shared_ptr<MemoryRegister> memory_register);
  95. bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optim_kernel,
  96. const std::shared_ptr<MemoryRegister> memory_register);
  97. // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user
  98. // configuration, etc.
  99. std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode);
  100. ServerMode server_mode_;
  101. size_t required_push_count_;
  102. size_t required_pull_count_;
  103. size_t current_pull_count_;
  104. // The status of aggregation/optimizing/pulling.
  105. bool aggregation_done_;
  106. bool optimizing_done_;
  107. bool pulling_done_;
  108. // ParameterAggregator stores all data that it needs for aggregation, optimizing, etc.
  109. std::shared_ptr<MemoryRegister> memory_register_;
  110. // Update could have multiple aggregation and optimizer server kernels.
  111. // Here stores multiple pairs of server kernels to parameters of their Launch function.
  112. std::vector<std::pair<std::shared_ptr<kernel::AggregationKernel>, KernelParams>> aggregation_kernel_parameters_;
  113. std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
  114. };
  115. } // namespace server
  116. } // namespace ps
  117. } // namespace mindspore
  118. #endif // MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_