You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

activation_info.h 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
  17. #define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
  18. #include <ir/value.h>
  19. #include <memory>
  20. #include <string>
  21. #include <unordered_map>
  22. #include <vector>
  23. #include "parallel/auto_parallel/operator_costmodel.h"
  24. #include "parallel/ops_info/operator_info.h"
  25. #include "parallel/strategy.h"
  26. namespace mindspore {
  27. namespace parallel {
  28. class ActivationBase : public OperatorInfo {
  29. public:
  30. ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  31. const PrimitiveAttrs &attrs, OperatorCostPtr cost)
  32. : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
  33. ~ActivationBase() override = default;
  34. Status Init(const StrategyPtr &strategy) override;
  35. Status InitForCostModel(const StrategyPtr &strategy) override;
  36. protected:
  37. Status InferMirrorOps() override;
  38. Status InferForwardCommunication() override;
  39. Status InferTensorMap() override;
  40. Status InferTensorInfo() override;
  41. Status InferDevMatrixShape() override;
  42. };
  43. class Activation : public ActivationBase {
  44. public:
  45. Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  46. const PrimitiveAttrs &attrs)
  47. : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {}
  48. ~Activation() override = default;
  49. Status GenerateStrategies(int32_t stage_id) override;
  50. Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
  51. protected:
  52. Status CheckStrategy(const StrategyPtr &strategy) override;
  53. };
  54. class ActivationInfo : public Activation {
  55. public:
  56. ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  57. const PrimitiveAttrs &attrs)
  58. : Activation(name, inputs_shape, outputs_shape, attrs) {}
  59. ~ActivationInfo() override = default;
  60. protected:
  61. Status GetAttrs() override; // activation_type: relu, relu6, sigmoid
  62. };
  63. class ActivationOther : public Activation {
  64. public:
  65. ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  66. const PrimitiveAttrs &attrs)
  67. : Activation(name, inputs_shape, outputs_shape, attrs) {}
  68. ~ActivationOther() override = default;
  69. protected:
  70. Status GetAttrs() override;
  71. };
  72. class GeluInfo : public ActivationOther {
  73. public:
  74. GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  75. const PrimitiveAttrs &attrs)
  76. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  77. ~GeluInfo() override = default;
  78. };
  79. class TanhInfo : public ActivationOther {
  80. public:
  81. TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  82. const PrimitiveAttrs &attrs)
  83. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  84. ~TanhInfo() override = default;
  85. };
  86. class Softmax : public ActivationBase {
  87. public:
  88. explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  89. const PrimitiveAttrs &attrs)
  90. : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {}
  91. ~Softmax() override = default;
  92. Status GenerateStrategies(int32_t stage_id) override;
  93. Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
  94. protected:
  95. Status CheckStrategy(const StrategyPtr &strategy) override;
  96. Status GetAttrs() override;
  97. private:
  98. std::vector<int32_t> axis_;
  99. };
  100. class SoftmaxInfo : public Softmax {
  101. public:
  102. SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  103. const PrimitiveAttrs &attrs)
  104. : Softmax(name, inputs_shape, outputs_shape, attrs) {}
  105. ~SoftmaxInfo() override = default;
  106. };
  107. class LogSoftmaxInfo : public Softmax {
  108. public:
  109. LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  110. const PrimitiveAttrs &attrs)
  111. : Softmax(name, inputs_shape, outputs_shape, attrs) {}
  112. ~LogSoftmaxInfo() override = default;
  113. };
  114. class ReLUInfo : public ActivationOther {
  115. public:
  116. ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  117. const PrimitiveAttrs &attrs)
  118. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  119. ~ReLUInfo() override = default;
  120. };
  121. class CastInfo : public ActivationOther {
  122. public:
  123. CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  124. const PrimitiveAttrs &attrs)
  125. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  126. ~CastInfo() override = default;
  127. protected:
  128. Status InferMirrorOps() override;
  129. };
  130. class SqrtInfo : public ActivationOther {
  131. public:
  132. SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  133. const PrimitiveAttrs &attrs)
  134. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  135. ~SqrtInfo() override = default;
  136. };
  137. class NegInfo : public ActivationOther {
  138. public:
  139. NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
  140. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  141. ~NegInfo() override = default;
  142. };
  143. class ExpandDimsInfo : public ActivationOther {
  144. public:
  145. ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  146. const PrimitiveAttrs &attrs)
  147. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  148. ~ExpandDimsInfo() override = default;
  149. protected:
  150. Status GetAttrs() override;
  151. Status InferTensorMap() override;
  152. Status InferTensorInfo() override;
  153. Status InferMirrorOps() override;
  154. Status InferTensorStrategy();
  155. private:
  156. int32_t positive_axis_ = -1;
  157. Strategys inputs_strategy_;
  158. Strategys outputs_strategy_;
  159. };
  160. class SqueezeInfo : public ActivationOther {
  161. public:
  162. SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  163. const PrimitiveAttrs &attrs)
  164. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  165. ~SqueezeInfo() override = default;
  166. protected:
  167. Status InferAxis(const ValueTuplePtr &value_tuple);
  168. Status GetAttrs() override;
  169. Status InferReplaceOps(const StrategyPtr &strategy);
  170. Status InferTensorMap() override;
  171. Status InferTensorInfo() override;
  172. Status Init(const StrategyPtr &strategy) override;
  173. private:
  174. ValueTuplePtr axis_;
  175. };
  176. class SquareInfo : public ActivationOther {
  177. public:
  178. SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
  179. const PrimitiveAttrs &attrs)
  180. : ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
  181. ~SquareInfo() override = default;
  182. };
  183. } // namespace parallel
  184. } // namespace mindspore
  185. #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_