You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

worker.h 7.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_WORKER_H_
  17. #define MINDSPORE_CCSRC_PS_WORKER_H_
  18. #include <utility>
  19. #include <memory>
  20. #include <vector>
  21. #include <string>
  22. #include <numeric>
  23. #include <functional>
  24. #include <algorithm>
  25. #include <map>
  26. #include <mutex>
  27. #include <unordered_set>
  28. #include <unordered_map>
  29. #include "utils/log_adapter.h"
  30. #include "ir/tensor.h"
  31. #include "ps/util.h"
  32. #include "ps/constants.h"
  33. #include "utils/shape_utils.h"
  34. #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
  35. #include "ps/core/worker_node.h"
  36. #include "ps/embedding_table_shard_metadata.h"
  37. #include "proto/comm.pb.h"
  38. #include "proto/ps.pb.h"
  39. #include "ps/ps_context.h"
  40. namespace mindspore {
  41. namespace ps {
  42. class Worker {
  43. public:
  44. static Worker &GetInstance() {
  45. static Worker instance;
  46. return instance;
  47. }
  48. using Callback = std::function<void()>;
  49. using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>;
  50. using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>;
  51. using EmbeddingPartitioner = std::function<void(
  52. const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
  53. using KVPartitioner =
  54. std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
  55. void Run();
  56. void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
  57. void Pull(const size_t key, void *dev_addr, const size_t size);
  58. size_t SetParamKey(const std::string &param_name);
  59. size_t GetParamKey(const std::string &param_name);
  60. void SetParamInitInServer(const std::string &param_name, bool init_in_server);
  61. bool GetParamInitInServer(const std::string &param_name);
  62. void SetKeyOptimId(size_t key, const std::string &optimizer_name);
  63. void SetOptimInputShapes(size_t key, const ShapeVector &shape);
  64. void AddEmbeddingTable(const Key &key, const size_t &row_count);
  65. void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
  66. const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
  67. const ParamInitInfoMessage &info);
  68. void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
  69. void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
  70. int64_t cmd);
  71. void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
  72. const std::vector<float> &vals);
  73. bool running() { return running_; }
  74. void Finalize();
  75. private:
  76. Worker() : server_num_(-1), running_(false), key_cnt_(0) {}
  77. ~Worker() = default;
  78. Worker(const Worker &) = delete;
  79. Worker &operator=(const Worker &) = delete;
  80. void Initialize();
  81. bool IsKeyInit(const size_t key);
  82. void AddKeyToServerId(const Key &key);
  83. void AddKeyByHashMod(const Key &key);
  84. void InitPSOptimId(const size_t param_key);
  85. void InitPSOptimInputShapes(const size_t key);
  86. void InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size);
  87. bool IsReadyForPush(const Key &key);
  88. bool IsReadyForPull(const Key &key);
  89. void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
  90. const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
  91. const size_t segment_size, float *gradient, int *indices);
  92. void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
  93. const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data);
  94. void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {},
  95. int command = 0, int64_t priority = 0);
  96. void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
  97. size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
  98. void PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens = nullptr,
  99. int cmd = 0, int64_t priority = 0);
  100. void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
  101. const std::map<int64_t, int64_t> &attrs);
  102. void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
  103. const std::map<int64_t, int64_t> &attrs);
  104. void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  105. const std::map<int64_t, int64_t> &attrs);
  106. void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
  107. const std::map<int64_t, int64_t> &attrs);
  108. void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  109. const std::map<int64_t, int64_t> &attrs);
  110. void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  111. const std::map<int64_t, int64_t> &attrs);
  112. void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
  113. const std::map<int64_t, int64_t> &attrs);
  114. void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
  115. const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens);
  116. int64_t server_num_;
  117. bool running_;
  118. std::mutex running_mutex_;
  119. size_t key_cnt_;
  120. std::map<std::string, size_t> param_to_key_;
  121. std::map<size_t, bool> init_keys_;
  122. std::map<size_t, int64_t> key_to_optimId_;
  123. std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
  124. std::map<std::string, bool> param_to_init_in_server_;
  125. core::WorkerNode worker_node_;
  126. EmbeddingPartitioner lookup_partitioner_;
  127. KVPartitioner sparse_partitioner_;
  128. KVPartitioner round_robin_partitioner_;
  129. KVPartitioner worker_init_embedding_partitioner_;
  130. KVPartitioner update_embedding_partitioner_;
  131. KVPartitioner broadcast_partitioner_;
  132. std::unordered_map<Key, int64_t> key_to_server_id_;
  133. std::unordered_map<Key, size_t> embedding_row_cnt_;
  134. std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
  135. };
  136. } // namespace ps
  137. } // namespace mindspore
  138. #endif // MINDSPORE_CCSRC_PS_WORKER_H_