You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_info.cc 16 kB

5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/optimizer_info.h"
  17. #include <map>
  18. #include <memory>
  19. #include <string>
  20. #include <functional>
  21. #include "ps/util.h"
  22. namespace mindspore {
  23. namespace ps {
  24. void OptimizerInfo::AddWorkspace(const AddressPtr &workspace) {
  25. MS_EXCEPTION_IF_NULL(workspace);
  26. workspaces_.push_back(workspace);
  27. }
  28. const std::vector<AddressPtr> &OptimizerInfo::inputs() const { return inputs_; }
  29. const std::vector<AddressPtr> &OptimizerInfo::workspaces() const { return workspaces_; }
  30. const std::vector<AddressPtr> &OptimizerInfo::outputs() const { return outputs_; }
  31. bool OptimizerInfo::IsSparse() const { return false; }
  32. const size_t OptimizerInfo::indice_size() const { return 0; }
  33. size_t OptimizerInfo::grad_index() { return 0; }
  34. size_t OptimizerInfo::indices_index() { return 0; }
  35. template <typename T>
  36. void OptimizerInfo::UpdateOptimInputValue(const std::string &optim_type, const std::string &input_name, void *data,
  37. const Lengths &lens) {
  38. MS_EXCEPTION_IF_NULL(data);
  39. if (kOptimToOriginIdx.count(optim_type) == 0 || kOptimToPSSendIdx.count(optim_type) == 0) {
  40. MS_LOG(EXCEPTION) << "Optimizer type " << optim_type << " in not supported.";
  41. }
  42. const OptimOriginIdx &origin_input_map = kOptimToOriginIdx.at(optim_type);
  43. const OptimPSSendIdx &ps_send_index_map = kOptimToPSSendIdx.at(optim_type);
  44. if (ps_send_index_map.count(input_name) == 0 || origin_input_map.count(input_name) == 0) {
  45. MS_LOG(EXCEPTION) << "Optimizer " << optim_type << " has no input for " << input_name;
  46. }
  47. size_t origin_index = origin_input_map.at(input_name);
  48. size_t ps_send_index = ps_send_index_map.at(input_name);
  49. if (ps_send_index >= lens.size() || origin_index >= inputs_.size()) {
  50. MS_LOG(EXCEPTION) << "Index is out of bound for optimizer " << optim_type << ", origin_index:" << origin_index
  51. << ", ps_send_index:" << ps_send_index;
  52. }
  53. EXC_IF_VEC_IDX_OOB(lens, ps_send_index);
  54. size_t size = IntToSize(lens[ps_send_index]) * sizeof(T);
  55. int offset = std::accumulate(lens.begin(), lens.begin() + SizeToInt(ps_send_index), 0, std::plus<int>());
  56. AddressPtr optim_input = inputs_[origin_index];
  57. MS_EXCEPTION_IF_NULL(optim_input);
  58. void *dst_data = optim_input->addr;
  59. T *src_data = reinterpret_cast<T *>(data) + offset;
  60. MS_EXCEPTION_IF_NULL(dst_data);
  61. MS_EXCEPTION_IF_NULL(src_data);
  62. int64_t ret = memcpy_s(optim_input->addr, optim_input->size, src_data, size);
  63. if (ret != 0) {
  64. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  65. return;
  66. }
  67. return;
  68. }
  69. void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
  70. MS_EXCEPTION_IF_NULL(gradient()->addr);
  71. float *accum_grad_data = reinterpret_cast<float *>(gradient()->addr);
  72. size_t size = gradient()->size / sizeof(float);
  73. size_t grad_index = this->grad_index();
  74. size_t grad_offset = 0;
  75. for (size_t i = 0; i < grad_index; i++) {
  76. grad_offset += IntToSize(lengths[i]);
  77. }
  78. float *grad_data = const_cast<float *>(values.data()) + grad_offset;
  79. MS_EXCEPTION_IF_NULL(grad_data);
  80. #define google mindspore_private
  81. CHECK_EQ(size, IntToSize(lengths[grad_index]));
  82. #undef google
  83. for (size_t i = 0; i < size; i++) {
  84. accum_grad_data[i] += grad_data[i];
  85. }
  86. }
  87. void DenseOptimInfo::ComputeMean(const std::vector<std::vector<size_t>> &, size_t n, size_t, size_t) {
  88. if (n > 1) {
  89. MS_EXCEPTION_IF_NULL(gradient()->addr);
  90. float *accum_grad_data = reinterpret_cast<float *>(gradient()->addr);
  91. size_t size = gradient()->size / sizeof(float);
  92. for (size_t i = 0; i < size; i++) {
  93. accum_grad_data[i] /= n;
  94. }
  95. }
  96. }
  97. void DenseOptimInfo::Reset() {
  98. MS_EXCEPTION_IF_NULL(gradient()->addr);
  99. int64_t ret = memset_s(gradient()->addr, gradient()->size, 0x00, gradient()->size);
  100. if (ret != 0) {
  101. MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")";
  102. return;
  103. }
  104. }
  105. void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
  106. // Append grad data to the end
  107. MS_EXCEPTION_IF_NULL(gradient()->addr);
  108. float *accum_grad_data = reinterpret_cast<float *>(gradient()->addr);
  109. size_t grad_index = this->grad_index();
  110. size_t grad_offset = 0;
  111. for (size_t i = 0; i < grad_index; i++) {
  112. grad_offset += IntToSize(lengths[i]);
  113. }
  114. float *incr_grad_data = const_cast<float *>(values.data()) + grad_offset;
  115. MS_EXCEPTION_IF_NULL(incr_grad_data);
  116. size_t incr_grad_size = IntToSize(lengths[grad_index]) * sizeof(float);
  117. size_t dst_size = incr_grad_size;
  118. size_t src_size = incr_grad_size;
  119. void *dst_data = accum_grad_data + grads_offset_;
  120. void *src_data = incr_grad_data;
  121. MS_EXCEPTION_IF_NULL(dst_data);
  122. MS_EXCEPTION_IF_NULL(src_data);
  123. int64_t ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  124. if (ret != 0) {
  125. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  126. return;
  127. }
  128. grads_offset_ += IntToSize(lengths[grad_index]);
  129. gradient()->size += incr_grad_size;
  130. // Append indice data to the end
  131. MS_EXCEPTION_IF_NULL(indices()->addr);
  132. int *accum_indices_data = reinterpret_cast<int *>(indices()->addr);
  133. MS_EXCEPTION_IF_NULL(accum_indices_data);
  134. size_t indices_index = this->indices_index();
  135. size_t indice_offset = 0;
  136. for (size_t i = 0; i < indices_index; i++) {
  137. indice_offset += IntToSize(lengths[i]);
  138. }
  139. void *incr_indice_data_temp = const_cast<float *>(values.data()) + indice_offset;
  140. MS_EXCEPTION_IF_NULL(incr_indice_data_temp);
  141. int *incr_indice_data = reinterpret_cast<int *>(incr_indice_data_temp);
  142. MS_EXCEPTION_IF_NULL(incr_indice_data);
  143. size_t incr_indice_size = lengths[indices_index];
  144. size_t incr_indice_data_size = incr_indice_size * sizeof(int);
  145. dst_size = incr_indice_data_size;
  146. src_size = incr_indice_data_size;
  147. dst_data = accum_indices_data + indices_offset_;
  148. src_data = incr_indice_data;
  149. MS_EXCEPTION_IF_NULL(dst_data);
  150. MS_EXCEPTION_IF_NULL(src_data);
  151. auto ret2 = memcpy_s(dst_data, dst_size, src_data, src_size);
  152. if (ret2 != 0) {
  153. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")";
  154. return;
  155. }
  156. indices_offset_ += IntToSize(lengths[indices_index]);
  157. indices()->size += incr_indice_data_size;
  158. }
  159. void SparseOptimInfo::ComputeMean(const std::vector<std::vector<size_t>> &shapes, size_t n, size_t server_num,
  160. size_t rank_id) {
  161. if (n == 0 || indices()->size == 0) {
  162. MS_LOG(EXCEPTION) << "The size of shapes or indices are 0.";
  163. }
  164. size_t indices_size = static_cast<size_t>(indices()->size / sizeof(int));
  165. size_t segment_size = gradient()->size / indices()->size;
  166. std::vector<float> new_grad(indices_size * segment_size);
  167. std::vector<int> new_indices(indices_size);
  168. mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
  169. if (shapes.size() < 2 || shapes[1].empty()) {
  170. MS_LOG(EXCEPTION) << "No input shape found";
  171. }
  172. auto input_shapes = shapes[1];
  173. if (input_shapes.size() == 0) {
  174. MS_LOG(EXCEPTION) << "Invalid input shapes";
  175. }
  176. size_t first_dim_size = input_shapes.front();
  177. size_t outer_dim_size = segment_size;
  178. if (first_dim_size == 0 || outer_dim_size == 0) {
  179. MS_LOG(ERROR) << "Invalid first dim size";
  180. }
  181. MS_EXCEPTION_IF_NULL(gradient()->addr);
  182. MS_EXCEPTION_IF_NULL(indices()->addr);
  183. float *grad_data = reinterpret_cast<float *>(gradient()->addr);
  184. int *indices_data = reinterpret_cast<int *>(indices()->addr);
  185. if (sharded_) {
  186. size_t original_row_count = input_shapes.front();
  187. if (original_row_count > 0) {
  188. size_t offset = 0;
  189. std::map<int64_t, int64_t> rank_dims =
  190. Util::AllRankLocalShard(SizeToLong(original_row_count), SizeToLong(rank_id), SizeToLong(server_num));
  191. for (size_t i = 0; i < rank_id; i++) {
  192. if (rank_dims.count(i) == 0) {
  193. MS_LOG(EXCEPTION) << "No local shard number for rank " << i;
  194. }
  195. offset += LongToSize(rank_dims[i]);
  196. }
  197. for (size_t j = 0; j < indices_size; j++) {
  198. indices_data[j] -= SizeToInt(offset);
  199. }
  200. }
  201. }
  202. Util::ReduceSparseGradient(grad_data, indices_data, indices_size, segment_size, first_dim_size, outer_dim_size,
  203. &unique_sparse_grad);
  204. size_t reduced_grad_size = unique_sparse_grad.indices_size_ * segment_size * sizeof(float);
  205. MS_EXCEPTION_IF_NULL(unique_sparse_grad.value_);
  206. int ret = memcpy_s(gradient()->addr, gradient()->size, unique_sparse_grad.value_, reduced_grad_size);
  207. if (ret != 0) {
  208. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  209. return;
  210. }
  211. size_t reduced_indice_size = unique_sparse_grad.indices_size_ * sizeof(int);
  212. MS_EXCEPTION_IF_NULL(unique_sparse_grad.indices_);
  213. ret = memcpy_s(indices()->addr, indices()->size, unique_sparse_grad.indices_, reduced_indice_size);
  214. if (ret != 0) {
  215. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  216. return;
  217. }
  218. gradient()->size = reduced_grad_size;
  219. indices()->size = reduced_indice_size;
  220. for (size_t i = 0; i < unique_sparse_grad.indices_size_ * segment_size; i++) {
  221. grad_data[i] = grad_data[i] / n;
  222. }
  223. }
  224. void SparseOptimInfo::Reset() {
  225. gradient()->size = 0;
  226. indices()->size = 0;
  227. grads_offset_ = 0;
  228. indices_offset_ = 0;
  229. }
  230. MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate,
  231. const AddressPtr &learning_rate, const AddressPtr &gradient,
  232. const AddressPtr &momentum) {
  233. MS_EXCEPTION_IF_NULL(weight);
  234. MS_EXCEPTION_IF_NULL(accumulate);
  235. MS_EXCEPTION_IF_NULL(learning_rate);
  236. MS_EXCEPTION_IF_NULL(gradient);
  237. MS_EXCEPTION_IF_NULL(momentum);
  238. inputs_.push_back(weight);
  239. inputs_.push_back(accumulate);
  240. inputs_.push_back(learning_rate);
  241. inputs_.push_back(gradient);
  242. inputs_.push_back(momentum);
  243. }
  244. void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) {
  245. UpdateOptimInputValue<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens);
  246. }
  247. const size_t SparseOptimInfo::indice_size() const { return indices_offset_; }
  248. const AddressPtr &MomentumOptimInfo::gradient() {
  249. size_t origin_grad_index = kMomentumOriginIdx.at("grad");
  250. EXC_IF_VEC_IDX_OOB(inputs_, origin_grad_index);
  251. MS_EXCEPTION_IF_NULL(inputs_[origin_grad_index]);
  252. return inputs_[origin_grad_index];
  253. }
  254. const AddressPtr &MomentumOptimInfo::indices() {
  255. size_t origin_grad_index = kMomentumOriginIdx.at("grad");
  256. EXC_IF_VEC_IDX_OOB(inputs_, origin_grad_index);
  257. MS_EXCEPTION_IF_NULL(inputs_[origin_grad_index]);
  258. return inputs_[origin_grad_index];
  259. }
  260. size_t MomentumOptimInfo::grad_index() {
  261. size_t ps_grad_index = kMomentumPSSendIdx.at("grad");
  262. return ps_grad_index;
  263. }
  264. SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v,
  265. const AddressPtr &beta1_power, const AddressPtr &beta2_power,
  266. const AddressPtr &learning_rate, const AddressPtr &beta1,
  267. const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad,
  268. const AddressPtr &indices, bool sharded) {
  269. MS_EXCEPTION_IF_NULL(weight);
  270. MS_EXCEPTION_IF_NULL(m);
  271. MS_EXCEPTION_IF_NULL(v);
  272. MS_EXCEPTION_IF_NULL(beta1_power);
  273. MS_EXCEPTION_IF_NULL(beta2_power);
  274. MS_EXCEPTION_IF_NULL(learning_rate);
  275. MS_EXCEPTION_IF_NULL(beta1);
  276. MS_EXCEPTION_IF_NULL(beta2);
  277. MS_EXCEPTION_IF_NULL(epsilon);
  278. MS_EXCEPTION_IF_NULL(grad);
  279. MS_EXCEPTION_IF_NULL(indices);
  280. inputs_.push_back(weight);
  281. inputs_.push_back(m);
  282. inputs_.push_back(v);
  283. inputs_.push_back(beta1_power);
  284. inputs_.push_back(beta2_power);
  285. inputs_.push_back(learning_rate);
  286. inputs_.push_back(beta1);
  287. inputs_.push_back(beta2);
  288. inputs_.push_back(epsilon);
  289. inputs_.push_back(grad);
  290. inputs_.push_back(indices);
  291. grads_offset_ = grad->size / sizeof(float);
  292. indices_offset_ = indices->size / sizeof(int);
  293. sharded_ = sharded;
  294. }
  295. void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) {
  296. UpdateOptimInputValue<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens);
  297. UpdateOptimInputValue<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens);
  298. UpdateOptimInputValue<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens);
  299. UpdateOptimInputValue<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens);
  300. UpdateOptimInputValue<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens);
  301. UpdateOptimInputValue<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens);
  302. }
  303. const AddressPtr &SparseAdamOptimInfo::gradient() {
  304. size_t origin_grad_index = kSparseAdamOriginIdx.at("grad");
  305. EXC_IF_VEC_IDX_OOB(inputs_, origin_grad_index);
  306. MS_EXCEPTION_IF_NULL(inputs_[origin_grad_index]);
  307. return inputs_[origin_grad_index];
  308. }
  309. const AddressPtr &SparseAdamOptimInfo::indices() {
  310. size_t origin_indices_index = kSparseAdamOriginIdx.at("indices");
  311. EXC_IF_VEC_IDX_OOB(inputs_, origin_indices_index);
  312. MS_EXCEPTION_IF_NULL(inputs_[origin_indices_index]);
  313. return inputs_[origin_indices_index];
  314. }
  315. bool SparseAdamOptimInfo::IsSparse() const { return true; }
  316. size_t SparseAdamOptimInfo::grad_index() {
  317. size_t ps_grad_index = kSparseAdamPSSendIdx.at("grad");
  318. return ps_grad_index;
  319. }
  320. size_t SparseAdamOptimInfo::indices_index() {
  321. size_t ps_indices_index = kSparseAdamPSSendIdx.at("indices");
  322. return ps_indices_index;
  323. }
  324. SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear,
  325. const AddressPtr &grad, const AddressPtr &indices, bool sharded) {
  326. MS_EXCEPTION_IF_NULL(weight);
  327. MS_EXCEPTION_IF_NULL(accum);
  328. MS_EXCEPTION_IF_NULL(linear);
  329. MS_EXCEPTION_IF_NULL(grad);
  330. MS_EXCEPTION_IF_NULL(indices);
  331. inputs_.push_back(weight);
  332. inputs_.push_back(accum);
  333. inputs_.push_back(linear);
  334. inputs_.push_back(grad);
  335. inputs_.push_back(indices);
  336. grads_offset_ = grad->size / sizeof(float);
  337. indices_offset_ = indices->size / sizeof(int);
  338. sharded_ = sharded;
  339. }
  340. const AddressPtr &SparseFtrlOptimInfo::gradient() {
  341. size_t origin_grad_index = kSparseFtrlOriginIdx.at("grad");
  342. EXC_IF_VEC_IDX_OOB(inputs_, origin_grad_index);
  343. MS_EXCEPTION_IF_NULL(inputs_[origin_grad_index]);
  344. return inputs_[origin_grad_index];
  345. }
  346. const AddressPtr &SparseFtrlOptimInfo::indices() {
  347. size_t origin_indices_index = kSparseFtrlOriginIdx.at("indices");
  348. EXC_IF_VEC_IDX_OOB(inputs_, origin_indices_index);
  349. MS_EXCEPTION_IF_NULL(inputs_[origin_indices_index]);
  350. return inputs_[origin_indices_index];
  351. }
  352. bool SparseFtrlOptimInfo::IsSparse() const { return true; }
  353. size_t SparseFtrlOptimInfo::grad_index() {
  354. size_t ps_grad_index = kSparseFtrlPSSendIdx.at("grad");
  355. return ps_grad_index;
  356. }
  357. size_t SparseFtrlOptimInfo::indices_index() {
  358. size_t ps_indices_index = kSparseFtrlPSSendIdx.at("indices");
  359. return ps_indices_index;
  360. }
  361. } // namespace ps
  362. } // namespace mindspore