You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matmul_info.cc 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/ops_info/matmul_info.h"
  17. #include <algorithm>
  18. #include <functional>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "ir/value.h"
  24. #include "parallel/auto_parallel/graph_costmodel.h"
  25. #include "parallel/device_manager.h"
  26. #include "parallel/device_matrix.h"
  27. #include "parallel/tensor_layout/tensor_redistribution.h"
  28. namespace mindspore {
  29. namespace parallel {
  30. void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b,
  31. Shape *dev_matrix_shape) {
  32. MS_EXCEPTION_IF_NULL(dev_matrix_shape);
  33. size_t mat_a_size = mat_a_strategy.size();
  34. size_t mat_b_size = mat_b_strategy.size();
  35. if (mat_a_size >= mat_b_size) {
  36. // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32]
  37. // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
  38. // [2],[4] in the example above
  39. for (size_t i = 0; i < SECOND_FROM_END(mat_a_size); ++i) {
  40. dev_matrix_shape->push_back(mat_a_strategy.at(i));
  41. }
  42. } else {
  43. // for example: mat_a_strategy:[8,16], mat_b_strategy:[2,4,16,32]
  44. // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
  45. // [2],[4] in the example above
  46. for (size_t i = 0; i < SECOND_FROM_END(mat_b_size); ++i) {
  47. dev_matrix_shape->push_back(mat_b_strategy.at(i));
  48. }
  49. }
  50. // [8],[16] in the example above
  51. dev_matrix_shape->push_back(mat_a_strategy.at(SECOND_FROM_END(mat_a_size)));
  52. dev_matrix_shape->push_back(mat_a_strategy.back());
  53. // [32] in the example above
  54. if (!transpose_b) {
  55. dev_matrix_shape->push_back(mat_b_strategy.back());
  56. } else {
  57. dev_matrix_shape->push_back(mat_b_strategy.at(SECOND_FROM_END(mat_b_size)));
  58. }
  59. }
  60. Status MatMulBase::GetAttrs() {
  61. if (attrs_.size() < MATMUL_ATTRS_SIZE) {
  62. MS_LOG(ERROR) << name_ << " : The size of attrs small than 2.";
  63. return FAILED;
  64. }
  65. auto transpose_a_iter = attrs_.find(TRANSPOSE_A);
  66. if (transpose_a_iter != attrs_.end()) {
  67. MS_EXCEPTION_IF_NULL(transpose_a_iter->second);
  68. if (transpose_a_iter->second->isa<BoolImm>()) {
  69. transpose_a_ = transpose_a_iter->second->cast<BoolImmPtr>()->value();
  70. } else {
  71. MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
  72. return FAILED;
  73. }
  74. }
  75. auto transpose_b_iter = attrs_.find(TRANSPOSE_B);
  76. if (transpose_b_iter != attrs_.end()) {
  77. MS_EXCEPTION_IF_NULL(transpose_b_iter->second);
  78. if (transpose_b_iter->second->isa<BoolImm>()) {
  79. transpose_b_ = transpose_b_iter->second->cast<BoolImmPtr>()->value();
  80. } else {
  81. MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
  82. return FAILED;
  83. }
  84. }
  85. auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER);
  86. if (forward_reduce_scatter_iter != attrs_.end()) {
  87. MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second);
  88. if (forward_reduce_scatter_iter->second->isa<BoolImm>()) {
  89. forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast<BoolImmPtr>()->value();
  90. } else {
  91. MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool.";
  92. return FAILED;
  93. }
  94. }
  95. // infer inputs dimension size
  96. if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
  97. MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
  98. return FAILED;
  99. }
  100. mat_a_dimension_ = inputs_shape_.at(0).size();
  101. mat_b_dimension_ = inputs_shape_.at(1).size();
  102. return SUCCESS;
  103. }
  104. Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) {
  105. size_t long_size = long_strategy.size();
  106. size_t short_size = short_strategy.size();
  107. if (long_size < short_size) {
  108. MS_LOG(ERROR) << "Size error, the size of long strategy is " << long_size << ", the size of short strategy is "
  109. << short_size;
  110. return FAILED;
  111. }
  112. size_t len_diff = long_size - short_size;
  113. for (size_t j = 0; j < SECOND_FROM_END(short_size); ++j) {
  114. if (long_strategy.at(len_diff + j) != short_strategy.at(j)) {
  115. MS_LOG(ERROR) << "Strategies of relevant dimensions are not equal, long strategy is "
  116. << ShapeToString(long_strategy) << ", short strategy is " << ShapeToString(short_strategy);
  117. return FAILED;
  118. }
  119. }
  120. return SUCCESS;
  121. }
  122. Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
  123. if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
  124. if (is_auto_parallel_) {
  125. MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
  126. } else {
  127. MS_LOG(ERROR) << name_ << " : Invalid strategy.";
  128. }
  129. return FAILED;
  130. }
  131. std::vector<Dimensions> stra = strategy->GetInputDim();
  132. Dimensions mat_a_strategy = stra.at(0);
  133. Dimensions mat_b_strategy = stra.at(1);
  134. size_t mat_a_size = mat_a_strategy.size();
  135. size_t mat_b_size = mat_b_strategy.size();
  136. if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) {
  137. if (is_auto_parallel_) {
  138. MS_LOG(DEBUG) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong.";
  139. } else {
  140. MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong.";
  141. }
  142. return FAILED;
  143. }
  144. // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32]
  145. // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
  146. // [16] in the example above
  147. if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) {
  148. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  149. return FAILED;
  150. } else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) {
  151. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  152. return FAILED;
  153. }
  154. if (mat_a_size >= mat_b_size) {
  155. if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) {
  156. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  157. return FAILED;
  158. }
  159. } else {
  160. if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) {
  161. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  162. return FAILED;
  163. }
  164. }
  165. if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) {
  166. MS_LOG(WARNING) << name_
  167. << ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, "
  168. "setting the forward reduce scatter mode to false here";
  169. forward_reduce_scatter_ = false;
  170. }
  171. return SUCCESS;
  172. }
  173. Status MatMulBase::InferDevMatrixShape() {
  174. std::vector<Dimensions> stra = strategy_->GetInputDim();
  175. Dimensions mat_a_strategy = stra.at(0);
  176. Dimensions mat_b_strategy = stra.at(1);
  177. SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_);
  178. return SUCCESS;
  179. }
  180. // all-reduce weight's grad
  181. Status MatMulBase::InferMirrorOps() {
  182. mirror_ops_.clear();
  183. Shape mat_b_tensor_map = inputs_tensor_map_[1];
  184. std::vector<Group> mat_b_group;
  185. if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) {
  186. return FAILED;
  187. }
  188. OperatorVector op_for_inputs; // op_for_inputs is empty
  189. OperatorVector op_for_weight;
  190. if (mat_b_group.empty()) {
  191. MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
  192. return SUCCESS;
  193. } else {
  194. op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum());
  195. mirror_ops_.push_back(op_for_inputs);
  196. mirror_ops_.push_back(op_for_weight);
  197. MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name();
  198. }
  199. return SUCCESS;
  200. }
  201. Status MatMulBase::InferForwardCommunication() {
  202. forward_op_.clear();
  203. size_t dimension = dev_matrix_shape_.size();
  204. size_t relevant_dimension_index = SECOND_FROM_END(dimension);
  205. // Relevant dimension is not split and all reduce is not required
  206. if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
  207. MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
  208. return SUCCESS;
  209. }
  210. std::vector<Group> group_list;
  211. if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) {
  212. MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed.";
  213. return FAILED;
  214. } else if (group_list.empty()) {
  215. MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
  216. return SUCCESS;
  217. }
  218. Operator op;
  219. if (forward_reduce_scatter_) {
  220. op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name());
  221. } else {
  222. op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
  223. }
  224. forward_op_.push_back(op);
  225. MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
  226. return SUCCESS;
  227. }
  228. Status MatMulBase::InferTensorMap() {
  229. size_t size = dev_matrix_shape_.size();
  230. if (repeated_calc_num_ > 1) {
  231. // move the first dimension(repeated_calc_num_), just for the convenience of tensor-map's calculation
  232. size = dev_matrix_shape_.size() - 1;
  233. }
  234. std::vector<int32_t> tensor_map_index;
  235. // such as 5: tensor_map_index [4,3,2,1,0]
  236. for (size_t i = 0; i < size; ++i) {
  237. tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i));
  238. }
  239. // infer output tensor map: [4,3,2,0], delete the second-from-end element
  240. TensorMap output_tensor_map = tensor_map_index;
  241. (void)output_tensor_map.erase(output_tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(size)));
  242. // infer mat_a tensor map
  243. // for example: mat_a_dimension is 4, mat_a tensor map:[4,3,2,1]
  244. TensorMap mat_a_tensor_map = tensor_map_index;
  245. // delete last one element
  246. mat_a_tensor_map.pop_back();
  247. // delete the first (dev_matrix_size - 1 - mat_a_dimension) elements
  248. (void)mat_a_tensor_map.erase(
  249. mat_a_tensor_map.begin(),
  250. mat_a_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_a_dimension_));
  251. // infer mat_b tensor map
  252. TensorMap mat_b_tensor_map = tensor_map_index;
  253. // delete the third-to-last element
  254. (void)mat_b_tensor_map.erase(mat_b_tensor_map.begin() + static_cast<different_type>(THIRD_FROM_END(size)));
  255. // delete the first (dev_matrix_size - 1 - mat_b_dimension) elements
  256. (void)mat_b_tensor_map.erase(
  257. mat_b_tensor_map.begin(),
  258. mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_b_dimension_));
  259. if (transpose_b_) {
  260. // swap the last two elements
  261. int32_t last_value = mat_b_tensor_map.back();
  262. mat_b_tensor_map.pop_back();
  263. (void)mat_b_tensor_map.insert(
  264. mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(mat_b_tensor_map.size())), last_value);
  265. }
  266. if (forward_reduce_scatter_) {
  267. if (dev_matrix_shape_.size() != 3) {
  268. MS_LOG(WARNING) << name_
  269. << ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, "
  270. "setting the forward reduce scatter mode to false here";
  271. forward_reduce_scatter_ = false;
  272. } else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) {
  273. MS_LOG(WARNING) << name_
  274. << ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in "
  275. "forward reduce scatter mode, setting the forward reduce scatter mode to false here";
  276. forward_reduce_scatter_ = false;
  277. } else {
  278. // the forward reduce scatter only support that the dimension of output is 2
  279. output_tensor_map = {1, 0};
  280. }
  281. }
  282. inputs_tensor_map_.push_back(mat_a_tensor_map);
  283. inputs_tensor_map_.push_back(mat_b_tensor_map);
  284. outputs_tensor_map_.push_back(output_tensor_map);
  285. return SUCCESS;
  286. }
  287. Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
  288. Shape output_dev_matrix_shape;
  289. if (forward_reduce_scatter_) {
  290. if (dev_matrix_shape_.size() != 3) {
  291. MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode";
  292. return FAILED;
  293. }
  294. output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]};
  295. } else {
  296. output_dev_matrix_shape = dev_matrix_shape_;
  297. }
  298. TensorLayout mat_a_layout, mat_b_layout, output_layout;
  299. if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
  300. (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) ||
  301. (output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
  302. return FAILED;
  303. }
  304. inputs_layout->push_back(mat_a_layout);
  305. inputs_layout->push_back(mat_b_layout);
  306. outputs_layout->push_back(output_layout);
  307. return SUCCESS;
  308. }
  309. Status MatMulBase::InferTensorInfo() {
  310. // infer tensor layout
  311. TensorLayouts inputs_layout, outputs_layout;
  312. if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
  313. return FAILED;
  314. }
  315. TensorLayout mat_a_layout = inputs_layout.at(0);
  316. TensorLayout mat_b_layout = inputs_layout.at(1);
  317. TensorLayout output_layout = outputs_layout.at(0);
  318. TensorInfo mat_a_tensor_info(mat_a_layout);
  319. TensorInfo mat_b_tensor_info(mat_b_layout);
  320. TensorInfo output_tensor_info(output_layout);
  321. inputs_tensor_info_.push_back(mat_a_tensor_info);
  322. inputs_tensor_info_.push_back(mat_b_tensor_info);
  323. outputs_tensor_info_.push_back(output_tensor_info);
  324. return SUCCESS;
  325. }
  326. Status MatMulBase::Init(const StrategyPtr &strategy) {
  327. if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
  328. MS_LOG(ERROR) << name_ << " : Init failed.";
  329. return FAILED;
  330. }
  331. if (forward_reduce_scatter_) {
  332. virtual_div_op_.clear();
  333. MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op";
  334. }
  335. MS_LOG(INFO) << name_ << " : Init success.";
  336. return SUCCESS;
  337. }
  338. Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) {
  339. if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
  340. if (is_auto_parallel_) {
  341. MS_LOG(DEBUG) << name_ << " : Init for cost model failed.";
  342. } else {
  343. MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
  344. }
  345. return FAILED;
  346. }
  347. MS_LOG(INFO) << name_ << " : Init for cost model success.";
  348. return SUCCESS;
  349. }
  350. Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) {
  351. if (input->size() < 2) {
  352. MS_LOG(ERROR) << name_ << " : The size of inputs small than 2.";
  353. return FAILED;
  354. }
  355. auto last_1st_value = input->at(input->size() - 1);
  356. auto last_2nd_value = input->at(input->size() - 2);
  357. input->pop_back();
  358. input->pop_back();
  359. input->push_back(last_1st_value);
  360. input->push_back(last_2nd_value);
  361. return SUCCESS;
  362. }
  363. Status MatMulBase::GenerateStrategies(int32_t stage_id) {
  364. if (GetAttrs() != SUCCESS) {
  365. MS_LOG(ERROR) << name_ << " : GetAttrs failed.";
  366. return FAILED;
  367. }
  368. CheckGlobalDeviceManager();
  369. std::vector<int32_t> dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
  370. size_t dev_num = dev_list.size();
  371. Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1];
  372. if (transpose_a_) {
  373. if (SwapLastTwoElements(&input0_shape) == FAILED) {
  374. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  375. }
  376. }
  377. if (transpose_b_) {
  378. if (SwapLastTwoElements(&input1_shape) == FAILED) {
  379. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  380. }
  381. }
  382. // The shape of input0 (input1)
  383. // E.g., input0 = [100, 200, 300], input1 = [300, 400]
  384. // Combining the input0_shape and input1_shape
  385. // E.g., combined_shape = [100, 200, 300, 400]
  386. is_auto_parallel_ = true;
  387. size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size();
  388. Dimensions combined_partitions;
  389. Shape combined_shape;
  390. // In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2
  391. if (input0_shape.size() >= input1_shape.size()) {
  392. combined_shape = input0_shape;
  393. combined_shape.push_back(input1_shape[input1_shape.size() - 1]);
  394. } else {
  395. combined_shape = input1_shape;
  396. combined_shape.push_back(input0_shape[input0_shape.size() - 2]);
  397. }
  398. std::function<void(uint32_t, size_t)> recursive = [&stage_id, &dev_num, &combined_partitions, &combined_shape,
  399. &input1_shape_size, &recursive, &input0_shape_size,
  400. this](uint32_t current_index, size_t n) {
  401. // Finishing the recursive steps, if the strategy is valid, then calculate the cost
  402. // for this operator under the strategy.
  403. if (current_index == combined_shape.size()) {
  404. StrategyPtr sp;
  405. if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) ==
  406. FAILED) {
  407. return;
  408. }
  409. if (this->SetCostUnderStrategy(sp) == FAILED) {
  410. MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed.";
  411. return;
  412. }
  413. } else {
  414. MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size
  415. << ", input1_shape_size: " << input1_shape_size;
  416. for (uint32_t i = 1; i <= n; i *= 2) {
  417. if (n % i == 0 && IntToSize(combined_shape[current_index]) % i == 0) {
  418. combined_partitions.push_back(i);
  419. recursive(current_index + 1, n / i);
  420. combined_partitions.pop_back();
  421. }
  422. }
  423. }
  424. };
  425. recursive(0, dev_num);
  426. if (strategy_cost_.empty()) {
  427. MS_LOG(EXCEPTION) << name_ << " : No available strategy.";
  428. }
  429. return Status::SUCCESS;
  430. }
  431. Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num,
  432. mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size,
  433. size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
  434. int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int>());
  435. if (!FULLY_USE_DEVICES) {
  436. if (IntToSize(product) > dev_num) {
  437. return FAILED;
  438. }
  439. } else {
  440. if (IntToSize(product) != dev_num) {
  441. return FAILED;
  442. }
  443. }
  444. Dimensions input0_partitions, input1_partitions;
  445. if (input0_shape_size >= input1_shape_size) {
  446. for (size_t i = 0; i < input0_shape_size; ++i) {
  447. input0_partitions.push_back(combined_partitions[i]);
  448. }
  449. if (input1_shape_size == 2) {
  450. input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]);
  451. input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
  452. } else {
  453. // input1_shape.size() > 2
  454. for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) {
  455. if (j == combined_partitions.size() - 3) {
  456. continue;
  457. }
  458. input1_partitions.push_back(combined_partitions[j]);
  459. }
  460. }
  461. } else {
  462. for (size_t i = 0; i < input1_shape_size; ++i) {
  463. input1_partitions.push_back(combined_partitions[i]);
  464. }
  465. for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) {
  466. input0_partitions.push_back(combined_partitions[j]);
  467. }
  468. input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
  469. input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]);
  470. }
  471. if (transpose_a_) {
  472. if (SwapLastTwoElements(&input0_partitions) == FAILED) {
  473. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  474. }
  475. }
  476. if (transpose_b_) {
  477. if (SwapLastTwoElements(&input1_partitions) == FAILED) {
  478. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  479. }
  480. }
  481. std::vector<Dimensions> stras;
  482. stras.push_back(input0_partitions);
  483. stras.push_back(input1_partitions);
  484. (*sp) = std::make_shared<Strategy>(stage_id, stras);
  485. return SUCCESS;
  486. }
  487. void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_tensor_vector) {
  488. TensorLayout tly;
  489. if (transpose_a_) {
  490. Shape replica_input0_shape(inputs_tensor_info_[0].shape());
  491. Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape());
  492. if (SwapLastTwoElements(&replica_input0_shape) == FAILED) {
  493. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  494. }
  495. if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) {
  496. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  497. }
  498. TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape);
  499. relica_inputs_tensor_vector->push_back(replica_input0_info);
  500. } else {
  501. relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]);
  502. }
  503. if (transpose_b_) {
  504. Shape replica_input1_shape(inputs_tensor_info_[1].shape());
  505. Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape());
  506. if (SwapLastTwoElements(&replica_input1_shape) == FAILED) {
  507. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  508. }
  509. if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) {
  510. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  511. }
  512. TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape);
  513. relica_inputs_tensor_vector->push_back(replica_input1_info);
  514. } else {
  515. relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]);
  516. }
  517. }
  518. Status MatMulBase::CheckForTensorSliceValid() const {
  519. if (!TENSOR_SLICE_ALIGNMENT_ENABLE) {
  520. return SUCCESS;
  521. }
  522. if (inputs_tensor_info_.empty()) {
  523. return FAILED;
  524. }
  525. for (auto &one_input_tensor : inputs_tensor_info_) {
  526. auto slice_shape = one_input_tensor.slice_shape();
  527. if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) ||
  528. (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) {
  529. return FAILED;
  530. }
  531. }
  532. return SUCCESS;
  533. }
  534. Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
  535. if (InitForCostModel(strategy) == FAILED) {
  536. if (is_auto_parallel_) {
  537. MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed.";
  538. } else {
  539. MS_LOG(ERROR) << name_ << " : Initialization under the strategy failed.";
  540. }
  541. return FAILED;
  542. }
  543. PrintStrategy(strategy);
  544. // Check whether the tensor slice of input_tensor_info is valid or not
  545. if (CheckForTensorSliceValid() != SUCCESS) {
  546. MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy.";
  547. return FAILED;
  548. }
  549. // Here, a replicated inputs_ is constructed for the transposed TensorInfo.
  550. std::vector<TensorInfo> relica_inputs_tensor_vector;
  551. InitTensorInfoForCost(&relica_inputs_tensor_vector);
  552. int32_t stage_id = strategy->GetInputStage();
  553. // Here, we use the origin outputs_, because we only use the slice size of the output tensor.
  554. // It does not matter whether the output tensor is transposed or not.
  555. double computation_cost =
  556. operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
  557. double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
  558. std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
  559. result->communication_without_parameter_ =
  560. operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
  561. result->communication_with_partial_para_ =
  562. result->communication_without_parameter_ +
  563. COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
  564. // Breaking ties for preferring data parallelization
  565. BreakingTiesForPerferringDataParallel(strategy, result);
  566. MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_
  567. << ", communication_cost: " << result->communication_cost_
  568. << ", communication_without_parameter_: " << result->communication_without_parameter_
  569. << ", communication_with_partial_para_: " << result->communication_with_partial_para_;
  570. // refine communication cost calculation for practice
  571. RefineForPracticalCost(result, false);
  572. result->communication_forward_ = result->communication_without_parameter_;
  573. std::shared_ptr<StrategyWithCost> swc =
  574. std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
  575. swc->cost_list.push_back(result);
  576. strategy_cost_.emplace_back(swc);
  577. return SUCCESS;
  578. }
  579. } // namespace parallel
  580. } // namespace mindspore