You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arithmetic_info.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/ops_info/arithmetic_info.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <utility>
  20. #include <vector>
  21. #include "parallel/device_matrix.h"
  22. #include "parallel/strategy.h"
  23. #include "parallel/tensor_layout/tensor_redistribution.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) {
  27. size_t insert_num = bigger_size_shape.size() - smaller_size_shape.size();
  28. for (size_t num = 0; num < insert_num; ++num) {
  29. (void)smaller_size_shape.insert(smaller_size_shape.begin(), 1);
  30. }
  31. return smaller_size_shape;
  32. }
  33. Shapes ArithmeticBase::InferExpendShape() {
  34. Shape input_a_shape = inputs_shape_.at(0);
  35. Shape input_b_shape = inputs_shape_.at(1);
  36. Shapes input_shapes;
  37. size_t input_a_size = input_a_shape.size();
  38. size_t input_b_size = input_b_shape.size();
  39. if (input_a_size > input_b_size) {
  40. input_shapes.push_back(input_a_shape);
  41. input_shapes.push_back(ExpendShape(input_a_shape, input_b_shape));
  42. } else if (input_a_size < input_b_size) {
  43. input_shapes.push_back(ExpendShape(input_b_shape, input_a_shape));
  44. input_shapes.push_back(input_b_shape);
  45. } else {
  46. input_shapes.push_back(input_a_shape);
  47. input_shapes.push_back(input_b_shape);
  48. }
  49. return input_shapes;
  50. }
  51. std::vector<Dimensions> ExpendStrategy(const StrategyPtr &strategy) {
  52. std::vector<Dimensions> expend_strategy;
  53. std::vector<Dimensions> stra = strategy->GetInputDim();
  54. Dimensions sub_a_strategy = stra.at(0);
  55. Dimensions sub_b_strategy = stra.at(1);
  56. size_t input_a_size = sub_a_strategy.size();
  57. size_t input_b_size = sub_b_strategy.size();
  58. if (input_a_size > input_b_size) {
  59. expend_strategy.push_back(sub_a_strategy);
  60. expend_strategy.push_back(ExpendShape(sub_a_strategy, sub_b_strategy));
  61. } else if (input_a_size < input_b_size) {
  62. expend_strategy.push_back(ExpendShape(sub_b_strategy, sub_a_strategy));
  63. expend_strategy.push_back(sub_b_strategy);
  64. } else {
  65. expend_strategy = stra;
  66. }
  67. return expend_strategy;
  68. }
  69. Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
  70. if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
  71. if (is_auto_parallel_) {
  72. MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
  73. } else {
  74. MS_LOG(ERROR) << name_ << " : Invalid strategy.";
  75. }
  76. return FAILED;
  77. }
  78. Shapes input_shapes = InferExpendShape();
  79. std::vector<Dimensions> expend_strategy = ExpendStrategy(strategy);
  80. Dimensions sub_a_strategy = expend_strategy.at(0);
  81. Dimensions sub_b_strategy = expend_strategy.at(1);
  82. Shape input_a_shape = input_shapes.at(0);
  83. Shape input_b_shape = input_shapes.at(1);
  84. for (size_t i = 0; i < input_a_shape.size(); ++i) {
  85. if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != 1) && (input_b_shape[i] != 1)) {
  86. if (is_auto_parallel_) {
  87. MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
  88. } else {
  89. MS_LOG(ERROR) << name_ << " : Invalid strategy.";
  90. }
  91. return FAILED;
  92. }
  93. }
  94. return SUCCESS;
  95. }
  96. Status ArithmeticBase::InferDevMatrixShape() {
  97. std::vector<Dimensions> expend_strategy = ExpendStrategy(strategy_);
  98. Dimensions sub_a_strategy = expend_strategy.at(0);
  99. Dimensions sub_b_strategy = expend_strategy.at(1);
  100. Shape dev_shape;
  101. for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
  102. if (sub_a_strategy[i] != sub_b_strategy[i]) {
  103. dev_shape.push_back(sub_a_strategy[i] * sub_b_strategy[i]);
  104. } else {
  105. dev_shape.push_back(sub_a_strategy[i]);
  106. }
  107. }
  108. dev_matrix_shape_ = dev_shape;
  109. return SUCCESS;
  110. }
  111. TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) {
  112. TensorMap tensor_map_index;
  113. for (size_t i = 0; i < strategy.size(); ++i) {
  114. if (strategy[i] == dev_matrix_shape[i]) {
  115. tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(strategy.size())) - i));
  116. } else {
  117. tensor_map_index.push_back(-1);
  118. }
  119. }
  120. return tensor_map_index;
  121. }
  122. TensorMap SetTensorMap(const Shape &strategy_expend, const Shape &dev_matrix_shape, const Shape &strategy) {
  123. TensorMap expend_map = SetExpendTensorMap(strategy_expend, dev_matrix_shape);
  124. size_t dev_matrix_size = dev_matrix_shape.size();
  125. size_t strategy_size = strategy.size();
  126. if (dev_matrix_size != strategy_size) {
  127. (void)expend_map.erase(expend_map.begin(),
  128. expend_map.begin() + static_cast<different_type>(dev_matrix_size - strategy_size));
  129. }
  130. return expend_map;
  131. }
  132. void ArithmeticBase::ReComputeBatchSplitFlagList() {
  133. Shapes expend_shapes = InferExpendShape();
  134. Shape expend_a_shape = expend_shapes.at(0);
  135. Shape expend_b_shape = expend_shapes.at(1);
  136. if (expend_a_shape.size() != expend_b_shape.size()) {
  137. MS_LOG(EXCEPTION) << name_ << " : Recompute batch split flag list is wrong.";
  138. }
  139. if (expend_a_shape.empty()) {
  140. split_flag_list_[0] = false;
  141. split_flag_list_[1] = false;
  142. return;
  143. }
  144. (expend_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false);
  145. (expend_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false);
  146. }
  147. Status ArithmeticBase::InferTensorMap() {
  148. std::vector<int32_t> tensor_map_index;
  149. std::vector<Dimensions> expend_strategy = ExpendStrategy(strategy_);
  150. Dimensions sub_a_expend_strategy = expend_strategy.at(0);
  151. Dimensions sub_b_expend_strategy = expend_strategy.at(1);
  152. Strategys stra = strategy_->GetInputDim();
  153. Dimensions sub_a_strategy = stra.at(0);
  154. Dimensions sub_b_strategy = stra.at(1);
  155. for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
  156. tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_expend_strategy.size())) - i));
  157. }
  158. Shape dev_shape;
  159. for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
  160. if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) {
  161. dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]);
  162. } else {
  163. dev_shape.push_back(sub_a_expend_strategy[i]);
  164. }
  165. }
  166. inputs_tensor_map_.push_back(SetTensorMap(sub_a_expend_strategy, dev_shape, sub_a_strategy));
  167. inputs_tensor_map_.push_back(SetTensorMap(sub_b_expend_strategy, dev_shape, sub_b_strategy));
  168. outputs_tensor_map_.push_back(tensor_map_index);
  169. return SUCCESS;
  170. }
  171. Status ArithmeticBase::InferMirrorOps() {
  172. mirror_ops_.clear();
  173. Shape input_a_tensor_map = inputs_tensor_map_.at(0);
  174. Shape input_b_tensor_map = inputs_tensor_map_.at(1);
  175. std::vector<Group> input_a_group, input_b_group;
  176. if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
  177. MS_LOG(ERROR) << name_ << " : Create group for input a failed.";
  178. return FAILED;
  179. }
  180. if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) {
  181. MS_LOG(ERROR) << name_ << " : Create group for input b failed.";
  182. return FAILED;
  183. }
  184. OperatorVector op_for_input_a, op_for_input_b;
  185. if (input_a_group.empty() && input_b_group.empty()) {
  186. MS_LOG(INFO) << name_ << " : The mirror group is empty.";
  187. return SUCCESS;
  188. }
  189. if (!input_a_group.empty()) {
  190. op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
  191. MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name();
  192. }
  193. if (!input_b_group.empty()) {
  194. op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum());
  195. MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name();
  196. }
  197. mirror_ops_.push_back(op_for_input_a);
  198. mirror_ops_.push_back(op_for_input_b);
  199. return SUCCESS;
  200. }
  201. Status ArithmeticBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout,
  202. const Shape &dev_matrix_array) {
  203. if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) {
  204. MS_LOG(ERROR) << name_ << " : The layout is null.";
  205. return FAILED;
  206. }
  207. TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0);
  208. TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1);
  209. TensorMap out_tensor_map_array = outputs_tensor_map_.at(0);
  210. Shape input_a_shape_array = inputs_shape_.at(0);
  211. Shape input_b_shape_array = inputs_shape_.at(1);
  212. Shape out_shape_array = outputs_shape_.at(0);
  213. TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout;
  214. if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) !=
  215. SUCCESS) {
  216. MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed.";
  217. return FAILED;
  218. }
  219. if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) !=
  220. SUCCESS) {
  221. MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed.";
  222. return FAILED;
  223. }
  224. if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) {
  225. MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed.";
  226. return FAILED;
  227. }
  228. inputs_layout->push_back(input_a_tensor_layout);
  229. inputs_layout->push_back(input_b_tensor_layout);
  230. outputs_layout->push_back(out_tensor_layout);
  231. return SUCCESS;
  232. }
  233. Status ArithmeticBase::InferTensorInfo() {
  234. // infer tensor shape
  235. Shape input_a_shape = inputs_shape_.at(0);
  236. Shape input_b_shape = inputs_shape_.at(1);
  237. Shape output_shape = outputs_shape_.at(0);
  238. // infer slice shape
  239. Shapes inputs_slice_shape, outputs_slice_shape;
  240. std::vector<Dimensions> expend_strategy = ExpendStrategy(strategy_);
  241. Dimensions sub_a_expend_strategy = expend_strategy.at(0);
  242. Dimensions sub_b_expend_strategy = expend_strategy.at(1);
  243. Strategys inputs_strategy = strategy_->GetInputDim();
  244. Shape dev_shape;
  245. for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
  246. if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) {
  247. dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]);
  248. } else {
  249. dev_shape.push_back(sub_a_expend_strategy[i]);
  250. }
  251. }
  252. Strategys outputs_strategy = {dev_shape};
  253. if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
  254. return FAILED;
  255. }
  256. Shape input_a_slice_shape = inputs_slice_shape.at(0);
  257. Shape input_b_slice_shape = inputs_slice_shape.at(1);
  258. Shape output_slice_shape = outputs_slice_shape.at(0);
  259. // infer tensor layout
  260. TensorLayouts inputs_layout, outputs_layout;
  261. if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) {
  262. MS_LOG(ERROR) << name_ << " : Infer tensor layout failed.";
  263. return FAILED;
  264. }
  265. TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape);
  266. TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape);
  267. TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape);
  268. inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a
  269. inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b
  270. outputs_tensor_info_.push_back(out_tensor_info); // output
  271. return SUCCESS;
  272. }
  273. Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) {
  274. if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
  275. if (is_auto_parallel_) {
  276. MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed.";
  277. } else {
  278. MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
  279. }
  280. return FAILED;
  281. }
  282. return SUCCESS;
  283. }
  284. Status ArithmeticBase::GenerateStrategies(int32_t stage_id) {
  285. Shape input0_split(inputs_shape_[0].size(), 1);
  286. Shape input1_split(inputs_shape_[1].size(), 1);
  287. Shapes splittable_inputs = {input0_split, input1_split};
  288. std::vector<StrategyPtr> sp_vector;
  289. is_auto_parallel_ = true;
  290. if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
  291. MS_LOG(ERROR) << name_ << " : Generate strategies with broadcast failed.";
  292. return FAILED;
  293. }
  294. MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success.";
  295. size_t success = 0;
  296. for (auto &sp : sp_vector) {
  297. PrintStrategy(sp);
  298. if (SetCostUnderStrategy(sp) == SUCCESS) {
  299. success++;
  300. MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy.";
  301. PrintStrategy(sp);
  302. }
  303. }
  304. return SUCCESS;
  305. }
  306. Status ArithmeticBase::Init(const StrategyPtr &strategy) {
  307. if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
  308. MS_LOG(ERROR) << name_ << " : Init failed.";
  309. return FAILED;
  310. }
  311. MS_LOG(INFO) << name_ << " : Init success.";
  312. return SUCCESS;
  313. }
  314. Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) {
  315. if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
  316. if (is_auto_parallel_) {
  317. MS_LOG(DEBUG) << name_ << " : Init for cost model failed.";
  318. } else {
  319. MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
  320. }
  321. return FAILED;
  322. }
  323. MS_LOG(INFO) << name_ << " : Init for cost model success.";
  324. return SUCCESS;
  325. }
  326. } // namespace parallel
  327. } // namespace mindspore