You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

worker.h 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_WORKER_H_
  17. #define MINDSPORE_CCSRC_PS_WORKER_H_
  18. #include <utility>
  19. #include <memory>
  20. #include <vector>
  21. #include <string>
  22. #include <numeric>
  23. #include <functional>
  24. #include <map>
  25. #include "ps/ps.h"
  26. #include "utils/log_adapter.h"
  27. #include "ir/tensor.h"
  28. #include "ps/util.h"
  29. #include "ps/common.h"
  30. #include "ps/worker_proxy.h"
  31. #include "utils/shape_utils.h"
  32. namespace mindspore {
  33. namespace ps {
  34. template <typename T>
  35. class Worker {
  36. public:
  37. static Worker &GetInstance() {
  38. static Worker instance;
  39. return instance;
  40. }
  41. void Run();
  42. void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
  43. void Pull(const size_t key, void *dev_addr, const size_t size);
  44. size_t SetParamKey(const std::string &param_name);
  45. void SetParamInitInServer(const std::string &param_name, bool init_in_server);
  46. bool GetParamInitInServer(const std::string &param_name);
  47. void SetKeyOptimId(size_t key, const std::string &optimizer_name);
  48. void SetOptimInputShapes(size_t key, const ShapeVector &shape);
  49. void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
  50. void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const ShapeVector &sizes);
  51. void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
  52. void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  53. const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
  54. void Finalize();
  55. private:
  56. Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {}
  57. ~Worker() = default;
  58. Worker(const Worker &) = delete;
  59. Worker &operator=(const Worker &) = delete;
  60. bool IsKeyInit(const size_t key);
  61. size_t GetParamKey(const std::string &param_name);
  62. void InitPSOptimId(const size_t param_key);
  63. void InitPSOptimInputShapes(const size_t key);
  64. void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
  65. static void EmbeddingLookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
  66. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {}
  67. std::shared_ptr<WorkerProxy<T>> kv_worker_;
  68. bool running_;
  69. size_t key_cnt_;
  70. std::map<std::string, size_t> param_to_key_;
  71. std::map<size_t, bool> init_keys_;
  72. std::map<size_t, int> key_to_optimId_;
  73. std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
  74. std::map<std::string, bool> param_to_init_in_server_;
  75. };
  76. template <typename T>
  77. void Worker<T>::Run() {
  78. if (running_) {
  79. MS_LOG(INFO) << "'Worker is already running.";
  80. return;
  81. }
  82. MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
  83. ::ps::Start(0);
  84. MS_LOG(INFO) << "Worker connected successfully.";
  85. if (!::ps::IsWorker()) {
  86. MS_LOG(EXCEPTION) << "The role is not worker.";
  87. }
  88. kv_worker_ = std::make_shared<WorkerProxy<T>>(0, 0, 1, 2);
  89. running_ = true;
  90. }
  91. template <typename T>
  92. void Worker<T>::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) {
  93. if (keys.size() == 0) {
  94. MS_LOG(EXCEPTION) << "key size should be greater than zero";
  95. }
  96. if (key_to_optimId_.count(keys[0]) == 0) {
  97. MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0];
  98. }
  99. Key key = keys[0];
  100. int optim_id = key_to_optimId_[key];
  101. bool is_sparse = false;
  102. if (optim_id == 1 || optim_id == 2 || optim_id == 3) {
  103. is_sparse = true;
  104. }
  105. int grad_index = -1;
  106. int indice_index = -1;
  107. // Sparse adam gradient
  108. if (optim_id == 1 || optim_id == 2) {
  109. grad_index = 6;
  110. indice_index = 7;
  111. // Sparse ftrl gradient
  112. } else if (optim_id == 3) {
  113. grad_index = 0;
  114. indice_index = 1;
  115. }
  116. size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int>());
  117. ::ps::SArray<T> total_buffer(total_size, 0);
  118. size_t offset = 0;
  119. size_t dst_size = 0;
  120. size_t src_size = 0;
  121. for (size_t i = 0; i < sizes.size(); i++) {
  122. void *dst_data = total_buffer.data() + offset / sizeof(T);
  123. void *src_data = reinterpret_cast<void *>(addrs[i]);
  124. MS_EXCEPTION_IF_NULL(dst_data);
  125. MS_EXCEPTION_IF_NULL(src_data);
  126. dst_size = sizes[i] * sizeof(T);
  127. src_size = sizes[i] * sizeof(T);
  128. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  129. if (ret != 0) {
  130. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  131. return;
  132. }
  133. offset += sizes[i] * sizeof(T);
  134. }
  135. while (!kv_worker_->IsReadyForPush(keys[0])) {
  136. continue;
  137. }
  138. if (!is_sparse) {
  139. kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes));
  140. } else {
  141. std::vector<int> &var_shape = key_to_optim_shapes_[key][0];
  142. int first_dim_size = var_shape[0];
  143. int outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int>());
  144. kv_worker_->PushSparseData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes), grad_index,
  145. indice_index, first_dim_size, outer_dim_size);
  146. }
  147. }
  148. template <typename T>
  149. void Worker<T>::Pull(const size_t key, void *dev_addr, const size_t size) {
  150. MS_EXCEPTION_IF_NULL(dev_addr);
  151. ::ps::SArray<T> variables(size / sizeof(T), 0);
  152. while (!kv_worker_->IsReadyForPull(key)) {
  153. continue;
  154. }
  155. kv_worker_->PullData({key}, &variables);
  156. size_t dst_size = size;
  157. size_t src_size = size;
  158. auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size);
  159. if (ret != 0) {
  160. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  161. return;
  162. }
  163. }
  164. template <typename T>
  165. void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  166. const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd) {
  167. MS_EXCEPTION_IF_NULL(lookup_result);
  168. kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd);
  169. }
  170. template <typename T>
  171. void Worker<T>::Finalize() {
  172. if (running_) {
  173. MS_LOG(INFO) << "Worker starts finalizing...";
  174. kv_worker_->Finalize();
  175. kv_worker_.reset();
  176. running_ = false;
  177. MS_LOG(INFO) << "Worker finalized successfully.";
  178. }
  179. }
  180. template <typename T>
  181. void Worker<T>::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) {
  182. MS_EXCEPTION_IF_NULL(origin_addr);
  183. ::ps::SArray<T> addr(reinterpret_cast<T *>(origin_addr), size / sizeof(T));
  184. ::ps::SArray<::ps::Key> key(keys);
  185. ::ps::SArray<int> lens;
  186. lens.push_back(addr.size());
  187. kv_worker_->PushData(key, addr, lens, kInitWeightsCmd);
  188. init_keys_[key[0]] = true;
  189. }
  190. template <typename T>
  191. void Worker<T>::SetOptimInputShapes(size_t key, const ShapeVector &shape) {
  192. if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) {
  193. key_to_optim_shapes_[key] = {shape};
  194. } else {
  195. key_to_optim_shapes_[key].push_back(shape);
  196. }
  197. }
  198. template <typename T>
  199. void Worker<T>::InitPSOptimInputShapes(const size_t key) {
  200. ::ps::SArray<::ps::Key> keys;
  201. ::ps::SArray<int> shape_len;
  202. ::ps::SArray<T> all_shape;
  203. std::vector<ShapeVector> shapes = key_to_optim_shapes_[key];
  204. for (auto shape : shapes) {
  205. keys.push_back(key);
  206. if (shape.size() == 0) {
  207. shape_len.push_back(1);
  208. all_shape.push_back(1);
  209. } else {
  210. shape_len.push_back(SizeToInt(shape.size()));
  211. for (auto dim : shape) {
  212. all_shape.push_back(static_cast<T>(dim));
  213. }
  214. }
  215. }
  216. MS_LOG(INFO) << "keys:" << keys;
  217. MS_LOG(INFO) << "shape_len:" << shape_len;
  218. MS_LOG(INFO) << "all_shape:" << all_shape;
  219. if (!init_keys_[key]) {
  220. init_keys_[key] = true;
  221. }
  222. kv_worker_->PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
  223. }
  224. template <typename T>
  225. bool Worker<T>::IsKeyInit(const size_t key) {
  226. if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) {
  227. return false;
  228. }
  229. return true;
  230. }
  231. template <typename T>
  232. size_t Worker<T>::SetParamKey(const std::string &param_name) {
  233. size_t key = UINT64_MAX;
  234. if (param_to_key_.count(param_name)) {
  235. key = param_to_key_[param_name];
  236. MS_LOG(INFO) << param_name << " key is already set: key value is " << key;
  237. } else {
  238. key = key_cnt_++;
  239. param_to_key_[param_name] = key;
  240. MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name;
  241. }
  242. return key;
  243. }
  244. template <typename T>
  245. void Worker<T>::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
  246. MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
  247. param_to_init_in_server_[param_name] = init_in_server;
  248. }
  249. template <typename T>
  250. bool Worker<T>::GetParamInitInServer(const std::string &param_name) {
  251. if (param_to_init_in_server_.count(param_name) == 0) {
  252. return false;
  253. }
  254. return param_to_init_in_server_[param_name];
  255. }
  256. template <typename T>
  257. size_t Worker<T>::GetParamKey(const std::string &param_name) {
  258. size_t key = kInvalidKey;
  259. if (param_to_key_.find(param_name) != param_to_key_.end()) {
  260. key = param_to_key_[param_name];
  261. MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key;
  262. }
  263. return key;
  264. }
  265. template <typename T>
  266. void Worker<T>::SetKeyOptimId(size_t key, const std::string &optimizer_name) {
  267. key_to_optimId_[key] = Util::optimizer_id(optimizer_name);
  268. }
  269. template <typename T>
  270. void Worker<T>::InitPSOptimId(const size_t param_key) {
  271. if (key_to_optimId_.count(param_key) == 0) {
  272. MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key;
  273. }
  274. int optim_id = key_to_optimId_[param_key];
  275. ::ps::SArray<::ps::Key> keys = {param_key};
  276. ::ps::SArray<T> optim_id_vals = {static_cast<T>(optim_id)};
  277. ::ps::SArray<int> optim_id_lens = {optim_id_vals.size()};
  278. kv_worker_->PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd);
  279. }
  280. template <typename T>
  281. void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes,
  282. const ShapeVector &sizes) {
  283. bool has_init = IsKeyInit(keys[0]);
  284. if (has_init) {
  285. MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized.";
  286. return;
  287. }
  288. ::ps::SArray<T> shapes_val;
  289. for (auto dim : shapes) {
  290. shapes_val.push_back(static_cast<T>(dim));
  291. }
  292. kv_worker_->Wait(kv_worker_->InitEmbeddingTable(::ps::SArray<::ps::Key>(keys), shapes_val, ::ps::SArray<int>(sizes)));
  293. }
  294. template <typename T>
  295. void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
  296. MS_EXCEPTION_IF_NULL(tensor);
  297. MS_EXCEPTION_IF_NULL(input_node);
  298. auto pk_node = input_node->cast<ParameterPtr>();
  299. MS_EXCEPTION_IF_NULL(pk_node);
  300. const std::string &param_name = pk_node->fullname_with_scope();
  301. void *param_data = tensor->data_c();
  302. size_t param_size = LongToSize(tensor->data().nbytes());
  303. if (param_size > INT_MAX) {
  304. MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size;
  305. }
  306. size_t param_key = GetParamKey(param_name);
  307. if (param_key == kInvalidKey) {
  308. MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
  309. return;
  310. }
  311. bool init_in_server = false;
  312. auto param_info_ptr = pk_node->param_info();
  313. if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
  314. init_in_server = true;
  315. }
  316. SetParamInitInServer(param_name, init_in_server);
  317. bool init = IsKeyInit(param_key);
  318. if (!init) {
  319. MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
  320. << ", whether init in server: " << init_in_server;
  321. kv_worker_->AddKeyToServerId(param_key);
  322. if (!init_in_server) {
  323. InitPSParamData({param_key}, param_data, param_size);
  324. }
  325. InitPSOptimId(param_key);
  326. InitPSOptimInputShapes(param_key);
  327. }
  328. }
  329. template <typename T>
  330. void Worker<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
  331. bool has_init = IsKeyInit(key);
  332. if (has_init) {
  333. return;
  334. }
  335. kv_worker_->AddEmbeddingTable(key, row_count);
  336. }
  337. static Worker<float> &worker = Worker<float>::GetInstance();
  338. } // namespace ps
  339. } // namespace mindspore
  340. #endif // MINDSPORE_CCSRC_PS_WORKER_H_