You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arrangement.cc 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/tensor_layout/arrangement.h"
  17. #include <algorithm>
  18. #include <iostream>
  19. #include <utility>
  20. #include "common/utils.h"
  21. #include "parallel/status.h"
  22. #include "parallel/tensor_layout/shape_util.h"
  23. #include "utils/convert_utils.h"
  24. #include "utils/log_adapter.h"
  25. namespace mindspore {
  26. namespace parallel {
  27. Status Arrangement::Init(const std::vector<int32_t> &array) {
  28. Status status = Array::Init(array);
  29. if (status != Status::SUCCESS) {
  30. return Status::FAILED;
  31. }
  32. if (!IsValidArrangement()) {
  33. MS_LOG(ERROR) << "invalid arrangement " << this->ToString();
  34. return Status::FAILED;
  35. }
  36. ComputeSize();
  37. return Status::SUCCESS;
  38. }
  39. bool Arrangement::IsValidArrangement() {
  40. return !std::any_of(array_.begin(), array_.end(), [](int32_t value) { return value <= 0; });
  41. }
  42. void Arrangement::ComputeSize() {
  43. size_ = 1;
  44. for (auto &value : array_) {
  45. size_ *= value;
  46. }
  47. }
  48. /*
  49. * if GetDimSize() = 0, return []
  50. * if value <= array_[0], return [value]
  51. * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]],
  52. * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1],
  53. * if value > size_, return []
  54. */
  55. std::vector<int32_t> Arrangement::GetFrontElementByValue(int32_t value) const {
  56. std::vector<int32_t> out;
  57. if (GetDimSize() == 0) {
  58. return out;
  59. }
  60. if (value <= size_) {
  61. int32_t size = 1;
  62. uint32_t shape_list_idx = 0;
  63. while (size < value) {
  64. size *= array_[shape_list_idx];
  65. if (size <= value) {
  66. out.push_back(array_[shape_list_idx]);
  67. } else {
  68. if (size == 0) {
  69. MS_LOG(ERROR) << "The size is 0";
  70. out.clear();
  71. return out;
  72. }
  73. out.push_back(value * array_[shape_list_idx] / size);
  74. }
  75. shape_list_idx++;
  76. }
  77. }
  78. return out;
  79. }
  80. std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft(
  81. const std::vector<Arrangement> &expand_list) const {
  82. if (expand_list.size() != GetDimSize()) {
  83. return nullptr;
  84. }
  85. std::vector<int32_t> new_shape;
  86. for (uint32_t i = 0; i < expand_list.size(); i++) {
  87. std::vector<int32_t> expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i));
  88. if (expand_shape.empty()) {
  89. new_shape.push_back(GetDimByIdx(i));
  90. } else {
  91. (void)new_shape.insert(new_shape.end(), expand_shape.begin(), expand_shape.end());
  92. }
  93. }
  94. Arrangement arrangement_new;
  95. (void)arrangement_new.Init(new_shape);
  96. return std::make_shared<Arrangement>(arrangement_new);
  97. }
  98. /*
  99. * example:
  100. * expand_shape = [4, 2, 2, 2]
  101. * array_ = [8, 4],
  102. * arrangement_list = [[4, 2], [2, 2]]
  103. */
  104. std::shared_ptr<std::vector<Arrangement>> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const {
  105. int32_t size = 1;
  106. uint32_t ind = 0;
  107. std::vector<Arrangement> arrangement_list;
  108. std::vector<int32_t> shape;
  109. for (uint32_t i = 0; i < expand_shape.GetDimSize(); i++) {
  110. size *= expand_shape.GetDimByIdx(i);
  111. if (size > GetDimByIdx(ind)) {
  112. MS_LOG(ERROR) << "invalid expand_shape";
  113. return nullptr;
  114. } else if (size < GetDimByIdx(ind)) {
  115. shape.push_back(expand_shape.GetDimByIdx(i));
  116. continue;
  117. } else {
  118. shape.push_back(expand_shape.GetDimByIdx(i));
  119. Arrangement arrangement;
  120. (void)arrangement.Init(shape);
  121. arrangement_list.push_back(arrangement);
  122. shape.clear();
  123. ind++;
  124. size = 1;
  125. }
  126. }
  127. if (ind != GetDimSize()) {
  128. MS_LOG(ERROR) << "invalid expand_shape";
  129. return nullptr;
  130. }
  131. auto arrangement_new = std::make_shared<std::vector<Arrangement>>(arrangement_list);
  132. return arrangement_new;
  133. }
  134. std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::GetExpandShapeListPair(
  135. const Arrangement &expand_shape) const {
  136. std::shared_ptr<std::vector<Arrangement>> expand_shape_list_ptr = GetExpandShapeList(expand_shape);
  137. if (expand_shape_list_ptr == nullptr) {
  138. return nullptr;
  139. }
  140. std::vector<int32_t> expand_num_list_shape;
  141. (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(),
  142. std::back_inserter(expand_num_list_shape),
  143. [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); });
  144. Arrangement expand_num_list;
  145. Status status = expand_num_list.Init(expand_num_list_shape);
  146. if (status != Status::SUCCESS) {
  147. return nullptr;
  148. }
  149. auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list);
  150. return std::make_shared<std::pair<std::vector<Arrangement>, Arrangement>>(out_value);
  151. }
  152. std::vector<int32_t> Arrangement::ComputeReverseAccumulateSumInReverseOrder() const {
  153. std::vector<int32_t> shape_accum;
  154. int32_t size = 0;
  155. for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) {
  156. shape_accum.push_back(size);
  157. size += *iter;
  158. }
  159. return shape_accum;
  160. }
  161. std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListReserveLeft(
  162. const std::vector<Arrangement> &expand_list) const {
  163. if (expand_list.size() != GetDimSize()) {
  164. return nullptr;
  165. }
  166. std::vector<int32_t> new_shape;
  167. for (uint32_t i = 0; i < expand_list.size(); i++) {
  168. if (expand_list[i].GetDimSize() >= 1) {
  169. int32_t size = 1;
  170. for (uint32_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) {
  171. new_shape.push_back(expand_list[i].GetDimByIdx(k));
  172. size *= expand_list[i].GetDimByIdx(k);
  173. }
  174. new_shape.push_back(GetDimByIdx(i) / size);
  175. } else {
  176. new_shape.push_back(GetDimByIdx(i));
  177. }
  178. }
  179. Arrangement arrangement_new;
  180. (void)arrangement_new.Init(new_shape);
  181. return std::make_shared<Arrangement>(arrangement_new);
  182. }
  183. std::shared_ptr<Arrangement> Arrangement::GetUnifiedShape(const Arrangement &in2) const {
  184. std::vector<int64_t> in1_accum;
  185. Status status = ShapeToAccumulateProduct(array_, &in1_accum);
  186. if (status != Status::SUCCESS) {
  187. return nullptr;
  188. }
  189. std::vector<int64_t> in2_accum;
  190. status = ShapeToAccumulateProduct(in2.array(), &in2_accum);
  191. if (status != Status::SUCCESS) {
  192. return nullptr;
  193. }
  194. std::vector<int64_t> out_accum;
  195. status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum);
  196. if (status != Status::SUCCESS) {
  197. return nullptr;
  198. }
  199. std::vector<int32_t> out_shape;
  200. status = AccumulateProductToShape(out_accum, &out_shape);
  201. if (status != Status::SUCCESS) {
  202. return nullptr;
  203. }
  204. Arrangement out;
  205. status = out.Init(out_shape);
  206. if (status != Status::SUCCESS) {
  207. return nullptr;
  208. }
  209. return std::make_shared<Arrangement>(out);
  210. }
  211. std::vector<size_t> Arrangement::GetSqueezeIdx() const {
  212. std::vector<size_t> out;
  213. for (size_t i = 0; i < GetDimSize(); i++) {
  214. if (GetDimByIdx(SizeToUint(i)) == 1) {
  215. out.push_back(i);
  216. }
  217. }
  218. return out;
  219. }
  220. Arrangement Arrangement::GetSqueezeArrangement() const {
  221. std::vector<int32_t> out_shape(array_.size());
  222. auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int32_t value) { return value != 1; });
  223. out_shape.resize(LongToSize(std::distance(out_shape.begin(), it)));
  224. // if all elements are 1, out_shape = {1}
  225. if (out_shape.empty()) {
  226. MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
  227. out_shape.push_back(1);
  228. }
  229. Arrangement out;
  230. (void)out.Init(out_shape);
  231. return out;
  232. }
  233. } // namespace parallel
  234. } // namespace mindspore