You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prelu_info.cc 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/ops_info/prelu_info.h"
  17. #include <memory>
  18. #include <utility>
  19. #include <vector>
  20. #include "parallel/device_manager.h"
  21. #include "parallel/device_matrix.h"
  22. #include "parallel/step_parallel.h"
  23. #include "utils/convert_utils.h"
  24. #include "utils/log_adapter.h"
  25. namespace mindspore {
  26. namespace parallel {
  27. /*
  28. * prelu has 2 input
  29. * A: A float tensor of shape [NCHW] representing the output of the preview layer.
  30. * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input.
  31. * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1
  32. */
  33. Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
  34. if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
  35. if (is_auto_parallel_) {
  36. MS_LOG(DEBUG) << name_ << ": Invalid strategy.";
  37. } else {
  38. MS_LOG(ERROR) << name_ << ": Invalid strategy.";
  39. }
  40. return FAILED;
  41. }
  42. std::vector<Dimensions> stra = strategy->GetInputDim();
  43. if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) {
  44. if (is_auto_parallel_) {
  45. MS_LOG(DEBUG) << name_ << ": Invalid strategy size.";
  46. } else {
  47. MS_LOG(ERROR) << name_ << ": Invalid strategy size.";
  48. }
  49. return FAILED;
  50. }
  51. if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) {
  52. if (is_auto_parallel_) {
  53. MS_LOG(DEBUG) << name_ << ": Invalid channel strategy.";
  54. } else {
  55. MS_LOG(ERROR) << name_ << ": Invalid channel strategy.";
  56. }
  57. return FAILED;
  58. }
  59. return SUCCESS;
  60. }
  61. /*
  62. * device matrix is same with the strategy matrix
  63. */
  64. Status PReLUInfo::InferDevMatrixShape() {
  65. std::vector<Dimensions> stra = strategy_->GetInputDim();
  66. Dimensions input_strategy = stra.at(0);
  67. input_strategy_ = input_strategy;
  68. dev_matrix_shape_ = input_strategy;
  69. return SUCCESS;
  70. }
  71. Status PReLUInfo::InferMirrorOps() {
  72. Shape param_tensor_map = inputs_tensor_map_[1];
  73. std::vector<Group> param_group;
  74. if (CreateGroupByTensorMap(param_tensor_map, &param_group) != SUCCESS) {
  75. return FAILED;
  76. } else if (param_group.empty()) {
  77. MS_LOG(INFO) << name_ << ": The mirror ops is empty.";
  78. return SUCCESS;
  79. }
  80. OperatorVector op_for_param;
  81. op_for_param = CreateMirrorOps(param_group[0].name(), param_group[0].GetDevNum());
  82. // op_for_inputs is empty
  83. OperatorVector op_for_inputs;
  84. mirror_ops_.push_back(op_for_inputs);
  85. mirror_ops_.push_back(op_for_param);
  86. std::string group_name = param_group[0].name();
  87. MS_LOG(INFO) << name_ << ": The mirror ops group is " << group_name;
  88. return SUCCESS;
  89. }
  90. Status PReLUInfo::InferForwardCommunication() { return SUCCESS; }
  91. /*
  92. * the output tensor map is the same as the input tensor map
  93. */
  94. Status PReLUInfo::InferTensorMap() {
  95. TensorMap input_tensor_map;
  96. // such as 4: input_tensor_map [3,2,1,0]
  97. for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
  98. input_tensor_map.push_back((int32_t)(inputs_shape_[0].size() - i - 1));
  99. }
  100. TensorMap param_tensor_map;
  101. if (inputs_shape_[1][0] == 1) {
  102. param_tensor_map.push_back(-1);
  103. } else {
  104. param_tensor_map.push_back(input_tensor_map.at(1));
  105. }
  106. inputs_tensor_map_.push_back(input_tensor_map);
  107. inputs_tensor_map_.push_back(param_tensor_map);
  108. outputs_tensor_map_.push_back(input_tensor_map);
  109. return SUCCESS;
  110. }
  111. Dimensions PReLUInfo::GetOutputStrategy() {
  112. Dimensions output_strategy = input_strategy_;
  113. return output_strategy;
  114. }
  115. Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
  116. if (inputs_layout == nullptr || outputs_layout == nullptr) {
  117. MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null.";
  118. return FAILED;
  119. }
  120. TensorLayout input_layout, param_layout, output_layout;
  121. if ((input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
  122. (param_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) ||
  123. (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
  124. return FAILED;
  125. }
  126. inputs_layout->push_back(input_layout);
  127. inputs_layout->push_back(param_layout);
  128. outputs_layout->push_back(output_layout);
  129. return SUCCESS;
  130. }
  131. Status PReLUInfo::InferTensorInfo() {
  132. // infer tensor shape
  133. Shape input_shape = inputs_shape_.at(0);
  134. Shape param_shape = inputs_shape_.at(1);
  135. Shape output_shape = outputs_shape_.at(0);
  136. // infer slice shape
  137. Shapes inputs_slice_shape, outputs_slice_shape;
  138. Dimensions output_strategy = GetOutputStrategy();
  139. Strategys inputs_strategy = strategy_->GetInputDim();
  140. Strategys outputs_strategy = {output_strategy};
  141. if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
  142. return FAILED;
  143. }
  144. Shape input_slice_shape = inputs_slice_shape.at(0);
  145. Shape param_slice_shape = inputs_slice_shape.at(1);
  146. Shape output_slice_shape = outputs_slice_shape.at(0);
  147. // infer tensor layout
  148. TensorLayouts inputs_layout, outputs_layout;
  149. if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
  150. return FAILED;
  151. }
  152. TensorLayout input_layout = inputs_layout.at(0);
  153. TensorLayout param_layout = inputs_layout.at(1);
  154. TensorLayout output_layout = outputs_layout.at(0);
  155. TensorInfo input_tensor_info(input_layout, input_shape, input_slice_shape);
  156. TensorInfo param_tensor_info(param_layout, param_shape, param_slice_shape);
  157. TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape);
  158. inputs_tensor_info_.push_back(input_tensor_info);
  159. inputs_tensor_info_.push_back(param_tensor_info);
  160. outputs_tensor_info_.push_back(output_tensor_info);
  161. return SUCCESS;
  162. }
  163. Status PReLUInfo::GetAttrs() {
  164. if ((inputs_shape_.size() != PRELU_INPUTS_SIZE) || (outputs_shape_.size() != PRELU_OUTPUTS_SIZE)) {
  165. MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size "
  166. << outputs_shape_.size() << " is wrong.";
  167. return FAILED;
  168. }
  169. return SUCCESS;
  170. }
  171. Status PReLUInfo::Init(const StrategyPtr &strategy) {
  172. if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
  173. MS_LOG(ERROR) << name_ << ": Init failed.";
  174. return FAILED;
  175. }
  176. MS_LOG(INFO) << name_ << ": Init success.";
  177. return SUCCESS;
  178. }
  179. Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) {
  180. if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
  181. if (is_auto_parallel_) {
  182. MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
  183. } else {
  184. MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
  185. }
  186. return FAILED;
  187. }
  188. MS_LOG(INFO) << name_ << ": Init for cost model success.";
  189. return SUCCESS;
  190. }
  191. Status PReLUInfo::GenerateStrategies(int32_t stage_id) {
  192. if (inputs_shape_.size() != PRELU_INPUTS_SIZE) {
  193. return FAILED;
  194. }
  195. if (inputs_shape_[1].size() != PRELU_SECOND_INPUT_SIZE) {
  196. return FAILED;
  197. }
  198. is_auto_parallel_ = true;
  199. Shape input0_split;
  200. input0_split.emplace_back(1);
  201. input0_split.emplace_back(0);
  202. (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 2, 1);
  203. Shape input1_split(inputs_shape_[1].size(), 0);
  204. Shapes splittable_inputs = {input0_split, input1_split};
  205. std::vector<StrategyPtr> sp_vector;
  206. if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
  207. MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed";
  208. return FAILED;
  209. }
  210. size_t success = 0;
  211. for (auto &sp : sp_vector) {
  212. if (SetCostUnderStrategy(sp) == SUCCESS) {
  213. success++;
  214. MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
  215. PrintStrategy(sp);
  216. }
  217. }
  218. return SUCCESS;
  219. }
  220. Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
  221. if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
  222. if (is_auto_parallel_) {
  223. MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed.";
  224. } else {
  225. MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
  226. }
  227. return FAILED;
  228. }
  229. return SUCCESS;
  230. }
  231. } // namespace parallel
  232. } // namespace mindspore