You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reshape_info.cc 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/ops_info/reshape_info.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "parallel/device_manager.h"
  20. #include "parallel/device_matrix.h"
  21. #include "parallel/step_parallel.h"
  22. #include "utils/convert_utils.h"
  23. #include "utils/log_adapter.h"
  24. namespace mindspore {
  25. namespace parallel {
  26. Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) {
  27. if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
  28. if (is_auto_parallel_) {
  29. MS_LOG(DEBUG) << name_ << ": Invalid strategy.";
  30. } else {
  31. MS_LOG(ERROR) << name_ << ": Invalid strategy.";
  32. }
  33. return FAILED;
  34. }
  35. size_t strategy_size = strategy->GetInputNumber();
  36. if (strategy_size != 1) {
  37. if (is_auto_parallel_) {
  38. MS_LOG(DEBUG) << name_ << ": Invalid strategy size " << strategy_size;
  39. } else {
  40. MS_LOG(ERROR) << name_ << ": Invalid strategy size " << strategy_size;
  41. }
  42. return FAILED;
  43. }
  44. std::vector<Dimensions> stra = strategy->GetInputDim();
  45. for (size_t i = 0; i < strategy_size; ++i) {
  46. Shape sub_strategy = stra.at(i);
  47. size_t strategy_len = sub_strategy.size();
  48. bool flag = false;
  49. for (size_t j = 0; j < strategy_len; ++j) {
  50. int32_t strategy_value = sub_strategy.at(j);
  51. if (strategy_value > 1) {
  52. if (flag) {
  53. if (is_auto_parallel_) {
  54. MS_LOG(DEBUG) << name_ << ": Only support batch parallel strategy.";
  55. } else {
  56. MS_LOG(ERROR) << name_ << ": Only support batch parallel strategy.";
  57. }
  58. return FAILED;
  59. }
  60. flag = true;
  61. }
  62. }
  63. }
  64. return SUCCESS;
  65. }
  66. /*
  67. * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of
  68. * device matrix
  69. * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
  70. */
  71. Status ReshapeInfo::InferDevMatrixShape() {
  72. std::vector<Dimensions> stra = strategy_->GetInputDim();
  73. input_strategy_ = stra.at(0);
  74. dev_matrix_shape_.push_back(input_strategy_[0]);
  75. return SUCCESS;
  76. }
  77. /*
  78. * there is no Parameter for Reshape Primitive, so no need to do allreduce
  79. */
  80. Status ReshapeInfo::InferMirrorOps() {
  81. mirror_ops_.clear();
  82. Shape input_tensor_map = input_layout_.tensor_map().array();
  83. std::vector<Group> input_group;
  84. if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) {
  85. MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed.";
  86. return FAILED;
  87. }
  88. OperatorVector op_for_input;
  89. if (input_group.empty()) {
  90. MS_LOG(INFO) << name_ << ": The mirror ops is empty.";
  91. return SUCCESS;
  92. }
  93. if (!input_group.empty()) {
  94. op_for_input = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum());
  95. std::string group_name = input_group[0].name();
  96. MS_LOG(INFO) << name_ << ": Create the mirror ops for input_a success, group is " << group_name;
  97. }
  98. mirror_ops_.push_back(op_for_input);
  99. OperatorVector op_for_input_empty;
  100. mirror_ops_.push_back(op_for_input_empty);
  101. return SUCCESS;
  102. }
  103. /*
  104. * there is no reduction dimension for forward computation of Reshape Primitive, so no need to do allreduce
  105. */
  106. Status ReshapeInfo::InferForwardCommunication() { return SUCCESS; }
  107. /*
  108. * get shape input of Reshape Primitive
  109. * the result is saved in parameter_input_v_
  110. * not support -1
  111. */
  112. Status ReshapeInfo::GetParameterInput() {
  113. if (input_value_[1] == nullptr) {
  114. MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr.";
  115. return FAILED;
  116. }
  117. std::vector<ValuePtr> elements;
  118. ValueTuplePtr dim_tuple = input_value_[1]->cast<ValueTuplePtr>();
  119. if (dim_tuple == nullptr) {
  120. MS_LOG(ERROR) << name_ << ": Input_value_[1] must be ValueTuplePtr.";
  121. return FAILED;
  122. }
  123. elements = dim_tuple->value();
  124. if (elements.size() != outputs_shape_[0].size()) {
  125. MS_LOG(ERROR) << name_ << ": Elements size must equal to outputs shape[0] size.";
  126. return FAILED;
  127. }
  128. for (auto &element : elements) {
  129. MS_EXCEPTION_IF_NULL(element);
  130. if (element->isa<Int32Imm>()) {
  131. int32_t axis = element->cast<Int32ImmPtr>()->value();
  132. parameter_input_v_.push_back(axis);
  133. } else {
  134. MS_LOG(ERROR) << name_ << ": The value of axis must be int32.";
  135. return FAILED;
  136. }
  137. }
  138. return SUCCESS;
  139. }
  140. Status ReshapeInfo::ComputeReplaceOp() {
  141. RankList dev_list = global_device_list();
  142. TensorRedistribution tensor_redistribution(true, true);
  143. if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) {
  144. MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
  145. return FAILED;
  146. }
  147. MS_LOG(INFO) << name_ << ": input " << input_layout_.ToString();
  148. MS_LOG(INFO) << name_ << ": output " << output_layout_.ToString();
  149. MS_LOG(INFO) << name_ << ": dev_list " << dev_list.size();
  150. RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
  151. if (redistribution_oplist_ptr == nullptr) {
  152. MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
  153. return FAILED;
  154. }
  155. replace_op_ = redistribution_oplist_ptr->first;
  156. replace_op_info_ = redistribution_oplist_ptr->second;
  157. MS_LOG(INFO) << name_ << ": replace op size = " << replace_op_.size();
  158. return SUCCESS;
  159. }
  160. /*
  161. * the first dimension of input tensor map and output tensor map is set to the last dimension of device arrangement,
  162. * all other dimension is set to None
  163. * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
  164. */
  165. Status ReshapeInfo::InferTensorMap() {
  166. if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
  167. MS_LOG(ERROR) << name_ << ": inputs shape and outputs shape size must be 1. inputs shape and outputs shape are "
  168. << inputs_shape_.size() << " and " << outputs_shape_.size();
  169. return FAILED;
  170. }
  171. std::vector<int32_t> tensor_map_index_input;
  172. tensor_map_index_input.push_back(0);
  173. for (size_t j = 1; j < inputs_shape_[0].size(); ++j) {
  174. tensor_map_index_input.push_back(MAP_NONE);
  175. }
  176. inputs_tensor_map_.push_back(tensor_map_index_input);
  177. std::vector<int32_t> tensor_map_index_output;
  178. tensor_map_index_output.push_back(0);
  179. for (size_t j = 1; j < outputs_shape_[0].size(); ++j) {
  180. tensor_map_index_output.push_back(MAP_NONE);
  181. }
  182. outputs_tensor_map_.push_back(tensor_map_index_output);
  183. return SUCCESS;
  184. }
  185. /*
  186. * the output tensor strategy is the same as input tensor strategy
  187. * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
  188. */
  189. Strategys ReshapeInfo::GetOutputsStrategy() {
  190. Strategys outputs_strategy;
  191. std::vector<int32_t> strategy;
  192. strategy.push_back(input_strategy_[0]);
  193. for (size_t j = 1; j < outputs_shape_[0].size(); ++j) {
  194. strategy.push_back(1);
  195. }
  196. outputs_strategy.push_back(strategy);
  197. return outputs_strategy;
  198. }
  199. Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
  200. if (inputs_layout == nullptr || outputs_layout == nullptr) {
  201. MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null.";
  202. return FAILED;
  203. }
  204. Arrangement dev_matrix;
  205. Status status = dev_matrix.Init(dev_matrix_shape_);
  206. if (status != Status::SUCCESS) {
  207. return status;
  208. }
  209. // infer input tensor info
  210. Shape shape_array_in = inputs_shape_.at(0);
  211. TensorMap tensor_map_array_in = inputs_tensor_map_.at(0);
  212. TensorLayout tensor_layout_in;
  213. Map tensor_map_in;
  214. status = tensor_map_in.Init(tensor_map_array_in);
  215. if (status != Status::SUCCESS) {
  216. return status;
  217. }
  218. Arrangement shape_in;
  219. status = shape_in.Init(shape_array_in);
  220. if (status != Status::SUCCESS) {
  221. return status;
  222. }
  223. (void)tensor_layout_in.Init(dev_matrix, tensor_map_in, shape_in);
  224. inputs_layout->push_back(tensor_layout_in);
  225. // infer output tensor info
  226. Shape shape_array_out = outputs_shape_.at(0);
  227. TensorMap tensor_map_array_out = outputs_tensor_map_.at(0);
  228. TensorLayout tensor_layout_out;
  229. Map tensor_map_out;
  230. status = tensor_map_out.Init(tensor_map_array_out);
  231. if (status != Status::SUCCESS) {
  232. return status;
  233. }
  234. Arrangement shape_out;
  235. status = shape_out.Init(shape_array_out);
  236. if (status != Status::SUCCESS) {
  237. return status;
  238. }
  239. (void)tensor_layout_out.Init(dev_matrix, tensor_map_out, shape_out);
  240. outputs_layout->push_back(tensor_layout_out);
  241. input_layout_ = tensor_layout_in;
  242. output_layout_ = tensor_layout_out;
  243. return SUCCESS;
  244. }
  245. Status ReshapeInfo::InferTensorInfo() {
  246. Shapes inputs_slice_shape, outputs_slice_shape;
  247. Strategys inputs_strategy = strategy_->GetInputDim();
  248. Strategys outputs_strategy = GetOutputsStrategy();
  249. if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
  250. return FAILED;
  251. }
  252. TensorLayouts inputs_layout, outputs_layout;
  253. if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
  254. return FAILED;
  255. }
  256. TensorLayout tensor_layout_in = inputs_layout.at(0);
  257. TensorLayout tensor_layout_out = outputs_layout.at(0);
  258. Shape shape_array_in = inputs_shape_.at(0);
  259. Shape slice_shape_in = inputs_slice_shape.at(0);
  260. Shape shape_array_out = outputs_shape_.at(0);
  261. Shape slice_shape_out = outputs_slice_shape.at(0);
  262. TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in);
  263. TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out);
  264. inputs_tensor_info_.push_back(tensor_info_in);
  265. outputs_tensor_info_.push_back(tensor_info_out);
  266. return SUCCESS;
  267. }
  268. void ReshapeInfo::InferTensorInfoByLayout() {
  269. TensorInfo tensor_info_in(input_layout_);
  270. TensorInfo tensor_info_out(output_layout_);
  271. inputs_tensor_info_.push_back(tensor_info_in);
  272. outputs_tensor_info_.push_back(tensor_info_out);
  273. }
  274. /*
  275. * compute parameter_input_v_ during this method
  276. */
  277. Status ReshapeInfo::GetAttrs() { return GetParameterInput(); }
  278. void ReshapeInfo::device_number(const StrategyPtr &strategy) {
  279. int32_t stage = 0;
  280. if (strategy != nullptr) {
  281. stage = strategy->GetInputStage();
  282. }
  283. CheckGlobalDeviceManager();
  284. global_device_list_ = g_device_manager->GetDeviceListByStageId(stage);
  285. dev_num_ = SizeToInt(global_device_list_.size());
  286. MS_ASSERT(dev_num_ > 0);
  287. }
  288. Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) {
  289. std::vector<int32_t> tensor_map_index;
  290. for (size_t i = 0; i < shape.size(); i++) {
  291. tensor_map_index.push_back(MAP_NONE);
  292. }
  293. Status status = layout->InitFromVector({dev_num_}, tensor_map_index, shape);
  294. if (status != Status::SUCCESS) {
  295. MS_LOG(ERROR) << name_ << ": InferDefaultLayout failed.";
  296. return status;
  297. }
  298. return Status::SUCCESS;
  299. }
  300. Status ReshapeInfo::Init(const StrategyPtr &strategy) {
  301. ResetQueueMember();
  302. device_number(strategy);
  303. if (strategy) {
  304. if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
  305. MS_LOG(ERROR) << name_ << ": Init failed.";
  306. return FAILED;
  307. }
  308. } else {
  309. if (!input_layout_set_flag_) {
  310. MS_ASSERT(inputs_shape_.size() == 1);
  311. Status status = InferDefaultLayout(inputs_shape_.at(0), &input_layout_);
  312. if (status != SUCCESS) {
  313. MS_LOG(ERROR) << name_ << ": infer input default layout failed.";
  314. return status;
  315. }
  316. }
  317. if (!output_layout_set_flag_) {
  318. MS_ASSERT(output_layout_.size() == 1);
  319. Status status = InferDefaultLayout(outputs_shape_.at(0), &output_layout_);
  320. if (status != SUCCESS) {
  321. MS_LOG(ERROR) << name_ << ": infer output default layout failed.";
  322. return status;
  323. }
  324. }
  325. inputs_tensor_map_.push_back(input_layout_.tensor_map().array());
  326. outputs_tensor_map_.push_back(output_layout_.tensor_map().array());
  327. InferTensorInfoByLayout();
  328. // change dev_matrix_shape_ to input_layout_ device_arrangement before InferMirrorOps
  329. dev_matrix_shape_ = input_layout_.device_arrangement().array();
  330. if (InferMirrorOps() != SUCCESS) {
  331. MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
  332. return FAILED;
  333. }
  334. // change dev_matrix_shape_ to output_layout_ device_arrangement before InferVirtualDivOps
  335. dev_matrix_shape_ = output_layout_.device_arrangement().array();
  336. if (InferVirtualDivOps() != SUCCESS) {
  337. MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
  338. return FAILED;
  339. }
  340. }
  341. Status status = ComputeReplaceOp();
  342. if (status != SUCCESS) {
  343. MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
  344. return status;
  345. }
  346. return SUCCESS;
  347. }
  348. Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) {
  349. if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
  350. if (is_auto_parallel_) {
  351. MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
  352. } else {
  353. MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
  354. }
  355. return FAILED;
  356. }
  357. MS_LOG(INFO) << name_ << ": Init for cost model success.";
  358. return SUCCESS;
  359. }
  360. Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
  361. if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
  362. if (is_auto_parallel_) {
  363. MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed.";
  364. } else {
  365. MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
  366. }
  367. return FAILED;
  368. }
  369. return SUCCESS;
  370. }
  371. Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {
  372. if (GetAttrs() != SUCCESS) {
  373. MS_LOG(ERROR) << name_ << ": GetAttrs failed.";
  374. return FAILED;
  375. }
  376. if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
  377. MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", "
  378. << outputs_shape_.size();
  379. return FAILED;
  380. }
  381. is_auto_parallel_ = true;
  382. Shape input0_split;
  383. input0_split.emplace_back(1);
  384. (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 1, 0);
  385. Shapes splittable_inputs = {input0_split};
  386. std::vector<StrategyPtr> sp_vector;
  387. if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
  388. MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
  389. return FAILED;
  390. }
  391. size_t success = 0;
  392. for (auto &sp : sp_vector) {
  393. if (SetCostUnderStrategy(sp) == SUCCESS) {
  394. success++;
  395. MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
  396. PrintStrategy(sp);
  397. }
  398. }
  399. return SUCCESS;
  400. }
  401. } // namespace parallel
  402. } // namespace mindspore