You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matmul_info.cc 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/ops_info/matmul_info.h"
  17. #include <algorithm>
  18. #include <functional>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "ir/value.h"
  24. #include "parallel/auto_parallel/graph_costmodel.h"
  25. #include "parallel/device_manager.h"
  26. #include "parallel/device_matrix.h"
  27. #include "parallel/tensor_layout/tensor_redistribution.h"
  28. namespace mindspore {
  29. namespace parallel {
  30. void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b,
  31. Shape *dev_matrix_shape) {
  32. MS_EXCEPTION_IF_NULL(dev_matrix_shape);
  33. size_t mat_a_size = mat_a_strategy.size();
  34. size_t mat_b_size = mat_b_strategy.size();
  35. if (mat_a_size >= mat_b_size) {
  36. // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32]
  37. // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
  38. // [2],[4] in the example above
  39. for (size_t i = 0; i < SECOND_FROM_END(mat_a_size); ++i) {
  40. dev_matrix_shape->push_back(mat_a_strategy.at(i));
  41. }
  42. } else {
  43. // for example: mat_a_strategy:[8,16], mat_b_strategy:[2,4,16,32]
  44. // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
  45. // [2],[4] in the example above
  46. for (size_t i = 0; i < SECOND_FROM_END(mat_b_size); ++i) {
  47. dev_matrix_shape->push_back(mat_b_strategy.at(i));
  48. }
  49. }
  50. // [8],[16] in the example above
  51. dev_matrix_shape->push_back(mat_a_strategy.at(SECOND_FROM_END(mat_a_size)));
  52. dev_matrix_shape->push_back(mat_a_strategy.back());
  53. // [32] in the example above
  54. if (!transpose_b) {
  55. dev_matrix_shape->push_back(mat_b_strategy.back());
  56. } else {
  57. dev_matrix_shape->push_back(mat_b_strategy.at(SECOND_FROM_END(mat_b_size)));
  58. }
  59. }
  60. Status MatMulBase::GetAttrs() {
  61. if (attrs_.size() < MATMUL_ATTRS_SIZE) {
  62. MS_LOG(ERROR) << name_ << " : The size of attrs small than 2.";
  63. return FAILED;
  64. }
  65. auto transpose_a_iter = attrs_.find(TRANSPOSE_A);
  66. if (transpose_a_iter != attrs_.end()) {
  67. MS_EXCEPTION_IF_NULL(transpose_a_iter->second);
  68. if (transpose_a_iter->second->isa<BoolImm>()) {
  69. transpose_a_ = transpose_a_iter->second->cast<BoolImmPtr>()->value();
  70. } else {
  71. MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
  72. return FAILED;
  73. }
  74. }
  75. auto transpose_b_iter = attrs_.find(TRANSPOSE_B);
  76. if (transpose_b_iter != attrs_.end()) {
  77. MS_EXCEPTION_IF_NULL(transpose_b_iter->second);
  78. if (transpose_b_iter->second->isa<BoolImm>()) {
  79. transpose_b_ = transpose_b_iter->second->cast<BoolImmPtr>()->value();
  80. } else {
  81. MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
  82. return FAILED;
  83. }
  84. }
  85. // infer inputs dimension size
  86. if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
  87. MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
  88. return FAILED;
  89. }
  90. mat_a_dimension_ = inputs_shape_.at(0).size();
  91. mat_b_dimension_ = inputs_shape_.at(1).size();
  92. return SUCCESS;
  93. }
  94. Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) {
  95. size_t long_size = long_strategy.size();
  96. size_t short_size = short_strategy.size();
  97. if (long_size < short_size) {
  98. MS_LOG(ERROR) << "Size error, the size of long strategy is " << long_size << ", the size of short strategy is "
  99. << short_size;
  100. return FAILED;
  101. }
  102. size_t len_diff = long_size - short_size;
  103. for (size_t j = 0; j < SECOND_FROM_END(short_size); ++j) {
  104. if (long_strategy.at(len_diff + j) != short_strategy.at(j)) {
  105. MS_LOG(ERROR) << "Strategies of relevant dimensions are not equal, long strategy is "
  106. << ShapeToString(long_strategy) << ", short strategy is " << ShapeToString(short_strategy);
  107. return FAILED;
  108. }
  109. }
  110. return SUCCESS;
  111. }
  112. Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
  113. if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
  114. if (is_auto_parallel_) {
  115. MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
  116. } else {
  117. MS_LOG(ERROR) << name_ << " : Invalid strategy.";
  118. }
  119. return FAILED;
  120. }
  121. std::vector<Dimensions> stra = strategy->GetInputDim();
  122. Dimensions mat_a_strategy = stra.at(0);
  123. Dimensions mat_b_strategy = stra.at(1);
  124. size_t mat_a_size = mat_a_strategy.size();
  125. size_t mat_b_size = mat_b_strategy.size();
  126. if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) {
  127. if (is_auto_parallel_) {
  128. MS_LOG(DEBUG) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong.";
  129. } else {
  130. MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong.";
  131. }
  132. return FAILED;
  133. }
  134. // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32]
  135. // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
  136. // [16] in the example above
  137. if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) {
  138. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  139. return FAILED;
  140. } else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) {
  141. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  142. return FAILED;
  143. }
  144. if (mat_a_size >= mat_b_size) {
  145. if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) {
  146. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  147. return FAILED;
  148. }
  149. } else {
  150. if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) {
  151. MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
  152. return FAILED;
  153. }
  154. }
  155. return SUCCESS;
  156. }
  157. Status MatMulBase::InferDevMatrixShape() {
  158. std::vector<Dimensions> stra = strategy_->GetInputDim();
  159. Dimensions mat_a_strategy = stra.at(0);
  160. Dimensions mat_b_strategy = stra.at(1);
  161. SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_);
  162. return SUCCESS;
  163. }
  164. // all-reduce weight's grad
  165. Status MatMulBase::InferMirrorOps() {
  166. mirror_ops_.clear();
  167. Shape mat_b_tensor_map = inputs_tensor_map_[1];
  168. std::vector<Group> mat_b_group;
  169. if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) {
  170. return FAILED;
  171. }
  172. OperatorVector op_for_inputs; // op_for_inputs is empty
  173. OperatorVector op_for_weight;
  174. if (mat_b_group.empty()) {
  175. MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
  176. return SUCCESS;
  177. } else {
  178. op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum());
  179. mirror_ops_.push_back(op_for_inputs);
  180. mirror_ops_.push_back(op_for_weight);
  181. MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name();
  182. }
  183. return SUCCESS;
  184. }
  185. Status MatMulBase::InferForwardCommunication() {
  186. forward_op_.clear();
  187. size_t dimension = dev_matrix_shape_.size();
  188. size_t relevant_dimension_index = SECOND_FROM_END(dimension);
  189. // Relevant dimension is not split and all reduce is not required
  190. if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
  191. MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
  192. return SUCCESS;
  193. }
  194. std::vector<Group> group_list;
  195. if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) {
  196. MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed.";
  197. return FAILED;
  198. } else if (group_list.empty()) {
  199. MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
  200. return SUCCESS;
  201. }
  202. Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
  203. forward_op_.push_back(op);
  204. MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
  205. return SUCCESS;
  206. }
  207. // dev_matrix_shape: [a, b, c, d, e], then output strategy: [a, b, c, e];
  208. Dimensions GetOutputStrategy(const Shape &dev_matrix_shape, int32_t repeated_calculation_num) {
  209. Dimensions output_strategy = dev_matrix_shape;
  210. if (repeated_calculation_num > 1) {
  211. // move the first dimension(repeated_calc_num_)
  212. (void)output_strategy.erase(output_strategy.begin());
  213. }
  214. // delete the second-to-last element
  215. (void)output_strategy.erase(output_strategy.begin() +
  216. static_cast<different_type>(SECOND_FROM_END(output_strategy.size())));
  217. return output_strategy;
  218. }
  219. Status MatMulBase::InferTensorMap() {
  220. size_t size = dev_matrix_shape_.size();
  221. if (repeated_calc_num_ > 1) {
  222. // move the first dimension(repeated_calc_num_), just for the convenience of tensor-map's calculation
  223. size = dev_matrix_shape_.size() - 1;
  224. }
  225. std::vector<int32_t> tensor_map_index;
  226. // such as 5: tensor_map_index [4,3,2,1,0]
  227. for (size_t i = 0; i < size; ++i) {
  228. tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i));
  229. }
  230. // infer output tensor map: [4,3,2,0], delete the second-from-end element
  231. TensorMap output_tensor_map = tensor_map_index;
  232. (void)output_tensor_map.erase(output_tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(size)));
  233. // infer mat_a tensor map
  234. // for example: mat_a_dimension is 4, mat_a tensor map:[4,3,2,1]
  235. TensorMap mat_a_tensor_map = tensor_map_index;
  236. // delete last one element
  237. mat_a_tensor_map.pop_back();
  238. // delete the first (dev_matrix_size - 1 - mat_a_dimension) elements
  239. (void)mat_a_tensor_map.erase(
  240. mat_a_tensor_map.begin(),
  241. mat_a_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_a_dimension_));
  242. // infer mat_b tensor map
  243. TensorMap mat_b_tensor_map = tensor_map_index;
  244. // delete the third-to-last element
  245. (void)mat_b_tensor_map.erase(mat_b_tensor_map.begin() + static_cast<different_type>(THIRD_FROM_END(size)));
  246. // delete the first (dev_matrix_size - 1 - mat_b_dimension) elements
  247. (void)mat_b_tensor_map.erase(
  248. mat_b_tensor_map.begin(),
  249. mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_b_dimension_));
  250. if (transpose_b_) {
  251. // swap the last two elements
  252. int32_t last_value = mat_b_tensor_map.back();
  253. mat_b_tensor_map.pop_back();
  254. (void)mat_b_tensor_map.insert(
  255. mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(mat_b_tensor_map.size())), last_value);
  256. }
  257. inputs_tensor_map_.push_back(mat_a_tensor_map);
  258. inputs_tensor_map_.push_back(mat_b_tensor_map);
  259. outputs_tensor_map_.push_back(output_tensor_map);
  260. return SUCCESS;
  261. }
  262. Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
  263. TensorLayout mat_a_layout, mat_b_layout, output_layout;
  264. if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
  265. (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) ||
  266. (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
  267. return FAILED;
  268. }
  269. inputs_layout->push_back(mat_a_layout);
  270. inputs_layout->push_back(mat_b_layout);
  271. outputs_layout->push_back(output_layout);
  272. return SUCCESS;
  273. }
  274. Status MatMulBase::InferTensorInfo() {
  275. // infer tensor shape
  276. Shape mat_a_shape = inputs_shape_.at(0);
  277. Shape mat_b_shape = inputs_shape_.at(1);
  278. Shape output_shape = outputs_shape_.at(0);
  279. // infer slice shape
  280. Shapes inputs_slice_shape, outputs_slice_shape;
  281. Dimensions output_strategy = GetOutputStrategy(dev_matrix_shape_, repeated_calc_num_);
  282. Strategys inputs_strategy = strategy_->GetInputDim();
  283. Strategys outputs_strategy = {output_strategy};
  284. if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
  285. return FAILED;
  286. }
  287. Shape mat_a_slice_shape = inputs_slice_shape.at(0);
  288. Shape mat_b_slice_shape = inputs_slice_shape.at(1);
  289. Shape output_slice_shape = outputs_slice_shape.at(0);
  290. // infer tensor layout
  291. TensorLayouts inputs_layout, outputs_layout;
  292. if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
  293. return FAILED;
  294. }
  295. TensorLayout mat_a_layout = inputs_layout.at(0);
  296. TensorLayout mat_b_layout = inputs_layout.at(1);
  297. TensorLayout output_layout = outputs_layout.at(0);
  298. TensorInfo mat_a_tensor_info(mat_a_layout, mat_a_shape, mat_a_slice_shape);
  299. TensorInfo mat_b_tensor_info(mat_b_layout, mat_b_shape, mat_b_slice_shape);
  300. TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape);
  301. inputs_tensor_info_.push_back(mat_a_tensor_info);
  302. inputs_tensor_info_.push_back(mat_b_tensor_info);
  303. outputs_tensor_info_.push_back(output_tensor_info);
  304. return SUCCESS;
  305. }
  306. Status MatMulBase::Init(const StrategyPtr &strategy) {
  307. if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
  308. MS_LOG(ERROR) << name_ << " : Init failed.";
  309. return FAILED;
  310. }
  311. MS_LOG(INFO) << name_ << " : Init success.";
  312. return SUCCESS;
  313. }
  314. Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) {
  315. if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
  316. if (is_auto_parallel_) {
  317. MS_LOG(DEBUG) << name_ << " : Init for cost model failed.";
  318. } else {
  319. MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
  320. }
  321. return FAILED;
  322. }
  323. MS_LOG(INFO) << name_ << " : Init for cost model success.";
  324. return SUCCESS;
  325. }
  326. Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) {
  327. if (input->size() < 2) {
  328. MS_LOG(ERROR) << name_ << " : The size of inputs small than 2.";
  329. return FAILED;
  330. }
  331. auto last_1st_value = input->at(input->size() - 1);
  332. auto last_2nd_value = input->at(input->size() - 2);
  333. input->pop_back();
  334. input->pop_back();
  335. input->push_back(last_1st_value);
  336. input->push_back(last_2nd_value);
  337. return SUCCESS;
  338. }
  339. Status MatMulBase::GenerateStrategies(int32_t stage_id) {
  340. if (GetAttrs() != SUCCESS) {
  341. MS_LOG(ERROR) << name_ << " : GetAttrs failed.";
  342. return FAILED;
  343. }
  344. CheckGlobalDeviceManager();
  345. std::vector<int32_t> dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
  346. size_t dev_num = dev_list.size();
  347. Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1];
  348. if (transpose_a_) {
  349. if (SwapLastTwoElements(&input0_shape) == FAILED) {
  350. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  351. }
  352. }
  353. if (transpose_b_) {
  354. if (SwapLastTwoElements(&input1_shape) == FAILED) {
  355. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  356. }
  357. }
  358. // The shape of input0 (input1)
  359. // E.g., input0 = [100, 200, 300], input1 = [300, 400]
  360. // Combining the input0_shape and input1_shape
  361. // E.g., combined_shape = [100, 200, 300, 400]
  362. is_auto_parallel_ = true;
  363. size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size();
  364. Dimensions combined_partitions;
  365. Shape combined_shape;
  366. // In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2
  367. if (input0_shape.size() >= input1_shape.size()) {
  368. combined_shape = input0_shape;
  369. combined_shape.push_back(input1_shape[input1_shape.size() - 1]);
  370. } else {
  371. combined_shape = input1_shape;
  372. combined_shape.push_back(input0_shape[input0_shape.size() - 2]);
  373. }
  374. std::function<void(uint32_t, size_t)> recursive = [&stage_id, &dev_num, &combined_partitions, &combined_shape,
  375. &input1_shape_size, &recursive, &input0_shape_size,
  376. this](uint32_t current_index, size_t n) {
  377. // Finishing the recursive steps, if the strategy is valid, then calculate the cost
  378. // for this operator under the strategy.
  379. if (current_index == combined_shape.size()) {
  380. StrategyPtr sp;
  381. if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) ==
  382. FAILED) {
  383. return;
  384. }
  385. if (this->SetCostUnderStrategy(sp) == FAILED) {
  386. MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed.";
  387. return;
  388. }
  389. } else {
  390. MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size
  391. << ", input1_shape_size: " << input1_shape_size;
  392. for (uint32_t i = 1; i <= n; i *= 2) {
  393. if (n % i == 0 && IntToSize(combined_shape[current_index]) % i == 0) {
  394. combined_partitions.push_back(i);
  395. recursive(current_index + 1, n / i);
  396. combined_partitions.pop_back();
  397. }
  398. }
  399. }
  400. };
  401. recursive(0, dev_num);
  402. if (strategy_cost_.empty()) {
  403. MS_LOG(EXCEPTION) << name_ << " : No available strategy.";
  404. }
  405. return Status::SUCCESS;
  406. }
  407. Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num,
  408. mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size,
  409. size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
  410. int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int>());
  411. if (!FULLY_USE_DEVICES) {
  412. if (IntToSize(product) > dev_num) {
  413. return FAILED;
  414. }
  415. } else {
  416. if (IntToSize(product) != dev_num) {
  417. return FAILED;
  418. }
  419. }
  420. Dimensions input0_partitions, input1_partitions;
  421. if (input0_shape_size >= input1_shape_size) {
  422. for (size_t i = 0; i < input0_shape_size; ++i) {
  423. input0_partitions.push_back(combined_partitions[i]);
  424. }
  425. if (input1_shape_size == 2) {
  426. input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]);
  427. input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
  428. } else {
  429. // input1_shape.size() > 2
  430. for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) {
  431. if (j == combined_partitions.size() - 3) {
  432. continue;
  433. }
  434. input1_partitions.push_back(combined_partitions[j]);
  435. }
  436. }
  437. } else {
  438. for (size_t i = 0; i < input1_shape_size; ++i) {
  439. input1_partitions.push_back(combined_partitions[i]);
  440. }
  441. for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) {
  442. input0_partitions.push_back(combined_partitions[j]);
  443. }
  444. input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
  445. input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]);
  446. }
  447. if (transpose_a_) {
  448. if (SwapLastTwoElements(&input0_partitions) == FAILED) {
  449. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  450. }
  451. }
  452. if (transpose_b_) {
  453. if (SwapLastTwoElements(&input1_partitions) == FAILED) {
  454. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  455. }
  456. }
  457. std::vector<Dimensions> stras;
  458. stras.push_back(input0_partitions);
  459. stras.push_back(input1_partitions);
  460. (*sp) = std::make_shared<Strategy>(stage_id, stras);
  461. return SUCCESS;
  462. }
  463. void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_tensor_vector) {
  464. TensorLayout tly;
  465. if (transpose_a_) {
  466. Shape replica_input0_shape(inputs_tensor_info_[0].shape());
  467. Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape());
  468. if (SwapLastTwoElements(&replica_input0_shape) == FAILED) {
  469. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  470. }
  471. if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) {
  472. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  473. }
  474. TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape);
  475. relica_inputs_tensor_vector->push_back(replica_input0_info);
  476. } else {
  477. relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]);
  478. }
  479. if (transpose_b_) {
  480. Shape replica_input1_shape(inputs_tensor_info_[1].shape());
  481. Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape());
  482. if (SwapLastTwoElements(&replica_input1_shape) == FAILED) {
  483. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  484. }
  485. if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) {
  486. MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
  487. }
  488. TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape);
  489. relica_inputs_tensor_vector->push_back(replica_input1_info);
  490. } else {
  491. relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]);
  492. }
  493. }
  494. Status MatMulBase::CheckForTensorSliceValid() const {
  495. if (!TENSOR_SLICE_ALIGNMENT_ENABLE) {
  496. return SUCCESS;
  497. }
  498. if (inputs_tensor_info_.empty()) {
  499. return FAILED;
  500. }
  501. for (auto &one_input_tensor : inputs_tensor_info_) {
  502. auto slice_shape = one_input_tensor.slice_shape();
  503. if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) ||
  504. (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) {
  505. return FAILED;
  506. }
  507. }
  508. return SUCCESS;
  509. }
  510. Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
  511. if (InitForCostModel(strategy) == FAILED) {
  512. if (is_auto_parallel_) {
  513. MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed.";
  514. } else {
  515. MS_LOG(ERROR) << name_ << " : Initialization under the strategy failed.";
  516. }
  517. return FAILED;
  518. }
  519. PrintStrategy(strategy);
  520. // Check whether the tensor slice of input_tensor_info is valid or not
  521. if (CheckForTensorSliceValid() != SUCCESS) {
  522. MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy.";
  523. return FAILED;
  524. }
  525. // Here, a replicated inputs_ is constructed for the transposed TensorInfo.
  526. std::vector<TensorInfo> relica_inputs_tensor_vector;
  527. InitTensorInfoForCost(&relica_inputs_tensor_vector);
  528. int32_t stage_id = strategy->GetInputStage();
  529. // Here, we use the origin outputs_, because we only use the slice size of the output tensor.
  530. // It does not matter whether the output tensor is transposed or not.
  531. double computation_cost =
  532. operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
  533. double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
  534. std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
  535. result->communication_without_parameter_ =
  536. operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
  537. result->communication_with_partial_para_ =
  538. result->communication_without_parameter_ +
  539. COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_);
  540. // Breaking ties for preferring data parallelization
  541. BreakingTiesForPerferringDataParallel(strategy, result);
  542. MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_
  543. << ", communication_cost: " << result->communication_cost_
  544. << ", communication_without_parameter_: " << result->communication_without_parameter_
  545. << ", communication_with_partial_para_: " << result->communication_with_partial_para_;
  546. // refine communication cost calculation for practice
  547. RefineForPracticalCost(result, false);
  548. std::shared_ptr<StrategyWithCost> swc =
  549. std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
  550. swc->cost_list.push_back(result);
  551. strategy_cost_.emplace_back(swc);
  552. return SUCCESS;
  553. }
  554. } // namespace parallel
  555. } // namespace mindspore