You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.h 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
  17. #define MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
  18. #include <map>
  19. #include <set>
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. #include <mutex>
  24. #include <condition_variable>
  25. #include "ps/server/common.h"
  26. #include "ps/server/parameter_aggregator.h"
  27. namespace mindspore {
  28. namespace ps {
  29. namespace server {
  30. // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
  31. // logics relevant to kernel launching.
  32. class Executor {
  33. public:
  34. static Executor &GetInstance() {
  35. static Executor instance;
  36. return instance;
  37. }
  38. // FuncGraphPtr func_graph is the graph compiled by the frontend. aggregation_count is the number which will
  39. // be used for aggregators.
  40. // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the
  41. // optimizer cnode's input. So we need to initialize server executor using func_graph.
  42. void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count);
  43. // Called in parameter server training mode to do Push operation.
  44. // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
  45. // as completed.
  46. bool HandlePush(const std::string &param_name, const UploadData &upload_data);
  47. // Called in parameter server training mode to do Pull operation.
  48. // Returns the value of parameter param_name.
  49. // HandlePull method must be called the same times as HandlePush is called before it's considered as
  50. // completed.
  51. AddressPtr HandlePull(const std::string &param_name);
  52. // Called in federated learning training mode. Update value for parameter param_name.
  53. bool HandleModelUpdate(const std::string &param_name, const UploadData &upload_data);
  54. // Called in asynchronous federated learning training mode. Update current model with the new feature map
  55. // asynchronously.
  56. bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
  57. // Forcibly overwrite specific weights in overwriteWeights message.
  58. bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map);
  59. // Returns value for multiple trainable parameters passed by weight_names.
  60. std::map<std::string, AddressPtr> HandleGetWeightsByKey(const std::vector<std::string> &param_names);
  61. // Reset the aggregation status for all aggregation kernels in the server.
  62. void ResetAggregationStatus();
  63. // Judge whether aggregation processes for all weights/gradients are completed.
  64. bool IsAllWeightAggregationDone();
  65. // Judge whether the aggregation processes for the given param_names are completed.
  66. bool IsWeightAggrDone(const std::vector<std::string> &param_names);
  67. // Returns whole model in key-value where key refers to the parameter name.
  68. std::map<std::string, AddressPtr> GetModel();
  69. // Returns whether the executor singleton is already initialized.
  70. bool initialized() const;
  71. const std::vector<std::string> &param_names() const;
  72. private:
  73. Executor() {}
  74. ~Executor() = default;
  75. Executor(const Executor &) = delete;
  76. Executor &operator=(const Executor &) = delete;
  77. // Returns the trainable parameter name parsed from this cnode.
  78. std::string GetTrainableParamName(const CNodePtr &cnode);
  79. // Server's graph is basically the same as Worker's graph, so we can get all information from func_graph for later
  80. // computations. Including forward and backward propagation, aggregation, optimizing, etc.
  81. bool InitParamAggregator(const FuncGraphPtr &func_graph);
  82. bool initialized_;
  83. size_t aggregation_count_;
  84. std::vector<std::string> param_names_;
  85. // The map for trainable parameter names and its ParameterAggregator, as noted in the header file
  86. // parameter_aggregator.h
  87. std::map<std::string, std::shared_ptr<ParameterAggregator>> param_aggrs_;
  88. // The mutex ensures that the operation on whole model is threadsafe.
  89. // The whole model is constructed by all trainable parameters.
  90. std::mutex model_mutex_;
  91. // Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can
  92. // acquire lock before calling its method.
  93. std::map<std::string, std::mutex> parameter_mutex_;
  94. };
  95. } // namespace server
  96. } // namespace ps
  97. } // namespace mindspore
  98. #endif // MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_