You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

activation_info.cc 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/ops_info/activation_info.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <vector>
  20. #include <utility>
  21. #include "ir/value.h"
  22. #include "parallel/auto_parallel/costmodel.h"
  23. #include "parallel/device_matrix.h"
  24. #include "parallel/strategy.h"
  25. namespace mindspore {
  26. namespace parallel {
  27. Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) {
  28. if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
  29. if (is_auto_parallel_) {
  30. MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed.";
  31. } else {
  32. MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
  33. }
  34. return FAILED;
  35. }
  36. return SUCCESS;
  37. }
  38. Status Activation::CheckStrategy(const StrategyPtr &strategy) {
  39. if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
  40. if (is_auto_parallel_) {
  41. MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
  42. } else {
  43. MS_LOG(ERROR) << name_ << " : Invalid strategy.";
  44. }
  45. return FAILED;
  46. }
  47. return SUCCESS;
  48. }
  49. Status ActivationInfo::GetAttrs() {
  50. if (attrs_.size() < ACTIVATION_ATTR_SIZE) {
  51. MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
  52. return FAILED;
  53. }
  54. if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
  55. MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
  56. << outputs_shape_.size() << "is wrong.";
  57. return FAILED;
  58. }
  59. auto iter = attrs_.find(ACTIVATION_TYPE);
  60. if (iter != attrs_.end()) {
  61. MS_EXCEPTION_IF_NULL(iter->second);
  62. if (iter->second->isa<StringImm>()) {
  63. std::string val = iter->second->cast<StringImmPtr>()->value();
  64. if ((val != RELU_TYPE) && (val != RELU6_TYPE) && (val != SIGMOID_TYPE)) {
  65. MS_LOG(ERROR) << name_ << " : Activation type is wrong.";
  66. return FAILED;
  67. }
  68. } else {
  69. MS_LOG(ERROR) << name_ << " : The value of activation_type is not string.";
  70. return FAILED;
  71. }
  72. }
  73. return SUCCESS;
  74. }
  75. Status ActivationOther::GetAttrs() {
  76. if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
  77. MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
  78. << outputs_shape_.size() << "is wrong.";
  79. return FAILED;
  80. }
  81. return SUCCESS;
  82. }
  83. Status Activation::GenerateStrategies(int32_t stage_id) {
  84. if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
  85. MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
  86. << outputs_shape_.size() << "is wrong.";
  87. return FAILED;
  88. }
  89. is_auto_parallel_ = true;
  90. Shape input0_split(inputs_shape_[0].size(), 1);
  91. Shapes splittable_inputs = {input0_split};
  92. std::vector<StrategyPtr> sp_vector;
  93. if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
  94. MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
  95. return FAILED;
  96. }
  97. size_t success = 0;
  98. for (auto &sp : sp_vector) {
  99. if (SetCostUnderStrategy(sp) == SUCCESS) {
  100. success++;
  101. MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
  102. PrintStrategy(sp);
  103. }
  104. }
  105. return SUCCESS;
  106. }
  107. Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
  108. if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
  109. if (is_auto_parallel_) {
  110. MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
  111. } else {
  112. MS_LOG(ERROR) << name_ << " : Invalid strategy.";
  113. }
  114. return FAILED;
  115. }
  116. std::vector<Dimensions> stra = strategy->GetInputDim();
  117. Dimensions input_strategy = stra.at(0);
  118. for (auto &element : axis_) {
  119. int32_t axis_index = element;
  120. if (element < 0) {
  121. size_t input_dim = inputs_shape_.at(0).size();
  122. axis_index = static_cast<int32_t>(input_dim) + element;
  123. }
  124. int32_t axis_strategy = input_strategy.at(IntToSize(axis_index));
  125. // Dimension corresponding to axis is un-splittable
  126. if (axis_strategy != MIN_SLICE_NUM) {
  127. if (is_auto_parallel_) {
  128. MS_LOG(DEBUG) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1";
  129. } else {
  130. MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1";
  131. }
  132. return FAILED;
  133. }
  134. }
  135. return SUCCESS;
  136. }
  137. Status Softmax::GetAttrs() {
  138. if (attrs_.size() < SOFTMAX_ATTR_SIZE) {
  139. MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
  140. return FAILED;
  141. }
  142. auto iter = attrs_.find(AXIS);
  143. if (iter != attrs_.end()) {
  144. MS_EXCEPTION_IF_NULL(iter->second);
  145. if (iter->second->isa<Int32Imm>()) { // the axis is a number
  146. int32_t axis_element = iter->second->cast<Int32ImmPtr>()->value();
  147. axis_.push_back(axis_element);
  148. MS_LOG(INFO) << name_ << " : The axis is int, value is " << axis_element;
  149. } else if (iter->second->isa<ValueTuple>()) { // the axis is a tuple
  150. ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
  151. if (value_tuple == nullptr) {
  152. MS_LOG(ERROR) << name_ << " : The value_tuple is nullptr.";
  153. return FAILED;
  154. }
  155. std::vector<ValuePtr> value_vector = value_tuple->value();
  156. (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_),
  157. [](const ValuePtr &value) { return static_cast<int32_t>(GetValue<int>(value)); });
  158. if (axis_.empty()) {
  159. MS_LOG(ERROR) << name_ << " : The axis tuple is empty.";
  160. return FAILED;
  161. }
  162. MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ShapeToString(axis_);
  163. } else {
  164. MS_LOG(ERROR) << name_ << " : The value of axis is not int or tuple int.";
  165. return FAILED;
  166. }
  167. }
  168. if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
  169. MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
  170. return FAILED;
  171. }
  172. // for example: tensor dimension is 4, then axis range [-4, 3]
  173. int32_t dim = SizeToInt(inputs_shape_.at(0).size());
  174. auto it =
  175. std::find_if(axis_.begin(), axis_.end(), [dim](int32_t element) { return ((element >= dim) || (element < -dim)); });
  176. if (it != axis_.end()) {
  177. MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << -dim << ", " << dim - 1 << "].";
  178. return FAILED;
  179. }
  180. return SUCCESS;
  181. }
  182. Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) {
  183. if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
  184. if (is_auto_parallel_) {
  185. MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed.";
  186. } else {
  187. MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
  188. }
  189. return FAILED;
  190. }
  191. return SUCCESS;
  192. }
  193. Status Softmax::GenerateStrategies(int32_t stage_id) {
  194. if (GetAttrs() != SUCCESS) {
  195. MS_LOG(ERROR) << name_ << " : GetAttrs failed.";
  196. return FAILED;
  197. }
  198. if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
  199. MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
  200. return FAILED;
  201. }
  202. is_auto_parallel_ = true;
  203. Shape input0_split;
  204. (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1);
  205. for (auto &element : axis_) {
  206. int32_t axis_index = element;
  207. if (element < 0) {
  208. size_t input_dim = inputs_shape_.at(0).size();
  209. axis_index = static_cast<int32_t>(input_dim) + element;
  210. }
  211. input0_split[IntToSize(axis_index)] = 0;
  212. }
  213. Shapes splittable_inputs = {input0_split};
  214. std::vector<StrategyPtr> sp_vector;
  215. if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
  216. MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs failed.";
  217. return FAILED;
  218. }
  219. size_t success = 0;
  220. for (auto &sp : sp_vector) {
  221. if (SetCostUnderStrategy(sp) == SUCCESS) {
  222. success++;
  223. MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy.";
  224. PrintStrategy(sp);
  225. }
  226. }
  227. return SUCCESS;
  228. }
  229. Status ActivationBase::InferDevMatrixShape() {
  230. std::vector<Dimensions> stra = strategy_->GetInputDim();
  231. Dimensions input_strategy = stra.at(0);
  232. dev_matrix_shape_ = input_strategy;
  233. return SUCCESS;
  234. }
  235. Status ActivationBase::InferMirrorOps() {
  236. mirror_ops_.clear();
  237. Shape tensor_map = inputs_tensor_map_[0];
  238. std::vector<Group> group;
  239. if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
  240. MS_LOG(ERROR) << name_ << " : Create group failed.";
  241. return FAILED;
  242. }
  243. OperatorVector mirror_op;
  244. if (group.empty()) {
  245. MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
  246. return SUCCESS;
  247. } else {
  248. mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
  249. mirror_ops_.push_back(mirror_op);
  250. std::string group_name = group[0].name();
  251. MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
  252. }
  253. return SUCCESS;
  254. }
  255. Status ActivationBase::InferForwardCommunication() {
  256. // do nothing
  257. return SUCCESS;
  258. }
  259. Status ActivationBase::InferTensorMap() {
  260. std::vector<int32_t> tensor_map_index;
  261. size_t size = inputs_shape_.at(0).size();
  262. // such as 4: tensor_map_index [3,2,1,0]
  263. for (size_t i = 0; i < size; ++i) {
  264. tensor_map_index.push_back((int32_t)(size - i - 1));
  265. }
  266. inputs_tensor_map_.push_back(tensor_map_index);
  267. outputs_tensor_map_.push_back(tensor_map_index);
  268. return SUCCESS;
  269. }
  270. Status ActivationBase::InferTensorInfo() {
  271. // infer tensor shape
  272. Shape input_shape = inputs_shape_.at(0);
  273. // infer slice shape
  274. Shapes inputs_slice_shape, outputs_slice_shape;
  275. Strategys inputs_strategy = strategy_->GetInputDim();
  276. Strategys outputs_strategy = {inputs_strategy.at(0)};
  277. if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
  278. return FAILED;
  279. }
  280. Shape input_slice_shape = inputs_slice_shape.at(0);
  281. TensorLayout input_tensor_layout;
  282. if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
  283. return FAILED;
  284. }
  285. TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
  286. inputs_tensor_info_.push_back(input_tensor_info);
  287. outputs_tensor_info_.push_back(input_tensor_info); // the same as input
  288. return SUCCESS;
  289. }
  290. Status ActivationBase::Init(const StrategyPtr &strategy) {
  291. if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
  292. MS_LOG(ERROR) << name_ << " : Init failed.";
  293. return FAILED;
  294. }
  295. MS_LOG(INFO) << name_ << " : Init success.";
  296. return SUCCESS;
  297. }
  298. Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) {
  299. if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
  300. if (is_auto_parallel_) {
  301. MS_LOG(DEBUG) << name_ << " : Init for cost model failed.";
  302. } else {
  303. MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
  304. }
  305. return FAILED;
  306. }
  307. MS_LOG(INFO) << name_ << " : Init for cost model success.";
  308. return SUCCESS;
  309. }
  310. Status CastInfo::InferMirrorOps() {
  311. mirror_ops_.clear();
  312. Shape tensor_map = inputs_tensor_map_[0];
  313. std::vector<Group> group;
  314. if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
  315. MS_LOG(ERROR) << name_ << " : Create group failed.";
  316. return FAILED;
  317. }
  318. OperatorVector mirror_op;
  319. OperatorVector op_for_value;
  320. if (group.empty()) {
  321. MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
  322. return SUCCESS;
  323. } else {
  324. mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
  325. mirror_ops_.push_back(mirror_op);
  326. mirror_ops_.push_back(op_for_value);
  327. std::string group_name = group[0].name();
  328. MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
  329. }
  330. return SUCCESS;
  331. }
  332. Status ExpandDimsInfo::GetAttrs() {
  333. if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) {
  334. MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size();
  335. return FAILED;
  336. }
  337. if (!input_value_.back()->isa<Int32Imm>()) {
  338. MS_LOG(ERROR) << name_ << ": The type of axis is not int";
  339. return FAILED;
  340. }
  341. int32_t axis = GetValue<int32_t>(input_value_.back());
  342. if (inputs_shape_.empty()) {
  343. MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
  344. return FAILED;
  345. }
  346. int32_t dim = SizeToInt(inputs_shape_[0].size());
  347. if ((axis > dim) || (axis < -dim - 1)) {
  348. MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim - 1 << ", " << dim << "]";
  349. return FAILED;
  350. }
  351. if (axis < 0) {
  352. positive_axis_ = dim + axis + 1;
  353. } else {
  354. positive_axis_ = axis;
  355. }
  356. MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_;
  357. return SUCCESS;
  358. }
  359. Status ExpandDimsInfo::InferTensorMap() {
  360. if (inputs_shape_.empty()) {
  361. MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
  362. return FAILED;
  363. }
  364. // for example: if the dimension of input is 3, and the axis is 2,
  365. // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0]
  366. std::vector<int32_t> input_tensor_map, output_tensor_map;
  367. size_t size = inputs_shape_[0].size();
  368. for (size_t i = 0; i < size; ++i) {
  369. input_tensor_map.push_back(SizeToInt(size - i - 1));
  370. }
  371. inputs_tensor_map_.push_back(input_tensor_map);
  372. output_tensor_map = input_tensor_map;
  373. if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(size))) {
  374. MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
  375. return FAILED;
  376. }
  377. (void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP);
  378. outputs_tensor_map_.push_back(output_tensor_map);
  379. MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
  380. << ", and the tensor map of output is " << ShapeToString(output_tensor_map);
  381. return SUCCESS;
  382. }
  383. Status ExpandDimsInfo::InferTensorStrategy() {
  384. if (strategy_ == nullptr) {
  385. MS_LOG(ERROR) << name_ << ": The strategy is null";
  386. return FAILED;
  387. }
  388. inputs_strategy_ = strategy_->GetInputDim();
  389. if (inputs_strategy_.empty()) {
  390. MS_LOG(ERROR) << name_ << ": The strategy is empty";
  391. return FAILED;
  392. }
  393. Shape output_strategy = inputs_strategy_[0];
  394. if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(output_strategy.size()))) {
  395. MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
  396. return FAILED;
  397. }
  398. (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY);
  399. outputs_strategy_ = {output_strategy};
  400. return SUCCESS;
  401. }
  402. Status ExpandDimsInfo::InferTensorInfo() {
  403. if (inputs_shape_.empty() || outputs_shape_.empty()) {
  404. MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty";
  405. return FAILED;
  406. }
  407. if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
  408. MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty";
  409. return FAILED;
  410. }
  411. Shape input_shape = inputs_shape_[0];
  412. Shape output_shape = outputs_shape_[0];
  413. // infer slice shape
  414. if (InferTensorStrategy() != SUCCESS) {
  415. MS_LOG(ERROR) << name_ << ": Infer tensor strategy failed";
  416. return FAILED;
  417. }
  418. Shapes inputs_slice_shape, outputs_slice_shape;
  419. if (InferSliceShape(inputs_strategy_, outputs_strategy_, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
  420. MS_LOG(ERROR) << name_ << ": Infer slice shape failed";
  421. return FAILED;
  422. }
  423. if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) {
  424. MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty";
  425. return FAILED;
  426. }
  427. Shape input_slice_shape = inputs_slice_shape[0];
  428. Shape output_slice_shape = outputs_slice_shape[0];
  429. TensorLayout input_tensor_layout, output_tensor_layout;
  430. if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
  431. MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed";
  432. return FAILED;
  433. }
  434. if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) {
  435. MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed";
  436. return FAILED;
  437. }
  438. TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
  439. TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
  440. inputs_tensor_info_.push_back(input_tensor_info);
  441. outputs_tensor_info_.push_back(output_tensor_info);
  442. return SUCCESS;
  443. }
  444. Status ExpandDimsInfo::InferMirrorOps() {
  445. mirror_ops_.clear();
  446. if (inputs_tensor_map_.empty()) {
  447. MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty";
  448. return FAILED;
  449. }
  450. std::vector<Group> group;
  451. if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) {
  452. MS_LOG(ERROR) << name_ << ": Create group failed";
  453. return FAILED;
  454. }
  455. if (group.empty()) {
  456. MS_LOG(INFO) << name_ << ": No need to create mirror ops";
  457. return SUCCESS;
  458. }
  459. OperatorVector mirror_op, placeholder_op;
  460. mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
  461. mirror_ops_.push_back(mirror_op);
  462. mirror_ops_.push_back(placeholder_op);
  463. MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name();
  464. return SUCCESS;
  465. }
  466. Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) {
  467. std::vector<int32_t> axis;
  468. auto axis_list = value_tuple->value();
  469. if (inputs_shape_.empty()) {
  470. MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
  471. return FAILED;
  472. }
  473. Shape input_shape = inputs_shape_.at(0);
  474. size_t input_size = input_shape.size();
  475. // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1.
  476. if (axis_list.empty()) {
  477. for (size_t i = 0; i < input_size; ++i) {
  478. if (input_shape[i] == 1) {
  479. axis.push_back(i);
  480. }
  481. }
  482. axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
  483. return SUCCESS;
  484. }
  485. // convert negative axis to positive.
  486. for (auto &dim : axis_list) {
  487. if (!dim->isa<Int32Imm>()) {
  488. MS_LOG(ERROR) << name_ << ": The type of axis is not int";
  489. return FAILED;
  490. }
  491. int32_t dim_value = GetValue<int32_t>(dim);
  492. int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value;
  493. axis.push_back(positive_value);
  494. }
  495. axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
  496. return SUCCESS;
  497. }
  498. Status SqueezeInfo::GetAttrs() {
  499. auto iter = attrs_.find(AXIS);
  500. if (iter == attrs_.end()) {
  501. MS_LOG(ERROR) << name_ << ": Can't find axis attribute.";
  502. return FAILED;
  503. }
  504. MS_EXCEPTION_IF_NULL(iter->second);
  505. auto value_tuple = iter->second->cast<ValueTuplePtr>();
  506. MS_EXCEPTION_IF_NULL(value_tuple);
  507. InferAxis(value_tuple);
  508. attrs_[AXIS] = axis_;
  509. return SUCCESS;
  510. }
  511. Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) {
  512. Attr attr = std::make_pair(AXIS, axis_);
  513. OperatorAttrs attrs = {attr};
  514. OperatorParams params;
  515. OperatorArgs args = std::make_pair(attrs, params);
  516. replace_op_ = {std::make_pair(SQUEEZE, args)};
  517. return SUCCESS;
  518. }
  519. Status SqueezeInfo::InferTensorMap() {
  520. // for example: if the shape of input is [32, 32, 1], and the axis is (2, ),
  521. // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1]
  522. std::vector<int32_t> input_tensor_map, output_tensor_map;
  523. if (inputs_shape_.empty()) {
  524. MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
  525. return FAILED;
  526. }
  527. size_t size = inputs_shape_[0].size();
  528. std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_);
  529. for (size_t i = 0; i < size; ++i) {
  530. size_t index = size - i - 1;
  531. auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i));
  532. if (iter == axis.end()) {
  533. output_tensor_map.push_back(SizeToInt(index));
  534. }
  535. input_tensor_map.push_back(SizeToInt(index));
  536. }
  537. inputs_tensor_map_.push_back(input_tensor_map);
  538. outputs_tensor_map_.push_back(output_tensor_map);
  539. MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
  540. << ", and the tensor map of output is " << ShapeToString(output_tensor_map);
  541. return SUCCESS;
  542. }
  543. Status SqueezeInfo::InferTensorInfo() {
  544. if (inputs_shape_.empty() || outputs_shape_.empty()) {
  545. MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty";
  546. return FAILED;
  547. }
  548. if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
  549. MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty";
  550. return FAILED;
  551. }
  552. Shape input_shape = inputs_shape_[0];
  553. Shape output_shape = outputs_shape_[0];
  554. // infer slice shape
  555. Shapes inputs_slice_shape, outputs_slice_shape;
  556. Strategys inputs_strategy = strategy_->GetInputDim();
  557. Dimensions output_strategy;
  558. std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_);
  559. for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
  560. auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i));
  561. if (iter == axis.end()) {
  562. output_strategy.push_back(inputs_strategy[0].at(i));
  563. }
  564. }
  565. Strategys outputs_strategy = {output_strategy};
  566. if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
  567. MS_LOG(ERROR) << name_ << ": Infer slice shape failed";
  568. return FAILED;
  569. }
  570. if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) {
  571. MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty";
  572. return FAILED;
  573. }
  574. Shape input_slice_shape = inputs_slice_shape[0];
  575. Shape output_slice_shape = outputs_slice_shape[0];
  576. // infer tensor layout
  577. TensorLayout input_tensor_layout, output_tensor_layout;
  578. if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
  579. MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed";
  580. return FAILED;
  581. }
  582. if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) {
  583. MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed";
  584. return FAILED;
  585. }
  586. TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
  587. TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
  588. inputs_tensor_info_.push_back(input_tensor_info);
  589. outputs_tensor_info_.push_back(output_tensor_info);
  590. return SUCCESS;
  591. }
  592. Status SqueezeInfo::Init(const StrategyPtr &strategy) {
  593. if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
  594. MS_LOG(ERROR) << name_ << " : Init failed.";
  595. }
  596. if (InferReplaceOps(strategy) != SUCCESS) {
  597. MS_LOG(ERROR) << name_ << " : Infer replace ops failed";
  598. }
  599. MS_LOG(INFO) << name_ << " : Init success.";
  600. return SUCCESS;
  601. }
  602. } // namespace parallel
  603. } // namespace mindspore