You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.h 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
  17. #define MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
  18. #include <map>
  19. #include <set>
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. #include <mutex>
  24. #include <condition_variable>
  25. #include "fl/server/common.h"
  26. #include "fl/server/parameter_aggregator.h"
  27. #ifdef ENABLE_ARMOUR
  28. #include "fl/armour/cipher/cipher_unmask.h"
  29. #endif
  30. namespace mindspore {
  31. namespace fl {
  32. namespace server {
  33. // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
  34. // logics relevant to kernel launching.
  35. class Executor {
  36. public:
  37. static Executor &GetInstance() {
  38. static Executor instance;
  39. return instance;
  40. }
  41. // FuncGraphPtr func_graph is the graph compiled by the frontend. aggregation_count is the number which will
  42. // be used for aggregators.
  43. // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the
  44. // optimizer cnode's input. So we need to initialize server executor using func_graph.
  45. void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count);
  46. // Reinitialize parameter aggregators after scaling operations are done.
  47. bool ReInitForScaling();
  48. // After hyper-parameters are updated, some parameter aggregators should be reinitialized.
  49. bool ReInitForUpdatingHyperParams(size_t aggr_threshold);
  50. // Called in parameter server training mode to do Push operation.
  51. // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
  52. // as completed.
  53. bool HandlePush(const std::string &param_name, const UploadData &upload_data);
  54. // Called in parameter server training mode to do Pull operation.
  55. // Returns the value of parameter param_name.
  56. // HandlePull method must be called the same times as HandlePush is called before it's considered as
  57. // completed.
  58. AddressPtr HandlePull(const std::string &param_name);
  59. // Called in federated learning training mode. Update value for parameter param_name.
  60. bool HandleModelUpdate(const std::string &param_name, const UploadData &upload_data);
  61. // Called in asynchronous federated learning training mode. Update current model with the new feature map
  62. // asynchronously.
  63. bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
  64. // Overwrite the weights in server using pushed feature map.
  65. bool HandlePushWeight(const std::map<std::string, Address> &feature_map);
  66. // Returns multiple trainable parameters passed by weight_names.
  67. std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> &param_names);
  68. // Reset the aggregation status for all aggregation kernels in the server.
  69. void ResetAggregationStatus();
  70. // Judge whether aggregation processes for all weights/gradients are completed.
  71. bool IsAllWeightAggregationDone();
  72. // Judge whether the aggregation processes for the given param_names are completed.
  73. bool IsWeightAggrDone(const std::vector<std::string> &param_names);
  74. // Returns whole model in key-value where key refers to the parameter name.
  75. std::map<std::string, AddressPtr> GetModel();
  76. // Returns whether the executor singleton is already initialized.
  77. bool initialized() const;
  78. const std::vector<std::string> &param_names() const;
  79. // The unmasking method for pairwise encrypt algorithm.
  80. bool Unmask();
  81. // The setter and getter for unmasked flag to judge whether the unmasking is completed.
  82. void set_unmasked(bool unmasked);
  83. bool unmasked() const;
  84. private:
  85. Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}), unmasked_(false) {}
  86. ~Executor() = default;
  87. Executor(const Executor &) = delete;
  88. Executor &operator=(const Executor &) = delete;
  89. // Returns the trainable parameter name parsed from this cnode.
  90. std::string GetTrainableParamName(const CNodePtr &cnode);
  91. // Server's graph is basically the same as Worker's graph, so we can get all information from func_graph for later
  92. // computations. Including forward and backward propagation, aggregation, optimizing, etc.
  93. bool InitParamAggregator(const FuncGraphPtr &func_graph);
  94. bool initialized_;
  95. size_t aggregation_count_;
  96. std::vector<std::string> param_names_;
  97. // The map for trainable parameter names and its ParameterAggregator, as noted in the header file
  98. // parameter_aggregator.h
  99. std::map<std::string, std::shared_ptr<ParameterAggregator>> param_aggrs_;
  100. // The mutex ensures that the operation on whole model is threadsafe.
  101. // The whole model is constructed by all trainable parameters.
  102. std::mutex model_mutex_;
  103. // Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can
  104. // acquire lock before calling its method.
  105. std::map<std::string, std::mutex> parameter_mutex_;
  106. #ifdef ENABLE_ARMOUR
  107. armour::CipherUnmask cipher_unmask_;
  108. #endif
  109. // The flag represents the unmasking status.
  110. std::atomic<bool> unmasked_;
  111. };
  112. } // namespace server
  113. } // namespace fl
  114. } // namespace mindspore
  115. #endif // MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_