You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_info_builder.cc 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/optimizer_info_builder.h"
  17. #include <vector>
  18. #include <memory>
  19. #include <functional>
  20. #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
  21. namespace mindspore {
  22. namespace ps {
  23. using mindspore::kernel::ps::SparseApplyFtrlPSKernel;
  24. OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr<PServerKernel> &pserver_kernel,
  25. const WeightPtr &weight, const Keys &keys, const Values &values,
  26. const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) {
  27. MS_EXCEPTION_IF_NULL(pserver_kernel);
  28. MS_EXCEPTION_IF_NULL(inputs_shape);
  29. OptimizerInfo *optim_info = BuildInputs(weight, keys, values, lens, inputs_shape, worker_num, pserver_kernel);
  30. MS_EXCEPTION_IF_NULL(optim_info);
  31. std::vector<size_t> ws_sizes = pserver_kernel->workspace_sizes();
  32. BuildWorkspaces(optim_info, ws_sizes, worker_num);
  33. BuildOutputs(optim_info, worker_num);
  34. return optim_info;
  35. }
  36. void OptimizerInfoBuilder::BuildWorkspaces(OptimizerInfo *info, const std::vector<size_t> &ws_sizes,
  37. size_t worker_num) {
  38. for (size_t i = 0; i < ws_sizes.size(); i++) {
  39. size_t size = ws_sizes[i];
  40. AddressPtr workspace = std::make_shared<kernel::Address>();
  41. MS_EXCEPTION_IF_NULL(workspace);
  42. workspace->addr = new float[size];
  43. MS_EXCEPTION_IF_NULL(workspace->addr);
  44. workspace->size = size;
  45. info->AddWorkspace(workspace);
  46. }
  47. }
  48. template <typename T>
  49. AddressPtr OptimizerInfoBuilder::GenInputAddrPtr(const std::string &optim_type, const std::string &input_name,
  50. void *ps_data, const Lengths &ps_lens,
  51. const InputsShapePtr &inputs_shape) {
  52. MS_EXCEPTION_IF_NULL(ps_data);
  53. // Take note of that the data type maybe inconsistent in ps_data.
  54. MS_LOG(INFO) << "Get input address pointer for optimizer:" << optim_type << ", input name:" << input_name;
  55. AddressPtr addr_ptr = std::make_shared<kernel::Address>();
  56. MS_EXCEPTION_IF_NULL(addr_ptr);
  57. if (kOptimToOriginIdx.count(optim_type) == 0 || kOptimToPSSendIdx.count(optim_type) == 0) {
  58. MS_LOG(EXCEPTION) << "Optimizer type " << optim_type << " in not supported.";
  59. }
  60. const OptimOriginIdx &origin_input_map = kOptimToOriginIdx.at(optim_type);
  61. const OptimPSSendIdx &ps_send_index_map = kOptimToPSSendIdx.at(optim_type);
  62. if (ps_send_index_map.count(input_name) == 0 || origin_input_map.count(input_name) == 0) {
  63. MS_LOG(EXCEPTION) << "Optimizer " << optim_type << " has no input for " << input_name;
  64. }
  65. size_t ps_index = ps_send_index_map.at(input_name);
  66. if (ps_index == INDEX_NOT_SEND) {
  67. MS_LOG(EXCEPTION) << "Input " << input_name << " is not supposed to be sent to PS.";
  68. }
  69. size_t addr_data_size, addr_data_offset;
  70. if (inputs_shape != nullptr) {
  71. // addr_data_size should be calculated by inputs_shape if it's passed.
  72. size_t origin_index = origin_input_map.at(input_name);
  73. EXC_IF_VEC_IDX_OOB((*inputs_shape), origin_index);
  74. auto shape = *((*inputs_shape)[origin_index]);
  75. addr_data_size = std::accumulate(shape.begin(), shape.end(), worker_num_, std::multiplies<size_t>());
  76. } else {
  77. EXC_IF_VEC_IDX_OOB(ps_lens, ps_index);
  78. addr_data_size = ps_lens[ps_index];
  79. }
  80. addr_data_offset = std::accumulate(ps_lens.begin(), ps_lens.begin() + ps_index, 0, std::plus<int>());
  81. // The size in ps_lens instead of addr_data_size is the size of real data.
  82. T *buffer = new T[addr_data_size];
  83. addr_ptr->size = ps_lens[ps_index] * sizeof(T);
  84. addr_ptr->addr = buffer;
  85. size_t dst_size = addr_ptr->size;
  86. size_t src_size = addr_ptr->size;
  87. void *dst_data = addr_ptr->addr;
  88. void *src_data = reinterpret_cast<T *>(ps_data) + addr_data_offset;
  89. MS_EXCEPTION_IF_NULL(dst_data);
  90. MS_EXCEPTION_IF_NULL(src_data);
  91. int ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  92. if (ret != 0) {
  93. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  94. delete[] buffer;
  95. buffer = nullptr;
  96. return nullptr;
  97. }
  98. return addr_ptr;
  99. }
  100. OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
  101. const Lengths &lens, const InputsShapePtr &inputs_shape,
  102. size_t worker_num, const std::shared_ptr<PServerKernel> &) {
  103. AddressPtr weight_addr = std::make_shared<kernel::Address>();
  104. MS_EXCEPTION_IF_NULL(weight_addr);
  105. weight_addr->addr = weight->data();
  106. weight_addr->size = weight->size() * sizeof(float);
  107. AddressPtr accumulate = std::make_shared<kernel::Address>();
  108. MS_EXCEPTION_IF_NULL(accumulate);
  109. accumulate->addr = new float[weight->size()];
  110. MS_EXCEPTION_IF_NULL(accumulate->addr);
  111. accumulate->size = weight->size() * sizeof(float);
  112. int ret = memset_s(accumulate->addr, accumulate->size, 0x00, accumulate->size);
  113. if (ret != 0) {
  114. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  115. delete[] reinterpret_cast<float *>(accumulate->addr);
  116. accumulate->addr = nullptr;
  117. return nullptr;
  118. }
  119. AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", values.data(), lens);
  120. AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", values.data(), lens);
  121. AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", values.data(), lens);
  122. return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum);
  123. }
  124. OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
  125. const Lengths &lens, const InputsShapePtr &inputs_shape,
  126. size_t worker_num, const std::shared_ptr<PServerKernel> &) {
  127. AddressPtr weight_addr = std::make_shared<kernel::Address>();
  128. MS_EXCEPTION_IF_NULL(weight_addr);
  129. weight_addr->addr = weight->data();
  130. weight_addr->size = weight->size() * sizeof(float);
  131. AddressPtr m = std::make_shared<kernel::Address>();
  132. MS_EXCEPTION_IF_NULL(m);
  133. m->addr = new float[weight->size()];
  134. MS_EXCEPTION_IF_NULL(m->addr);
  135. m->size = weight->size() * sizeof(float);
  136. int ret = memset_s(m->addr, m->size, 0x00, m->size);
  137. if (ret != 0) {
  138. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  139. delete[] reinterpret_cast<float *>(m->addr);
  140. m->addr = nullptr;
  141. return nullptr;
  142. }
  143. AddressPtr v = std::make_shared<kernel::Address>();
  144. MS_EXCEPTION_IF_NULL(v);
  145. v->addr = new float[weight->size()];
  146. MS_EXCEPTION_IF_NULL(v->addr);
  147. v->size = weight->size() * sizeof(float);
  148. ret = memset_s(v->addr, v->size, 0x00, v->size);
  149. if (ret != 0) {
  150. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  151. delete[] reinterpret_cast<float *>(v->addr);
  152. v->addr = nullptr;
  153. delete[] reinterpret_cast<float *>(m->addr);
  154. m->addr = nullptr;
  155. return nullptr;
  156. }
  157. AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", values.data(), lens);
  158. AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", values.data(), lens);
  159. AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", values.data(), lens);
  160. AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", values.data(), lens);
  161. AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", values.data(), lens);
  162. AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", values.data(), lens);
  163. AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", values.data(), lens, inputs_shape);
  164. AddressPtr indices = GenInputAddrPtr<float>(kSparseAdam, "indices", values.data(), lens, inputs_shape);
  165. return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon,
  166. grad, indices);
  167. }
  168. OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
  169. const Lengths &lens, const InputsShapePtr &inputs_shape,
  170. size_t worker_num,
  171. const std::shared_ptr<PServerKernel> &pserver_kernel) {
  172. MS_EXCEPTION_IF_NULL(inputs_shape);
  173. AddressPtr weight_addr = std::make_shared<kernel::Address>();
  174. MS_EXCEPTION_IF_NULL(weight_addr);
  175. weight_addr->addr = weight->data();
  176. weight_addr->size = weight->size() * sizeof(float);
  177. AddressPtr accum = std::make_shared<kernel::Address>();
  178. MS_EXCEPTION_IF_NULL(accum);
  179. accum->addr = new float[weight->size()];
  180. MS_EXCEPTION_IF_NULL(accum->addr);
  181. accum->size = weight->size() * sizeof(float);
  182. for (size_t i = 0; i < weight->size(); i++) {
  183. float *tmp = reinterpret_cast<float *>(accum->addr);
  184. tmp[i] = std::dynamic_pointer_cast<SparseApplyFtrlPSKernel>(pserver_kernel)->init_accum();
  185. }
  186. AddressPtr linear = std::make_shared<kernel::Address>();
  187. MS_EXCEPTION_IF_NULL(linear);
  188. linear->addr = new float[weight->size()];
  189. MS_EXCEPTION_IF_NULL(linear->addr);
  190. int ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float));
  191. if (ret != 0) {
  192. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  193. delete[] reinterpret_cast<float *>(linear->addr);
  194. linear->addr = nullptr;
  195. return nullptr;
  196. }
  197. linear->size = weight->size() * sizeof(float);
  198. AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", values.data(), lens, inputs_shape);
  199. AddressPtr indices = GenInputAddrPtr<float>(kSparseFtrl, "indices", values.data(), lens, inputs_shape);
  200. return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices);
  201. }
  202. } // namespace ps
  203. } // namespace mindspore