You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_info_builder.cc 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/optimizer_info_builder.h"
  17. #include <vector>
  18. #include <memory>
  19. #include <functional>
  20. #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
  21. namespace mindspore {
  22. namespace ps {
  23. using mindspore::kernel::ps::SparseApplyFtrlPSKernel;
  24. OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr<PServerKernel> &pserver_kernel,
  25. const WeightPtr &weight, const Keys &keys, const Values &values,
  26. const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num,
  27. bool sharded) {
  28. MS_EXCEPTION_IF_NULL(pserver_kernel);
  29. MS_EXCEPTION_IF_NULL(inputs_shape);
  30. OptimizerInfo *optim_info =
  31. BuildInputs(weight, keys, values, lens, inputs_shape, worker_num, pserver_kernel, sharded);
  32. MS_EXCEPTION_IF_NULL(optim_info);
  33. std::vector<size_t> ws_sizes = pserver_kernel->workspace_sizes();
  34. BuildWorkspaces(optim_info, ws_sizes, worker_num);
  35. BuildOutputs(optim_info, worker_num);
  36. return optim_info;
  37. }
  38. void OptimizerInfoBuilder::BuildWorkspaces(OptimizerInfo *info, const std::vector<size_t> &ws_sizes,
  39. size_t worker_num) {
  40. for (size_t i = 0; i < ws_sizes.size(); i++) {
  41. size_t size = ws_sizes[i];
  42. AddressPtr workspace = std::make_shared<kernel::Address>();
  43. MS_EXCEPTION_IF_NULL(workspace);
  44. workspace->addr = new float[size];
  45. MS_EXCEPTION_IF_NULL(workspace->addr);
  46. workspace->size = size;
  47. info->AddWorkspace(workspace);
  48. }
  49. }
  50. template <typename T>
  51. AddressPtr OptimizerInfoBuilder::GenInputAddrPtr(const std::string &optim_type, const std::string &input_name,
  52. void *ps_data, const Lengths &ps_lens,
  53. const InputsShapePtr &inputs_shape) {
  54. MS_EXCEPTION_IF_NULL(ps_data);
  55. // Take note of that the data type maybe inconsistent in ps_data.
  56. MS_LOG(INFO) << "Get input address pointer for optimizer:" << optim_type << ", input name:" << input_name;
  57. AddressPtr addr_ptr = std::make_shared<kernel::Address>();
  58. MS_EXCEPTION_IF_NULL(addr_ptr);
  59. if (kOptimToOriginIdx.count(optim_type) == 0 || kOptimToPSSendIdx.count(optim_type) == 0) {
  60. MS_LOG(EXCEPTION) << "Optimizer type " << optim_type << " in not supported.";
  61. }
  62. const OptimOriginIdx &origin_input_map = kOptimToOriginIdx.at(optim_type);
  63. const OptimPSSendIdx &ps_send_index_map = kOptimToPSSendIdx.at(optim_type);
  64. if (ps_send_index_map.count(input_name) == 0 || origin_input_map.count(input_name) == 0) {
  65. MS_LOG(EXCEPTION) << "Optimizer " << optim_type << " has no input for " << input_name;
  66. }
  67. size_t ps_index = ps_send_index_map.at(input_name);
  68. if (ps_index == INDEX_NOT_SEND) {
  69. MS_LOG(EXCEPTION) << "Input " << input_name << " is not supposed to be sent to PS.";
  70. }
  71. size_t addr_data_size, addr_data_offset;
  72. if (inputs_shape != nullptr) {
  73. // addr_data_size should be calculated by inputs_shape if it's passed.
  74. size_t origin_index = origin_input_map.at(input_name);
  75. EXC_IF_VEC_IDX_OOB((*inputs_shape), origin_index);
  76. auto shape = *((*inputs_shape)[origin_index]);
  77. addr_data_size = std::accumulate(shape.begin(), shape.end(), worker_num_, std::multiplies<size_t>());
  78. } else {
  79. EXC_IF_VEC_IDX_OOB(ps_lens, ps_index);
  80. addr_data_size = ps_lens[ps_index];
  81. }
  82. addr_data_offset = std::accumulate(ps_lens.begin(), ps_lens.begin() + ps_index, 0, std::plus<int>());
  83. // The size in ps_lens instead of addr_data_size is the size of real data.
  84. T *buffer = new T[addr_data_size];
  85. addr_ptr->size = ps_lens[ps_index] * sizeof(T);
  86. addr_ptr->addr = buffer;
  87. size_t dst_size = addr_ptr->size;
  88. size_t src_size = addr_ptr->size;
  89. void *dst_data = addr_ptr->addr;
  90. void *src_data = reinterpret_cast<T *>(ps_data) + addr_data_offset;
  91. MS_EXCEPTION_IF_NULL(dst_data);
  92. MS_EXCEPTION_IF_NULL(src_data);
  93. int64_t ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  94. if (ret != 0) {
  95. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  96. delete[] buffer;
  97. buffer = nullptr;
  98. return nullptr;
  99. }
  100. return addr_ptr;
  101. }
  102. OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
  103. const Lengths &lens, const InputsShapePtr &inputs_shape,
  104. size_t worker_num, const std::shared_ptr<PServerKernel> &, bool) {
  105. AddressPtr weight_addr = std::make_shared<kernel::Address>();
  106. MS_EXCEPTION_IF_NULL(weight_addr);
  107. weight_addr->addr = weight->data();
  108. weight_addr->size = weight->size() * sizeof(float);
  109. AddressPtr accumulate = std::make_shared<kernel::Address>();
  110. MS_EXCEPTION_IF_NULL(accumulate);
  111. accumulate->addr = new float[weight->size()];
  112. MS_EXCEPTION_IF_NULL(accumulate->addr);
  113. accumulate->size = weight->size() * sizeof(float);
  114. int64_t ret = memset_s(accumulate->addr, accumulate->size, 0x00, accumulate->size);
  115. if (ret != 0) {
  116. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  117. delete[] reinterpret_cast<float *>(accumulate->addr);
  118. accumulate->addr = nullptr;
  119. return nullptr;
  120. }
  121. AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens);
  122. AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", const_cast<float *>(values.data()), lens);
  123. AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", const_cast<float *>(values.data()), lens);
  124. return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum);
  125. }
  126. OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
  127. const Lengths &lens, const InputsShapePtr &inputs_shape,
  128. size_t worker_num, const std::shared_ptr<PServerKernel> &,
  129. bool sharded) {
  130. AddressPtr weight_addr = std::make_shared<kernel::Address>();
  131. MS_EXCEPTION_IF_NULL(weight_addr);
  132. weight_addr->addr = weight->data();
  133. weight_addr->size = weight->size() * sizeof(float);
  134. AddressPtr m = std::make_shared<kernel::Address>();
  135. MS_EXCEPTION_IF_NULL(m);
  136. m->addr = new float[weight->size()];
  137. MS_EXCEPTION_IF_NULL(m->addr);
  138. m->size = weight->size() * sizeof(float);
  139. int64_t ret = memset_s(m->addr, m->size, 0x00, m->size);
  140. if (ret != 0) {
  141. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  142. delete[] reinterpret_cast<float *>(m->addr);
  143. m->addr = nullptr;
  144. return nullptr;
  145. }
  146. AddressPtr v = std::make_shared<kernel::Address>();
  147. MS_EXCEPTION_IF_NULL(v);
  148. v->addr = new float[weight->size()];
  149. MS_EXCEPTION_IF_NULL(v->addr);
  150. v->size = weight->size() * sizeof(float);
  151. ret = memset_s(v->addr, v->size, 0x00, v->size);
  152. if (ret != 0) {
  153. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  154. delete[] reinterpret_cast<float *>(v->addr);
  155. v->addr = nullptr;
  156. delete[] reinterpret_cast<float *>(m->addr);
  157. m->addr = nullptr;
  158. return nullptr;
  159. }
  160. AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens);
  161. AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens);
  162. AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens);
  163. AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens);
  164. AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens);
  165. AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens);
  166. AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", const_cast<float *>(values.data()), lens, inputs_shape);
  167. AddressPtr indices =
  168. GenInputAddrPtr<float>(kSparseAdam, "indices", const_cast<float *>(values.data()), lens, inputs_shape);
  169. return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon,
  170. grad, indices, sharded);
  171. }
  172. OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
  173. const Lengths &lens, const InputsShapePtr &inputs_shape,
  174. size_t worker_num,
  175. const std::shared_ptr<PServerKernel> &pserver_kernel,
  176. bool sharded) {
  177. MS_EXCEPTION_IF_NULL(inputs_shape);
  178. AddressPtr weight_addr = std::make_shared<kernel::Address>();
  179. MS_EXCEPTION_IF_NULL(weight_addr);
  180. weight_addr->addr = weight->data();
  181. weight_addr->size = weight->size() * sizeof(float);
  182. AddressPtr accum = std::make_shared<kernel::Address>();
  183. MS_EXCEPTION_IF_NULL(accum);
  184. accum->addr = new float[weight->size()];
  185. MS_EXCEPTION_IF_NULL(accum->addr);
  186. accum->size = weight->size() * sizeof(float);
  187. for (size_t i = 0; i < weight->size(); i++) {
  188. float *tmp = reinterpret_cast<float *>(accum->addr);
  189. tmp[i] = std::dynamic_pointer_cast<SparseApplyFtrlPSKernel>(pserver_kernel)->init_accum();
  190. }
  191. AddressPtr linear = std::make_shared<kernel::Address>();
  192. MS_EXCEPTION_IF_NULL(linear);
  193. linear->addr = new float[weight->size()];
  194. MS_EXCEPTION_IF_NULL(linear->addr);
  195. int64_t ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
  196. if (ret != 0) {
  197. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  198. delete[] reinterpret_cast<float *>(linear->addr);
  199. linear->addr = nullptr;
  200. return nullptr;
  201. }
  202. linear->size = weight->size() * sizeof(float);
  203. AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", const_cast<float *>(values.data()), lens, inputs_shape);
  204. AddressPtr indices =
  205. GenInputAddrPtr<float>(kSparseFtrl, "indices", const_cast<float *>(values.data()), lens, inputs_shape);
  206. return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, sharded);
  207. }
  208. } // namespace ps
  209. } // namespace mindspore