You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.h 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
  17. #define MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
  18. #include <map>
  19. #include <string>
  20. #include <numeric>
  21. #include <climits>
  22. #include <memory>
  23. #include <functional>
  24. #include "proto/ps.pb.h"
  25. #include "proto/fl.pb.h"
  26. #include "ir/anf.h"
  27. #include "utils/utils.h"
  28. #include "ir/dtype/type_id.h"
  29. #include "backend/kernel_compiler/cpu/cpu_kernel.h"
  30. #include "schema/fl_job_generated.h"
  31. #include "schema/cipher_generated.h"
  32. #include "ps/ps_context.h"
  33. #include "ps/core/communicator/http_message_handler.h"
  34. #include "ps/core/communicator/tcp_server.h"
  35. #include "ps/core/communicator/message_handler.h"
  36. namespace mindspore {
  37. namespace ps {
  38. namespace server {
  39. // Definitions for the server framework.
  40. enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
  41. enum CommType { HTTP = 0, TCP };
  42. enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
  43. struct RoundConfig {
  44. std::string name;
  45. bool check_timeout = false;
  46. size_t time_window = 3000;
  47. bool check_count = false;
  48. size_t threshold_count = 0;
  49. };
  50. using mindspore::kernel::Address;
  51. using mindspore::kernel::AddressPtr;
  52. using mindspore::kernel::CPUKernel;
  53. using FBBuilder = flatbuffers::FlatBufferBuilder;
  54. using TimeOutCb = std::function<void(bool)>;
  55. using StopTimerCb = std::function<void(void)>;
  56. using FinishIterCb = std::function<void(bool)>;
  57. using FinalizeCb = std::function<void(void)>;
  58. using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
  59. // Information about whether server kernel will reuse kernel node memory from the front end.
  60. // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
  61. // Value refers to the kernel node's parameter index.
  62. using ReuseKernelNodeInfo = std::map<std::string, size_t>;
  63. // UploadData refers to the data which is uploaded by workers.
  64. // Key refers to the data name. For example: "weights", "grad", "learning_rate", etc. This will be set by the worker.
  65. // Value refers to the data of the key.
  66. // We use Address instead of AddressPtr because:
  67. // 1. Address doesn't need to call make_shared<T> so it has better performance.
  68. // 2. The data uploaded by worker is normally parsed from FlatterBuffers or ProtoBuffer. For example: learning rate, new
  69. // weights, etc. Address is enough to store these data.
  70. // Pay attention that Address only stores the void* pointer of the data, so the data must not be released before the
  71. // related logic is done.
  72. using UploadData = std::map<std::string, Address>;
  73. constexpr auto kWeight = "weight";
  74. constexpr auto kNewWeight = "new_weight";
  75. constexpr auto kAccumulation = "accum";
  76. constexpr auto kLearningRate = "lr";
  77. constexpr auto kGradient = "grad";
  78. constexpr auto kNewGradient = "new_grad";
  79. constexpr auto kMomentum = "momentum";
  80. constexpr auto kIndices = "indices";
  81. constexpr auto kAdamM = "m";
  82. constexpr auto kAdamV = "v";
  83. constexpr auto kAdamBeta1Power = "beta1_power";
  84. constexpr auto kAdamBeta2Power = "beta2_power";
  85. constexpr auto kAdamBeta1 = "beta1";
  86. constexpr auto kAdamBeta2 = "beta2";
  87. constexpr auto kAdamEps = "eps";
  88. constexpr auto kFtrlLinear = "linear";
  89. constexpr auto kDataSize = "data_size";
  90. constexpr auto kNewDataSize = "new_data_size";
  91. // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
  92. // launched.
  93. using OptimParamNameToIndex = std::map<std::string, std::map<std::string, size_t>>;
  94. const OptimParamNameToIndex kMomentumNameToIdx = {
  95. {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kLearningRate, 2}, {kGradient, 3}, {kMomentum, 4}}}, {"outputs", {}}};
  96. const OptimParamNameToIndex kAdamNameToIdx = {{"inputs",
  97. {{kWeight, 0},
  98. {kAdamM, 1},
  99. {kAdamV, 2},
  100. {kAdamBeta1Power, 3},
  101. {kAdamBeta2Power, 4},
  102. {kLearningRate, 5},
  103. {kAdamBeta1, 6},
  104. {kAdamBeta2, 7},
  105. {kAdamEps, 8},
  106. {kGradient, 9}}},
  107. {"outputs", {}}};
  108. const OptimParamNameToIndex kSparseAdamNameToIdx = {{"inputs",
  109. {{kWeight, 0},
  110. {kAdamM, 1},
  111. {kAdamV, 2},
  112. {kAdamBeta1Power, 3},
  113. {kAdamBeta2Power, 4},
  114. {kLearningRate, 5},
  115. {kAdamBeta1, 6},
  116. {kAdamBeta1, 7},
  117. {kAdamEps, 8},
  118. {kGradient, 9},
  119. {kIndices, 10}}},
  120. {"outputs", {}}};
  121. const OptimParamNameToIndex kSparseFtrlNameToIdx = {
  122. {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kFtrlLinear, 2}, {kGradient, 3}, {kIndices, 4}}}, {"outputs", {}}};
  123. const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {
  124. {kApplyMomentumOpName, kMomentumNameToIdx},
  125. {kFusedSparseAdamName, kSparseAdamNameToIdx},
  126. {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
  127. {kApplyAdamOpName, kAdamNameToIdx},
  128. };
  129. constexpr uint32_t kLeaderServerRank = 0;
  130. constexpr size_t kWorkerMgrThreadPoolSize = 32;
  131. constexpr size_t kWorkerMgrMaxTaskNum = 64;
  132. constexpr size_t kCipherMgrThreadPoolSize = 32;
  133. constexpr size_t kCipherMgrMaxTaskNum = 64;
  134. constexpr size_t kExecutorThreadPoolSize = 32;
  135. constexpr size_t kExecutorMaxTaskNum = 32;
  136. constexpr int kHttpSuccess = 200;
  137. constexpr auto kPBProtocol = "PB";
  138. constexpr auto kFBSProtocol = "FBS";
  139. constexpr auto kSuccess = "Success";
  140. constexpr auto kFedAvg = "FedAvg";
  141. constexpr auto kAggregationKernelType = "Aggregation";
  142. constexpr auto kOptimizerKernelType = "Optimizer";
  143. constexpr auto kCtxFuncGraph = "FuncGraph";
  144. constexpr auto kCtxIterNum = "iteration";
  145. constexpr auto kCtxDeviceMetas = "device_metas";
  146. constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
  147. constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
  148. constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
  149. constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
  150. constexpr auto kCtxUpdateModelThld = "update_model_threshold";
  151. constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
  152. // This macro the current timestamp in milliseconds.
  153. #define CURRENT_TIME_MILLI \
  154. std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
  155. #define RETURN_IF_NULL(expr, ret) \
  156. if (expr == nullptr) { \
  157. MS_LOG(ERROR) << #expr << " is nullptr."; \
  158. return ret; \
  159. }
  160. // This method returns the size in bytes of the given TypeId.
  161. inline size_t GetTypeIdByte(const TypeId &type) {
  162. switch (type) {
  163. case kNumberTypeFloat16:
  164. return 2;
  165. case kNumberTypeUInt32:
  166. case kNumberTypeFloat32:
  167. return 4;
  168. case kNumberTypeUInt64:
  169. return 8;
  170. default:
  171. MS_LOG(EXCEPTION) << "TypeId " << type << " not supported.";
  172. return 0;
  173. }
  174. }
  175. inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size_t param_idx) {
  176. RETURN_IF_NULL(kernel_node, nullptr);
  177. auto param_node =
  178. AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, param_idx), 0).first->cast<ParameterPtr>();
  179. RETURN_IF_NULL(param_node, nullptr);
  180. auto param_tensor = param_node->default_param()->cast<tensor::TensorPtr>();
  181. RETURN_IF_NULL(param_tensor, nullptr);
  182. AddressPtr addr = std::make_shared<kernel::Address>();
  183. addr->addr = param_tensor->data_c();
  184. addr->size = param_tensor->data().nbytes();
  185. return addr;
  186. }
  187. // Definitions for Federated Learning.
  188. // Definitions for Parameter Server.
  189. } // namespace server
  190. } // namespace ps
  191. } // namespace mindspore
  192. #endif // MINDSPORE_CCSRC_PS_SERVER_COMMON_H_