You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

worker.cc 37 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/worker.h"
  17. #include "pipeline/jit/pipeline.h"
  18. namespace mindspore {
  19. namespace ps {
  20. void Worker::Run() {
  21. std::lock_guard<std::mutex> lock(running_mutex_);
  22. server_num_ = PSContext::instance()->initial_server_num();
  23. if (running_) {
  24. MS_LOG(INFO) << "'Worker is already running.";
  25. return;
  26. }
  27. if (!PSContext::instance()->is_worker()) {
  28. MS_LOG(EXCEPTION) << "The role is not worker.";
  29. }
  30. Initialize();
  31. worker_node_.set_event_callback([&](const core::NodeEvent &event) {
  32. if ((event == core::NodeEvent::CLUSTER_TIMEOUT) ||
  33. (event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) {
  34. MS_LOG(WARNING) << "Trigger timeout event:" << event << " begin to exit the system!";
  35. Finalize();
  36. exit(0);
  37. }
  38. });
  39. MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
  40. worker_node_.Start();
  41. MS_LOG(INFO) << "Worker connected successfully.";
  42. running_ = true;
  43. }
  44. void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) {
  45. if (keys.size() == 0) {
  46. MS_LOG(EXCEPTION) << "key size should be greater than zero";
  47. }
  48. if (key_to_optimId_.count(keys[0]) == 0) {
  49. MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0];
  50. }
  51. Key key = keys[0];
  52. int64_t optim_id = key_to_optimId_[key];
  53. MS_LOG(INFO) << "The key is:" << key << " the optim_id:" << optim_id;
  54. bool is_sparse = false;
  55. if (optim_id == 1 || optim_id == 2 || optim_id == 3) {
  56. is_sparse = true;
  57. }
  58. int64_t grad_index = -1;
  59. int64_t indice_index = -1;
  60. // Sparse adam gradient
  61. if (optim_id == 1 || optim_id == 2) {
  62. grad_index = 6;
  63. indice_index = 7;
  64. // Sparse ftrl gradient
  65. } else if (optim_id == 3) {
  66. grad_index = 0;
  67. indice_index = 1;
  68. }
  69. size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int64_t>());
  70. std::vector<float> total_buffer(total_size, 0);
  71. size_t offset = 0;
  72. for (size_t i = 0; i < sizes.size(); i++) {
  73. void *dst_data = total_buffer.data() + offset / sizeof(float);
  74. void *src_data = reinterpret_cast<void *>(addrs[i]);
  75. MS_EXCEPTION_IF_NULL(dst_data);
  76. MS_EXCEPTION_IF_NULL(src_data);
  77. int size = sizes[i] * sizeof(float);
  78. auto ret = memcpy_s(dst_data, size, src_data, size);
  79. if (ret != 0) {
  80. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  81. return;
  82. }
  83. offset += size;
  84. }
  85. MS_LOG(INFO) << "The total size is:" << total_size;
  86. while (running_ && (!IsReadyForPush(keys[0]))) {
  87. continue;
  88. }
  89. std::vector<int> sizes_int;
  90. (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
  91. [](const int64_t &value) { return static_cast<int>(value); });
  92. if (!is_sparse) {
  93. PushData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), kPushCmd);
  94. } else {
  95. std::vector<int64_t> &var_shape = key_to_optim_shapes_[key][0];
  96. int64_t first_dim_size = var_shape[0];
  97. int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int64_t>());
  98. MS_LOG(DEBUG) << "The keys:" << keys << " the total_buffer:" << total_buffer << " the sizes_int:" << sizes_int
  99. << " the grad_index:" << grad_index << " the indice_index:" << indice_index
  100. << " the first_dim_size:" << first_dim_size << " the outer_dim_size" << outer_dim_size;
  101. PushSparseData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), grad_index, indice_index,
  102. first_dim_size, outer_dim_size);
  103. }
  104. }
  105. void Worker::Pull(const size_t key, void *dev_addr, const size_t size) {
  106. MS_EXCEPTION_IF_NULL(dev_addr);
  107. std::vector<float> variables(size / sizeof(float), 0);
  108. while (running_ && (!IsReadyForPull(key))) {
  109. continue;
  110. }
  111. PullData({key}, &variables, nullptr, kPullCmd);
  112. MS_LOG(DEBUG) << "The variables:" << variables << " the size is:" << size;
  113. size_t dst_size = size;
  114. size_t src_size = size;
  115. auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size);
  116. if (ret != 0) {
  117. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  118. return;
  119. }
  120. }
  121. size_t Worker::SetParamKey(const std::string &param_name) {
  122. size_t key = UINT64_MAX;
  123. if (param_to_key_.count(param_name)) {
  124. key = param_to_key_[param_name];
  125. MS_LOG(INFO) << param_name << " key is already set: key value is " << key;
  126. } else {
  127. key = key_cnt_++;
  128. param_to_key_[param_name] = key;
  129. MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name;
  130. }
  131. return key;
  132. }
  133. size_t Worker::GetParamKey(const std::string &param_name) {
  134. size_t key = kInvalidKey;
  135. if (param_to_key_.find(param_name) != param_to_key_.end()) {
  136. key = param_to_key_[param_name];
  137. MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
  138. }
  139. return key;
  140. }
  141. void Worker::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
  142. MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
  143. param_to_init_in_server_[param_name] = init_in_server;
  144. }
  145. bool Worker::GetParamInitInServer(const std::string &param_name) {
  146. if (param_to_init_in_server_.count(param_name) == 0) {
  147. return false;
  148. }
  149. return param_to_init_in_server_[param_name];
  150. }
  151. void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) {
  152. MS_LOG(INFO) << "SetKeyOptimId key is:" << key << " optimizer_name:" << optimizer_name;
  153. key_to_optimId_[key] = Util::optimizer_id(optimizer_name);
  154. }
  155. void Worker::SetOptimInputShapes(size_t key, const ShapeVector &shape) {
  156. if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) {
  157. key_to_optim_shapes_[key] = {shape};
  158. } else {
  159. key_to_optim_shapes_[key].push_back(shape);
  160. }
  161. }
  162. void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) {
  163. bool has_init = IsKeyInit(key);
  164. if (has_init) {
  165. return;
  166. }
  167. uint64_t begin = 0;
  168. uint64_t end = 0;
  169. for (int64_t i = 0; i < server_num_; i++) {
  170. int64_t local_row_cnt = Util::LocalShard(row_count, i, server_num_);
  171. MS_LOG(DEBUG) << "The row_count:" << row_count << " the local_row_cnt:" << local_row_cnt;
  172. if (i == 0) {
  173. end = local_row_cnt - 1;
  174. } else {
  175. begin = end + 1;
  176. end += local_row_cnt;
  177. }
  178. EmbeddingTableShardMetadata range(begin, end);
  179. if (embedding_table_ranges_.count(key) == 0) {
  180. embedding_table_ranges_[key] = std::make_shared<std::vector<EmbeddingTableShardMetadata>>();
  181. MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]);
  182. }
  183. embedding_table_ranges_[key]->push_back(range);
  184. }
  185. embedding_row_cnt_[key] = row_count;
  186. }
  187. void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
  188. const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
  189. const ParamInitInfoMessage &info) {
  190. bool has_init = IsKeyInit(key);
  191. if (has_init) {
  192. MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized.";
  193. return;
  194. }
  195. EmbeddingTableMeta embedding_table_meta;
  196. embedding_table_meta.set_key(key);
  197. *embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()};
  198. *embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()};
  199. *embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()};
  200. *embedding_table_meta.mutable_info() = info;
  201. std::string kv_data = embedding_table_meta.SerializeAsString();
  202. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  203. int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
  204. if (ret != 0) {
  205. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  206. return;
  207. }
  208. worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kInitEmbeddingsCmd);
  209. }
  210. void Worker::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
  211. MS_EXCEPTION_IF_NULL(tensor);
  212. MS_EXCEPTION_IF_NULL(input_node);
  213. auto pk_node = input_node->cast<ParameterPtr>();
  214. MS_EXCEPTION_IF_NULL(pk_node);
  215. const std::string &param_name = pk_node->fullname_with_scope();
  216. void *param_data = tensor->data_c();
  217. size_t param_size = LongToSize(tensor->data().nbytes());
  218. size_t param_key = GetParamKey(param_name);
  219. if (param_key == kInvalidKey) {
  220. MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
  221. return;
  222. }
  223. bool init_in_server = false;
  224. auto param_info_ptr = pk_node->param_info();
  225. if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
  226. init_in_server = true;
  227. }
  228. SetParamInitInServer(param_name, init_in_server);
  229. bool init = IsKeyInit(param_key);
  230. if (!init) {
  231. MS_LOG(INFO) << "Init parameter key " << param_key << " and optimizer in parameter server side for " << param_name
  232. << ", whether init in server: " << init_in_server;
  233. AddKeyToServerId(param_key);
  234. if (!PsDataPrefetch::GetInstance().cache_enable()) {
  235. if (!init_in_server) {
  236. if (param_size > INT_MAX) {
  237. MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
  238. << param_size;
  239. }
  240. InitPSParamData({param_key}, param_data, param_size);
  241. }
  242. InitPSOptimId(param_key);
  243. InitPSOptimInputShapes(param_key);
  244. }
  245. }
  246. }
  247. void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
  248. int64_t cmd) {
  249. MS_EXCEPTION_IF_NULL(lookup_result);
  250. EmbeddingTableLookup embedding_table_lookup;
  251. embedding_table_lookup.set_key(key);
  252. *embedding_table_lookup.mutable_keys() = {lookup_ids.begin(), lookup_ids.end()};
  253. PartitionEmbeddingMessages messages;
  254. lookup_partitioner_(embedding_table_lookup, &messages, {});
  255. std::vector<uint32_t> rank_ids;
  256. std::vector<DataPtr> data;
  257. std::vector<size_t> sizes;
  258. for (size_t i = 0; i < messages.size(); i++) {
  259. if (messages.at(i).first) {
  260. rank_ids.push_back(i);
  261. std::string kv_data = messages.at(i).second.SerializeAsString();
  262. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  263. int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
  264. if (ret != 0) {
  265. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  266. return;
  267. }
  268. data.push_back(res);
  269. sizes.push_back(kv_data.length());
  270. }
  271. }
  272. std::vector<VectorPtr> resp;
  273. worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
  274. int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
  275. std::unordered_map<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
  276. std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
  277. std::shared_ptr<std::vector<Key>> keys = std::make_shared<std::vector<Key>>();
  278. int64_t value_offset = 0;
  279. for (size_t i = 0; i < resp.size(); ++i) {
  280. KVMessage message;
  281. message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
  282. for (auto j = 0; j < message.values_size(); j++) {
  283. values->push_back(message.values(j));
  284. }
  285. for (auto k = 0; k < message.keys_size(); k++) {
  286. const Key &key = message.keys(k);
  287. keys->push_back(key);
  288. }
  289. }
  290. for (size_t i = 0; i < keys->size(); i++) {
  291. const Key &key = keys->at(i);
  292. float *addr = values->data() + value_offset;
  293. value_offset += single_id_len;
  294. id_addr_map[key] = std::make_shared<std::pair<float *, int64_t>>(std::make_pair(addr, single_id_len));
  295. }
  296. float *result_addr = lookup_result->data();
  297. MS_EXCEPTION_IF_NULL(result_addr);
  298. int64_t offset = 0;
  299. size_t dst_size = 0;
  300. size_t src_size = 0;
  301. void *dst_data = nullptr;
  302. void *src_data = nullptr;
  303. for (size_t i = 0; i < lookup_ids.size(); i++) {
  304. if (id_addr_map.count(lookup_ids[i]) == 0) {
  305. offset += single_id_len;
  306. continue;
  307. }
  308. const Key &key = static_cast<Key>(lookup_ids[i]);
  309. auto &pair = id_addr_map[key];
  310. int64_t size = single_id_len * sizeof(float);
  311. dst_size = size;
  312. src_size = size;
  313. dst_data = result_addr + offset;
  314. src_data = pair->first;
  315. MS_EXCEPTION_IF_NULL(dst_data);
  316. MS_EXCEPTION_IF_NULL(src_data);
  317. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  318. if (ret != 0) {
  319. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  320. return;
  321. }
  322. offset += single_id_len;
  323. }
  324. }
  325. void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
  326. const std::vector<float> &vals) {
  327. KVMessage kvs;
  328. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  329. *kvs.mutable_len() = {lookup_ids.begin(), lookup_ids.end()};
  330. *kvs.mutable_values() = {vals.begin(), vals.end()};
  331. PartitionKVMessages messages;
  332. update_embedding_partitioner_(kvs, &messages, {});
  333. std::vector<uint32_t> rank_ids;
  334. std::vector<DataPtr> data;
  335. std::vector<size_t> sizes;
  336. for (size_t i = 0; i < messages.size(); i++) {
  337. if (messages.at(i).first) {
  338. rank_ids.push_back(i);
  339. std::string kv_data = messages.at(i).second.SerializeAsString();
  340. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  341. int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
  342. if (ret != 0) {
  343. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  344. return;
  345. }
  346. data.push_back(res);
  347. sizes.push_back(kv_data.length());
  348. }
  349. }
  350. worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, kUpdateEmbeddingsCmd);
  351. }
  352. void Worker::Finalize() {
  353. if (running_) {
  354. MS_LOG(INFO) << "Worker starts finalizing...";
  355. KVMessage kvs;
  356. kvs.add_keys(0);
  357. kvs.add_values(0.0f);
  358. std::string kv_data = kvs.SerializeAsString();
  359. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  360. int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
  361. if (ret != 0) {
  362. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  363. return;
  364. }
  365. worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kFinalizeCmd);
  366. worker_node_.Finish();
  367. worker_node_.Stop();
  368. running_ = false;
  369. MS_LOG(INFO) << "Worker finalized successfully.";
  370. }
  371. }
  372. void Worker::Initialize() {
  373. lookup_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  374. LookupIdPartitioner(send, partition, attrs);
  375. };
  376. worker_init_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  377. WorkerInitEmbeddingPartitioner(send, partition, attrs);
  378. };
  379. round_robin_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  380. RoundRobinPartitioner(send, partition, attrs);
  381. };
  382. sparse_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  383. SparsePartitioner(send, partition, attrs);
  384. };
  385. update_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  386. UpdateEmbeddingPartitioner(send, partition, attrs);
  387. };
  388. broadcast_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  389. BroadcastPartitioner(send, partition, attrs);
  390. };
  391. }
  392. bool Worker::IsKeyInit(const size_t key) {
  393. if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) {
  394. return false;
  395. }
  396. return true;
  397. }
  398. void Worker::AddKeyToServerId(const Key &key) { AddKeyByHashMod(key); }
  399. void Worker::AddKeyByHashMod(const Key &key) {
  400. if (server_num_ == 0) {
  401. MS_LOG(EXCEPTION) << "Server number is invalid:0";
  402. }
  403. key_to_server_id_[key] = static_cast<int64_t>(key % server_num_);
  404. MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key];
  405. }
  406. void Worker::InitPSOptimId(const size_t param_key) {
  407. MS_LOG(INFO) << "InitPSOptimId key is:" << param_key;
  408. if (key_to_optimId_.count(param_key) == 0) {
  409. MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key;
  410. }
  411. int64_t optim_id = key_to_optimId_[param_key];
  412. std::vector<Key> keys = {param_key};
  413. std::vector<float> optim_id_vals = {static_cast<float>(optim_id)};
  414. std::vector<int> optim_id_lens = {SizeToInt(optim_id_vals.size())};
  415. MS_LOG(INFO) << "The keys is" << keys << " the optim_id_vals is: " << optim_id_vals
  416. << " optim_id_lens is:" << optim_id_lens;
  417. PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd);
  418. }
  419. void Worker::InitPSOptimInputShapes(const size_t key) {
  420. std::vector<Key> keys;
  421. std::vector<int> shape_len;
  422. std::vector<float> all_shape;
  423. std::vector<ShapeVector> shapes = key_to_optim_shapes_[key];
  424. for (auto shape : shapes) {
  425. keys.push_back(key);
  426. if (shape.size() == 0) {
  427. shape_len.push_back(1);
  428. all_shape.push_back(1);
  429. } else {
  430. shape_len.push_back(SizeToLong(shape.size()));
  431. std::transform(shape.begin(), shape.end(), std::back_inserter(all_shape),
  432. [](size_t dim) -> float { return static_cast<float>(dim); });
  433. }
  434. }
  435. MS_LOG(INFO) << "keys:" << keys;
  436. MS_LOG(INFO) << "shape_len:" << shape_len;
  437. MS_LOG(INFO) << "all_shape:" << all_shape;
  438. if (!init_keys_[key]) {
  439. init_keys_[key] = true;
  440. }
  441. PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
  442. }
  443. void Worker::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) {
  444. MS_EXCEPTION_IF_NULL(origin_addr);
  445. std::vector<float> addr{reinterpret_cast<float *>(origin_addr),
  446. reinterpret_cast<float *>(origin_addr) + size / sizeof(float)};
  447. std::vector<Key> key(keys);
  448. std::vector<int> lens;
  449. lens.push_back(addr.size());
  450. MS_LOG(INFO) << "the keys are:" << keys;
  451. MS_LOG(INFO) << "the values are:" << addr;
  452. PushData(key, addr, lens, kInitWeightsCmd);
  453. init_keys_[key[0]] = true;
  454. }
  455. bool Worker::IsReadyForPush(const Key &key) {
  456. std::vector<float> result(1, 0);
  457. PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
  458. MS_LOG(INFO) << "key:" << key;
  459. if (result[0] > 0) {
  460. MS_LOG(INFO) << "IsReadyForPush:";
  461. return true;
  462. } else {
  463. MS_LOG(INFO) << "IsReadyForPush:";
  464. return false;
  465. }
  466. }
  467. bool Worker::IsReadyForPull(const Key &key) {
  468. std::vector<float> result(1, 0);
  469. PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
  470. if (result[0] > 0) {
  471. MS_LOG(INFO) << "IsReadyForPull";
  472. return true;
  473. } else {
  474. MS_LOG(INFO) << "IsReadyForPull";
  475. return false;
  476. }
  477. }
  478. void Worker::PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
  479. const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
  480. const size_t segment_size, float *gradient, int *indices) {
  481. MS_EXCEPTION_IF_NULL(all_indice);
  482. MS_EXCEPTION_IF_NULL(gradient);
  483. MS_EXCEPTION_IF_NULL(indices);
  484. int64_t offset = 0;
  485. int64_t index = 0;
  486. size_t segment_data_size = segment_size * sizeof(float);
  487. size_t dst_size;
  488. size_t src_size;
  489. void *dst_data = nullptr;
  490. void *src_data = nullptr;
  491. for (auto &pair : indice_to_grads) {
  492. if (distinct_ids.count(pair.first) == 0) {
  493. continue;
  494. }
  495. indices[index++] = pair.first;
  496. dst_size = segment_data_size;
  497. src_size = segment_data_size;
  498. dst_data = gradient + offset;
  499. src_data = pair.second;
  500. MS_EXCEPTION_IF_NULL(dst_data);
  501. MS_EXCEPTION_IF_NULL(src_data);
  502. auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size);
  503. if (ret != 0) {
  504. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  505. return;
  506. }
  507. offset += segment_size;
  508. }
  509. }
  510. void Worker::BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
  511. const float *original_data, const float *grads, int *indices,
  512. std::vector<float> *reduced_data) {
  513. MS_EXCEPTION_IF_NULL(original_data);
  514. MS_EXCEPTION_IF_NULL(grads);
  515. MS_EXCEPTION_IF_NULL(indices);
  516. MS_EXCEPTION_IF_NULL(reduced_data);
  517. int64_t offset = 0;
  518. size_t dst_size = 0;
  519. size_t src_size = 0;
  520. void *dst_data = nullptr;
  521. void *src_data = nullptr;
  522. for (size_t i = 0; i < lengths.size(); i++) {
  523. if (i != grad_index && i != indice_index) {
  524. int data_size = lengths[i] * sizeof(float);
  525. dst_size = data_size;
  526. src_size = data_size;
  527. dst_data = reduced_data->data() + offset;
  528. src_data = const_cast<float *>(original_data) + offset;
  529. MS_EXCEPTION_IF_NULL(dst_data);
  530. MS_EXCEPTION_IF_NULL(src_data);
  531. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  532. if (ret != 0) {
  533. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  534. return;
  535. }
  536. }
  537. offset += lengths[i];
  538. }
  539. // Fill the reduced gradient
  540. int64_t grad_offset = 0;
  541. for (size_t i = 0; i < grad_index; i++) {
  542. grad_offset += lengths[i];
  543. }
  544. int64_t data_size = lengths[grad_index] * sizeof(float);
  545. dst_size = data_size;
  546. src_size = data_size;
  547. dst_data = reduced_data->data() + grad_offset;
  548. src_data = const_cast<float *>(grads);
  549. MS_EXCEPTION_IF_NULL(dst_data);
  550. MS_EXCEPTION_IF_NULL(src_data);
  551. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  552. if (ret != 0) {
  553. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  554. return;
  555. }
  556. // Fill the reduced indice
  557. int64_t indice_offset = grad_offset + lengths[grad_index];
  558. data_size = lengths[indice_index] * sizeof(float);
  559. float *indice_data = reduced_data->data() + indice_offset;
  560. dst_size = data_size;
  561. src_size = data_size;
  562. dst_data = indice_data;
  563. src_data = indices;
  564. MS_EXCEPTION_IF_NULL(dst_data);
  565. MS_EXCEPTION_IF_NULL(src_data);
  566. ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  567. if (ret != 0) {
  568. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  569. return;
  570. }
  571. }
  572. void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
  573. int cmd, int64_t priority) {
  574. KVMessage kvs;
  575. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  576. *kvs.mutable_values() = {vals.begin(), vals.end()};
  577. *kvs.mutable_len() = {lens.begin(), lens.end()};
  578. MS_LOG(INFO) << "the result is:" << embedding_table_ranges_.count(keys[0]);
  579. if (embedding_table_ranges_.count(keys[0])) {
  580. if (cmd == kInitWeightsCmd) {
  581. SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {});
  582. } else {
  583. std::string kv_data = kvs.SerializeAsString();
  584. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  585. int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
  586. if (ret != 0) {
  587. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  588. return;
  589. }
  590. worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), cmd);
  591. }
  592. } else {
  593. SendForPush(cmd, kvs, round_robin_partitioner_, {});
  594. }
  595. }
  596. void Worker::PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
  597. size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size) {
  598. KVMessage kvs;
  599. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  600. *kvs.mutable_values() = {vals.begin(), vals.end()};
  601. *kvs.mutable_len() = {lens.begin(), lens.end()};
  602. if (embedding_table_ranges_.count(keys[0])) {
  603. std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
  604. SendForPush(kPushCmd, kvs, sparse_partitioner_, attrs);
  605. } else {
  606. SendForPush(kPushCmd, kvs, round_robin_partitioner_, {});
  607. }
  608. }
  609. void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens, int cmd,
  610. int64_t priority) {
  611. MS_EXCEPTION_IF_NULL(vals);
  612. KVMessage kvs;
  613. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  614. if (embedding_table_ranges_.count(keys[0])) {
  615. SendForPull(cmd, kvs, broadcast_partitioner_, {}, vals, lens);
  616. } else {
  617. SendForPull(cmd, kvs, round_robin_partitioner_, {}, vals, lens);
  618. }
  619. }
  620. void Worker::LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
  621. const std::map<int64_t, int64_t> &attrs) {
  622. MS_EXCEPTION_IF_NULL(partition);
  623. const Key &key = send.key();
  624. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  625. partition->resize(ranges.size());
  626. for (size_t i = 0; i < ranges.size(); i++) {
  627. const EmbeddingTableShardMetadata &range = ranges[i];
  628. const auto &begin = range.begin();
  629. const auto &end = range.end();
  630. std::unordered_set<int32_t> unique_ids;
  631. auto &kvs = partition->at(i).second;
  632. kvs.set_key(key);
  633. std::for_each(send.keys().begin(), send.keys().end(), [&](int32_t lookup_id) {
  634. if (lookup_id >= SizeToInt(begin) && lookup_id <= SizeToInt(end)) {
  635. unique_ids.insert(lookup_id);
  636. }
  637. });
  638. MS_LOG(DEBUG) << "The unique ids size is:" << unique_ids.size();
  639. for (const auto &lookup_id : unique_ids) {
  640. kvs.add_keys(lookup_id);
  641. kvs.add_values(0.0f);
  642. }
  643. if (kvs.keys().empty()) {
  644. partition->at(i).first = false;
  645. } else {
  646. partition->at(i).first = true;
  647. }
  648. }
  649. }
  650. void Worker::SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
  651. const std::map<int64_t, int64_t> &attrs) {
  652. MS_EXCEPTION_IF_NULL(partition);
  653. // Init variables
  654. float *data = const_cast<float *>(send.values().data());
  655. if (attrs.count(0) == 0 || attrs.count(1) == 0 || attrs.count(2) == 0 || attrs.count(3) == 0) {
  656. MS_LOG(EXCEPTION) << "Invalid attrs keys";
  657. }
  658. auto iter = attrs.find(0);
  659. size_t grad_index = static_cast<size_t>(iter->second);
  660. iter = attrs.find(1);
  661. size_t indice_index = static_cast<size_t>(iter->second);
  662. iter = attrs.find(2);
  663. size_t first_dim_size = static_cast<size_t>(iter->second);
  664. iter = attrs.find(3);
  665. size_t outer_dim_size = static_cast<size_t>(iter->second);
  666. int grad_size = send.len()[grad_index];
  667. int indice_size = send.len()[indice_index];
  668. int segment_size = grad_size / indice_size;
  669. int64_t grad_offset = 0;
  670. int64_t indice_offset = 0;
  671. for (size_t i = 0; i < grad_index; i++) {
  672. grad_offset += send.len()[i];
  673. }
  674. for (size_t j = 0; j < indice_index; j++) {
  675. indice_offset += send.len()[j];
  676. }
  677. float *grad_data = data + grad_offset;
  678. void *indice_data_temp = data + indice_offset;
  679. int *indice_data = reinterpret_cast<int *>(indice_data_temp);
  680. // Build the mappings of indice to gradient
  681. std::vector<std::pair<int, float *>> indice_to_grads;
  682. for (int i = 0; i < indice_size; i++) {
  683. int indice = indice_data[i];
  684. float *grad = grad_data + i * segment_size;
  685. indice_to_grads.push_back(std::make_pair(indice, grad));
  686. }
  687. const Key &key = send.keys()[0];
  688. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  689. partition->resize(ranges.size());
  690. // Construct reduced sparse data for each server
  691. for (size_t i = 0; i < ranges.size(); i++) {
  692. const EmbeddingTableShardMetadata &range = ranges[i];
  693. const auto &begin = range.begin();
  694. const auto &end = range.end();
  695. auto &kvs = partition->at(i).second;
  696. *kvs.mutable_keys() = {send.keys().begin(), send.keys().end()};
  697. *kvs.mutable_len() = {send.len().begin(), send.len().end()};
  698. // Prepare the sparse gradient and indice
  699. std::vector<int> indice_ids;
  700. std::unordered_set<int> distinct_ids;
  701. for (int j = 0; j < indice_size; j++) {
  702. size_t indice = static_cast<size_t>(indice_data[j]);
  703. if (indice >= begin && indice <= end) {
  704. indice_ids.push_back(indice);
  705. distinct_ids.insert(indice);
  706. }
  707. }
  708. size_t indices_size = indice_ids.size();
  709. if (indices_size > 0) {
  710. int partition_segment_size = indices_size * segment_size;
  711. std::vector<float> src_grad_data(partition_segment_size);
  712. std::vector<int> src_indice_data(indices_size);
  713. PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(),
  714. src_indice_data.data());
  715. // Reduce the sparse gradient and indice
  716. std::vector<float> new_grad(partition_segment_size);
  717. std::vector<int> new_indices(indices_size);
  718. mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
  719. Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size,
  720. first_dim_size, outer_dim_size, &unique_sparse_grad);
  721. // Update the length of reduce sparse gradient and indice
  722. std::vector<int> reduced_lens;
  723. reduced_lens = {kvs.len().begin(), kvs.len().end()};
  724. reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
  725. reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
  726. // Build the sparse value to be sent
  727. size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>());
  728. std::vector<float> reduced_data(total_size, 0);
  729. BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
  730. unique_sparse_grad.indices_, &reduced_data);
  731. *kvs.mutable_len() = {reduced_lens.begin(), reduced_lens.end()};
  732. *kvs.mutable_values() = {reduced_data.begin(), reduced_data.end()};
  733. }
  734. if (indices_size == 0) {
  735. std::vector<float> no_keys;
  736. std::vector<float> no_vals;
  737. std::vector<float> no_lens;
  738. no_keys.push_back(key);
  739. no_vals.push_back(-100);
  740. *kvs.mutable_values() = {no_vals.begin(), no_vals.end()};
  741. *kvs.mutable_len() = {no_lens.begin(), no_lens.end()};
  742. }
  743. partition->at(i).first = true;
  744. }
  745. }
  746. void Worker::RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  747. const std::map<int64_t, int64_t> &attrs) {
  748. MS_EXCEPTION_IF_NULL(partition);
  749. partition->resize(server_num_);
  750. auto keys = send.keys();
  751. auto values = send.values();
  752. auto lens = send.len();
  753. MS_LOG(INFO) << "the key size is:" << send.keys_size() << " the values size is:" << send.values_size()
  754. << " the lens:" << send.len_size();
  755. int64_t len;
  756. Key param_key;
  757. for (int i = 0; i < send.keys_size(); i++) {
  758. param_key = keys[i];
  759. int64_t server_id = key_to_server_id_[param_key];
  760. if (!partition->at(server_id).first) {
  761. partition->at(server_id).first = true;
  762. }
  763. KVMessage &server_kv_pairs = partition->at(server_id).second;
  764. server_kv_pairs.add_keys(param_key);
  765. if (values.empty()) {
  766. continue;
  767. }
  768. len = lens[i];
  769. int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
  770. auto val_begin = values.begin() + offset;
  771. auto val_end = val_begin + len;
  772. for (auto it = val_begin; it != val_end; ++it) {
  773. server_kv_pairs.add_values(*it);
  774. }
  775. server_kv_pairs.add_len(len);
  776. }
  777. }
  778. void Worker::WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
  779. const std::map<int64_t, int64_t> &attrs) {
  780. MS_EXCEPTION_IF_NULL(partition);
  781. partition->resize(server_num_);
  782. auto keys = send.keys();
  783. auto values = send.values();
  784. auto lens = send.len();
  785. size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
  786. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[keys[0]]);
  787. for (size_t i = 0; i < ranges.size(); i++) {
  788. size_t offset_begin = ranges[i].begin() * col_cnt;
  789. size_t offset_end = (ranges[i].end() + 1) * col_cnt;
  790. KVMessage kvs;
  791. *kvs.mutable_keys() = keys;
  792. *kvs.mutable_values() = {values.begin() + offset_begin, values.begin() + offset_end};
  793. kvs.add_len(offset_end - offset_begin);
  794. partition->at(i).first = true;
  795. partition->at(i).second = kvs;
  796. }
  797. }
  798. void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  799. const std::map<int64_t, int64_t> &attrs) {
  800. MS_EXCEPTION_IF_NULL(partition);
  801. const float *embedding_vals = send.values().data();
  802. const int *lookup_ids = send.len().data();
  803. size_t val_size = send.values_size();
  804. size_t id_size = send.len_size();
  805. size_t embedding_dim = val_size / id_size;
  806. const Key &key = send.keys()[0];
  807. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  808. partition->resize(ranges.size());
  809. for (size_t i = 0; i < ranges.size(); i++) {
  810. const EmbeddingTableShardMetadata &range = ranges[i];
  811. const auto &begin = range.begin();
  812. const auto &end = range.end();
  813. auto &kvs = partition->at(i).second;
  814. kvs.add_keys(key);
  815. for (size_t j = 0; j < id_size; j++) {
  816. auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
  817. if (lookup_id >= begin && lookup_id <= end) {
  818. kvs.add_keys(lookup_id);
  819. for (size_t k = 0; k < embedding_dim; k++) {
  820. kvs.add_values(embedding_vals[j * embedding_dim + k]);
  821. }
  822. }
  823. }
  824. if (kvs.keys_size() <= 1) {
  825. partition->at(i).first = false;
  826. } else {
  827. partition->at(i).first = true;
  828. }
  829. }
  830. }
  831. void Worker::BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  832. const std::map<int64_t, int64_t> &attrs) {
  833. MS_EXCEPTION_IF_NULL(partition);
  834. partition->resize(server_num_);
  835. for (int64_t i = 0; i < server_num_; i++) {
  836. partition->at(i).first = true;
  837. partition->at(i).second = send;
  838. }
  839. }
  840. void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
  841. const std::map<int64_t, int64_t> &attrs) {
  842. PartitionKVMessages messages;
  843. partitioner(send, &messages, attrs);
  844. std::vector<uint32_t> rank_ids;
  845. std::vector<DataPtr> data;
  846. std::vector<size_t> sizes;
  847. for (size_t i = 0; i < messages.size(); i++) {
  848. if (messages.at(i).first) {
  849. rank_ids.push_back(i);
  850. std::string kv_data = messages.at(i).second.SerializeAsString();
  851. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  852. int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
  853. if (ret != 0) {
  854. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  855. return;
  856. }
  857. data.push_back(res);
  858. sizes.push_back(kv_data.length());
  859. }
  860. }
  861. worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd);
  862. }
  863. void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
  864. const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens) {
  865. PartitionKVMessages messages;
  866. partitioner(send, &messages, {});
  867. std::vector<uint32_t> rank_ids;
  868. std::vector<DataPtr> data;
  869. std::vector<size_t> sizes;
  870. for (size_t i = 0; i < messages.size(); i++) {
  871. if (messages.at(i).first) {
  872. rank_ids.push_back(i);
  873. std::string kv_data = messages.at(i).second.SerializeAsString();
  874. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  875. int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
  876. if (ret != 0) {
  877. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  878. return;
  879. }
  880. data.push_back(res);
  881. sizes.push_back(kv_data.length());
  882. }
  883. }
  884. std::vector<VectorPtr> resp;
  885. worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
  886. vals->clear();
  887. for (size_t i = 0; i < resp.size(); ++i) {
  888. KVMessage message;
  889. message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
  890. std::copy(message.values().begin(), message.values().end(), std::back_inserter(*vals));
  891. if (lens) {
  892. lens->clear();
  893. std::copy(message.len().begin(), message.len().end(), std::back_inserter(*lens));
  894. }
  895. }
  896. }
  897. } // namespace ps
  898. } // namespace mindspore