You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss_info.cc 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/ops_info/loss_info.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <utility>
  20. #include <vector>
  21. #include "ir/value.h"
  22. #include "parallel/device_matrix.h"
  23. #include "parallel/strategy.h"
  24. #include "parallel/tensor_layout/tensor_redistribution.h"
  25. namespace mindspore {
  26. namespace parallel {
  27. Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) {
  28. if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
  29. if (is_auto_parallel_) {
  30. MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
  31. } else {
  32. MS_LOG(ERROR) << name_ << " : Invalid strategy.";
  33. }
  34. return FAILED;
  35. }
  36. std::vector<Dimensions> stra = strategy->GetInputDim();
  37. Dimensions input_strategy = stra.at(0);
  38. Dimensions label_strategy = stra.at(1);
  39. if (input_strategy != label_strategy) {
  40. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  41. return FAILED;
  42. }
  43. int32_t axis_index = axis_;
  44. if (axis_ < 0) {
  45. size_t input_dim = inputs_shape_.at(0).size();
  46. axis_index = static_cast<int32_t>(input_dim) + axis_;
  47. }
  48. int32_t input_axis_strategy = input_strategy.at(IntToSize(axis_index));
  49. int32_t label_axis_strategy = label_strategy.at(IntToSize(axis_index));
  50. // Dimension corresponding to axis is un-splittable
  51. if ((input_axis_strategy != MIN_SLICE_NUM) && (label_axis_strategy != MIN_SLICE_NUM)) {
  52. if (is_auto_parallel_) {
  53. MS_LOG(DEBUG) << name_
  54. << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy
  55. << ", label: " << label_axis_strategy;
  56. } else {
  57. MS_LOG(ERROR) << name_
  58. << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy
  59. << ", label: " << label_axis_strategy;
  60. }
  61. return FAILED;
  62. }
  63. return SUCCESS;
  64. }
  65. Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() {
  66. if ((inputs_shape_.size() != SoftmaxCrossEntropyWithLogitsInputsSize) ||
  67. (outputs_shape_.size() != SoftmaxCrossEntropyWithLogitsOutputsSize)) {
  68. MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
  69. return FAILED;
  70. }
  71. return SUCCESS;
  72. }
  73. Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() {
  74. std::vector<Dimensions> stra = strategy_->GetInputDim();
  75. Dimensions input_strategy = stra.at(0);
  76. dev_matrix_shape_ = input_strategy;
  77. return SUCCESS;
  78. }
  79. Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() {
  80. std::vector<int32_t> tensor_map_index;
  81. size_t size = inputs_shape_[0].size();
  82. // such as 4: tensor_map_index [3,2,1,0]
  83. for (size_t i = 0; i < size; ++i) {
  84. tensor_map_index.push_back((int32_t)(size - i - 1));
  85. }
  86. std::vector<int32_t> first_output_tensor_map = {tensor_map_index[0]};
  87. inputs_tensor_map_.push_back(tensor_map_index); // input
  88. inputs_tensor_map_.push_back(tensor_map_index); // label
  89. outputs_tensor_map_.push_back(first_output_tensor_map); // output-0
  90. outputs_tensor_map_.push_back(tensor_map_index); // output-1
  91. return SUCCESS;
  92. }
  93. Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorInfo() {
  94. // infer tensor shape
  95. Shape input_shape = inputs_shape_.at(0);
  96. Shape first_output_shape = outputs_shape_.at(0);
  97. // infer slice shape
  98. Shapes inputs_slice_shape, outputs_slice_shape;
  99. Strategys inputs_strategy = strategy_->GetInputDim();
  100. Strategys outputs_strategy = {{inputs_strategy[0][0]}, inputs_strategy.at(0)};
  101. if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
  102. return FAILED;
  103. }
  104. Shape input_slice_shape = inputs_slice_shape.at(0);
  105. Shape first_output_slice_shape = outputs_slice_shape.at(0);
  106. TensorMap input_tensor_map = inputs_tensor_map_.at(0);
  107. TensorMap first_output_tensor_map = outputs_tensor_map_.at(0);
  108. TensorLayout input_tensor_layout, first_output_tensor_layout;
  109. if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, input_tensor_map, input_shape) != SUCCESS) ||
  110. (first_output_tensor_layout.InitFromVector(dev_matrix_shape_, first_output_tensor_map, first_output_shape) !=
  111. SUCCESS)) {
  112. return FAILED;
  113. }
  114. TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
  115. TensorInfo first_output_tensor_info(first_output_tensor_layout, first_output_shape, first_output_slice_shape);
  116. inputs_tensor_info_.push_back(input_tensor_info); // input
  117. inputs_tensor_info_.push_back(input_tensor_info); // label
  118. outputs_tensor_info_.push_back(first_output_tensor_info); // output-0
  119. outputs_tensor_info_.push_back(input_tensor_info); // output-1
  120. return SUCCESS;
  121. }
  122. // There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload the function.
  123. Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() {
  124. if (outputs_tensor_map_.size() != 2) {
  125. MS_LOG(ERROR) << name_ << " : The size of outputs tensor map " << outputs_tensor_map_.size() << " is error.";
  126. return FAILED;
  127. }
  128. as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[1]);
  129. MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
  130. << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[1]) << ", as_loss_divisor_ is "
  131. << as_loss_divisor_;
  132. return SUCCESS;
  133. }
  134. Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) {
  135. if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
  136. MS_LOG(ERROR) << name_ << " : Init failed.";
  137. return FAILED;
  138. }
  139. MS_LOG(INFO) << name_ << " : Init success.";
  140. return SUCCESS;
  141. }
  142. Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) {
  143. if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
  144. if (is_auto_parallel_) {
  145. MS_LOG(DEBUG) << name_ << " : Init for cost model failed.";
  146. } else {
  147. MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
  148. }
  149. return FAILED;
  150. }
  151. MS_LOG(INFO) << name_ << " : Init for cost model success.";
  152. return SUCCESS;
  153. }
  154. void SoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
  155. for (size_t i = 0; i < inputs_shape_.size(); ++i) {
  156. split_flag_list_[i] = true;
  157. }
  158. }
  159. Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) {
  160. if (GetAttrs() != SUCCESS) {
  161. MS_LOG(ERROR) << name_ << " : GetAttrs failed.";
  162. return FAILED;
  163. }
  164. int32_t axis_index = axis_;
  165. if (axis_ < 0) {
  166. size_t input_dim = inputs_shape_[0].size();
  167. axis_index = static_cast<int32_t>(input_dim) + axis_;
  168. }
  169. is_auto_parallel_ = true;
  170. Shape input0_split;
  171. (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1);
  172. input0_split[IntToSize(axis_index)] = 0;
  173. Shapes splittable_inputs = {input0_split, input0_split};
  174. std::vector<StrategyPtr> sp_vector;
  175. if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
  176. MS_LOG(ERROR) << name_ << " : Generate strategies failed.";
  177. return FAILED;
  178. }
  179. size_t success = 0;
  180. for (auto &sp : sp_vector) {
  181. if (SetCostUnderStrategy(sp) == SUCCESS) {
  182. success++;
  183. MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy.";
  184. PrintStrategy(sp);
  185. }
  186. }
  187. return SUCCESS;
  188. }
  189. Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
  190. PrintStrategy(strategy);
  191. if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
  192. if (is_auto_parallel_) {
  193. MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed.";
  194. } else {
  195. MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
  196. }
  197. return FAILED;
  198. }
  199. return SUCCESS;
  200. }
  201. } // namespace parallel
  202. } // namespace mindspore