You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.h 6.5 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_COMMON_H_
  17. #define MINDSPORE_CCSRC_PS_COMMON_H_
  18. #include <iostream>
  19. #include <vector>
  20. #include <memory>
  21. #include <map>
  22. #include <string>
  23. #include "ps/ps.h"
  24. namespace mindspore {
  25. namespace ps {
  26. constexpr char kEnvCommType[] = "MS_COMM_TYPE";
  27. constexpr char kEnvInterface[] = "MS_INTERFACE";
  28. constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
  29. constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
  30. constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
  31. constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";
  32. constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE";
  33. constexpr char kDmlcInterface[] = "DMLC_INTERFACE";
  34. constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER";
  35. constexpr char kDmlcWorkerNum[] = "DMLC_NUM_WORKER";
  36. constexpr char kDmlcRole[] = "DMLC_ROLE";
  37. constexpr char kDmlcSchedulerHost[] = "DMLC_PS_ROOT_URI";
  38. constexpr char kDmlcSchedulerPort[] = "DMLC_PS_ROOT_PORT";
  39. constexpr char kCommTypeOfIBVerbs[] = "ibverbs";
  40. constexpr char kCommTypeOfTCP[] = "zmq";
  41. constexpr char kRoleOfPServer[] = "server";
  42. constexpr char kRoleOfWorker[] = "worker";
  43. constexpr char kRoleOfScheduler[] = "scheduler";
  44. constexpr char kLearningRate[] = "learning_rate";
  45. constexpr char kMomentum[] = "momentum";
  46. constexpr char kApplyMomentum[] = "ApplyMomentum";
  47. constexpr char kSparseAdam[] = "Adam";
  48. constexpr char kSparseLazyAdam[] = "LazyAdam";
  49. constexpr char kSparseFtrl[] = "Ftrl";
  50. constexpr char kApplyMomentumOp[] = "Momentum";
  51. constexpr char kSparseAdamOp[] = "Adam";
  52. constexpr char kSparseLazyAdamOp[] = "LazyAdam";
  53. constexpr char kSparseFtrlOp[] = "FTRL";
  54. constexpr int kInitWeightsCmd = 10;
  55. constexpr int kInitWeightToOptimIdCmd = 11;
  56. constexpr int kInitOptimInputsShapeCmd = 12;
  57. constexpr int kInitKeyToPushNodeIdCmd = 13;
  58. constexpr int kInitEmbeddingsCmd = 20;
  59. constexpr int kCheckReadyForPushCmd = 25;
  60. constexpr int kCheckReadyForPullCmd = 26;
  61. constexpr int kEmbeddingLookupCmd = 30;
  62. constexpr int kFinalizeCmd = 40;
  63. constexpr size_t kInvalidKey = UINT64_MAX;
  64. constexpr int kInvalidID = -1;
  65. using Key = ::ps::Key;
  66. using Keys = ::ps::SArray<Key>;
  67. using Values = ::ps::SArray<float>;
  68. using ValuesPtr = std::shared_ptr<Values>;
  69. using Weight = ::ps::SArray<float>;
  70. using Grad = ::ps::SArray<float>;
  71. using LookupIds = ::ps::SArray<Key>;
  72. using Lengths = ::ps::SArray<int>;
  73. using WeightPtr = std::shared_ptr<Weight>;
  74. using GradPtr = std::shared_ptr<Grad>;
  75. using InputsShape = std::vector<std::shared_ptr<std::vector<size_t>>>;
  76. using InputsShapePtr = std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>>;
  77. constexpr size_t INDEX_NOT_SEND = UINT_MAX;
  78. using OptimOriginIdx = std::map<std::string, size_t>;
  79. using OptimPSSendIdx = std::map<std::string, size_t>;
  80. const OptimOriginIdx kMomentumOriginIdx = {{"weight", 0}, {"accum", 1}, {"lr", 2}, {"grad", 3}, {"momentum", 4}};
  81. const OptimPSSendIdx kMomentumPSSendIdx = {
  82. {"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"lr", 0}, {"grad", 1}, {"momentum", 2}};
  83. const OptimOriginIdx kSparseAdamOriginIdx = {{"weight", 0}, {"m", 1}, {"v", 2}, {"beta1_power", 3},
  84. {"beta2_power", 4}, {"lr", 5}, {"beta1", 6}, {"beta2", 7},
  85. {"eps", 8}, {"grad", 9}, {"indices", 10}};
  86. const OptimPSSendIdx kSparseAdamPSSendIdx = {{"weight", INDEX_NOT_SEND},
  87. {"m", INDEX_NOT_SEND},
  88. {"v", INDEX_NOT_SEND},
  89. {"beta1_power", 0},
  90. {"beta2_power", 1},
  91. {"lr", 2},
  92. {"beta1", 3},
  93. {"beta2", 4},
  94. {"eps", 5},
  95. {"grad", 6},
  96. {"indices", 7}};
  97. const OptimOriginIdx kSparseFtrlOriginIdx = {{"weight", 0}, {"accum", 1}, {"linear", 2}, {"grad", 3}, {"indices", 4}};
  98. const OptimPSSendIdx kSparseFtrlPSSendIdx = {
  99. {"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"linear", INDEX_NOT_SEND}, {"grad", 0}, {"indices", 1}};
  100. const std::map<std::string, OptimOriginIdx> kOptimToOriginIdx = {{kApplyMomentum, kMomentumOriginIdx},
  101. {kSparseAdam, kSparseAdamOriginIdx},
  102. {kSparseLazyAdam, kSparseAdamOriginIdx},
  103. {kSparseFtrl, kSparseFtrlOriginIdx}};
  104. const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum, kMomentumPSSendIdx},
  105. {kSparseAdam, kSparseAdamPSSendIdx},
  106. {kSparseLazyAdam, kSparseAdamPSSendIdx},
  107. {kSparseFtrl, kSparseFtrlPSSendIdx}};
  108. #define EXC_IF_VEC_IDX_OOB(vec, idx) \
  109. { \
  110. size_t vec_size = vec.size(); \
  111. if (idx >= vec_size) { \
  112. MS_LOG(EXCEPTION) << "Vector " << #vec << " size is " << vec_size << ". So index " << idx \
  113. << " is out of bound."; \
  114. } \
  115. }
  116. } // namespace ps
  117. } // namespace mindspore
  118. #endif // MINDSPORE_CCSRC_PS_COMMON_H_