You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dynamic_creator.h 5.4 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include "frontend/parallel/ops_info/ops_info_head_files.h"
  23. #include "frontend/parallel/step_parallel.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. #define REGISTER(className) \
  27. OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \
  28. return std::make_shared<className>(name, in, out, attrs); \
  29. } \
  30. RegisterAction className##Register(#className, (CreatFn)objectCreator##className);
  31. typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out,
  32. const PrimitiveAttrs &attrs);
  33. class DynCreator {
  34. public:
  35. ~DynCreator() = default;
  36. // creat static singleton dyn_creator instance
  37. static DynCreator &Instance() {
  38. static DynCreator fac = DynCreator();
  39. return fac;
  40. }
  41. // register
  42. void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
  43. // creator
  44. OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
  45. const PrimitiveAttrs &attrs, size_t count) {
  46. std::string op_name = name + std::to_string(count);
  47. auto iter = Function_map_.find(name);
  48. if (iter == Function_map_.end()) {
  49. MS_LOG(INFO) << name << " is not register yet";
  50. return nullptr;
  51. }
  52. return iter->second(op_name, shape_in, shape_out, attrs);
  53. }
  54. private:
  55. DynCreator() = default;
  56. std::map<std::string, CreatFn> Function_map_;
  57. };
  58. class RegisterAction {
  59. public:
  60. RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) {
  61. DynCreator::Instance().Regist(name, creatfn);
  62. }
  63. ~RegisterAction() = default;
  64. private:
  65. std::string name_;
  66. };
  67. // operator register
  68. REGISTER(MatMulInfo);
  69. REGISTER(GeluInfo);
  70. REGISTER(VirtualDatasetInfo);
  71. REGISTER(BatchParallelInfo);
  72. REGISTER(TanhInfo);
  73. REGISTER(SoftmaxInfo);
  74. REGISTER(LogSoftmaxInfo);
  75. REGISTER(ActivationInfo);
  76. REGISTER(SoftmaxCrossEntropyWithLogitsInfo);
  77. REGISTER(SubInfo);
  78. REGISTER(TensorAddInfo);
  79. REGISTER(BiasAddInfo);
  80. REGISTER(MulInfo);
  81. REGISTER(DivInfo);
  82. REGISTER(ModInfo);
  83. REGISTER(RealDivInfo);
  84. REGISTER(PowInfo);
  85. REGISTER(ExpInfo);
  86. REGISTER(OneHotInfo);
  87. REGISTER(EqualInfo);
  88. REGISTER(NotEqualInfo);
  89. REGISTER(LogInfo);
  90. REGISTER(CosInfo);
  91. REGISTER(ACosInfo);
  92. REGISTER(LogicalNotInfo);
  93. REGISTER(L2NormalizeInfo);
  94. REGISTER(LayerNormInfo);
  95. REGISTER(ReduceMaxInfo);
  96. REGISTER(ArgMaxWithValueInfo);
  97. REGISTER(ArgMinWithValueInfo);
  98. REGISTER(ReduceMeanInfo);
  99. REGISTER(ReduceSumInfo);
  100. REGISTER(ReduceMinInfo);
  101. REGISTER(TransposeInfo);
  102. REGISTER(PReLUInfo);
  103. REGISTER(DropoutDoMaskInfo);
  104. REGISTER(ReshapeInfo);
  105. REGISTER(FloorDivInfo);
  106. REGISTER(MaximumInfo);
  107. REGISTER(MinimumInfo);
  108. REGISTER(CastInfo);
  109. REGISTER(GreaterInfo);
  110. REGISTER(GreaterEqualInfo);
  111. REGISTER(LessEqualInfo);
  112. REGISTER(LessInfo);
  113. REGISTER(ApproximateEqualInfo);
  114. REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo);
  115. REGISTER(AssignSubInfo);
  116. REGISTER(FloorModInfo);
  117. REGISTER(AssignInfo);
  118. REGISTER(AssignAddInfo);
  119. REGISTER(Atan2Info);
  120. REGISTER(DivNoNanInfo);
  121. REGISTER(LogicalAndInfo);
  122. REGISTER(LogicalOrInfo);
  123. REGISTER(EluInfo);
  124. REGISTER(ReLUInfo);
  125. REGISTER(ReLU6Info);
  126. REGISTER(ReLUV2Info);
  127. REGISTER(SoftplusInfo);
  128. REGISTER(SoftsignInfo);
  129. REGISTER(GatherV2Info);
  130. REGISTER(SparseGatherV2Info);
  131. REGISTER(SqrtInfo);
  132. REGISTER(SigmoidInfo);
  133. REGISTER(GetNextInfo);
  134. REGISTER(NegInfo);
  135. REGISTER(AbsInfo);
  136. REGISTER(AcoshInfo);
  137. REGISTER(AsinInfo);
  138. REGISTER(AsinhInfo);
  139. REGISTER(AtanInfo);
  140. REGISTER(AtanhInfo);
  141. REGISTER(CeilInfo);
  142. REGISTER(CoshInfo);
  143. REGISTER(Expm1Info);
  144. REGISTER(Log1pInfo);
  145. REGISTER(SinInfo);
  146. REGISTER(SinhInfo);
  147. REGISTER(TanInfo);
  148. REGISTER(RsqrtInfo);
  149. REGISTER(InvInfo);
  150. REGISTER(ReciprocalInfo);
  151. REGISTER(RoundInfo);
  152. REGISTER(FloorInfo);
  153. REGISTER(SignInfo);
  154. REGISTER(ErfInfo);
  155. REGISTER(ErfcInfo);
  156. REGISTER(ZerosLikeInfo);
  157. REGISTER(OnesLikeInfo);
  158. REGISTER(BesselI0eInfo);
  159. REGISTER(BesselI1eInfo);
  160. REGISTER(BatchMatMulInfo);
  161. REGISTER(ExpandDimsInfo);
  162. REGISTER(SqueezeInfo);
  163. REGISTER(SigmoidCrossEntropyWithLogitsInfo);
  164. REGISTER(SquareInfo);
  165. REGISTER(GatherV2PInfo);
  166. REGISTER(EmbeddingLookupInfo);
  167. REGISTER(TileInfo);
  168. REGISTER(BroadcastToInfo);
  169. REGISTER(StridedSliceInfo);
  170. REGISTER(DropoutInfo);
  171. REGISTER(PackInfo);
  172. REGISTER(ConcatInfo);
  173. REGISTER(SplitInfo);
  174. } // namespace parallel
  175. } // namespace mindspore
  176. #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_