You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "abstract/utils.h"
  19. #include <string>
  20. #include <sstream>
  21. #include <memory>
  22. #include "utils/symbolic.h"
  23. #include "abstract/param_validator.h"
  24. #include "utils/shape_utils.h"
  25. namespace mindspore {
  26. namespace abstract {
  27. ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) {
  28. MS_EXCEPTION_IF_NULL(value1);
  29. MS_EXCEPTION_IF_NULL(value2);
  30. if (*value1 == *value2) {
  31. return value1;
  32. }
  33. return kAnyValue;
  34. }
  35. TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
  36. MS_EXCEPTION_IF_NULL(type1);
  37. MS_EXCEPTION_IF_NULL(type2);
  38. if (*type1 == *type2) {
  39. return type1;
  40. }
  41. return kAnyType;
  42. }
  43. ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
  44. MS_EXCEPTION_IF_NULL(shape1);
  45. MS_EXCEPTION_IF_NULL(shape2);
  46. if (*shape1 == *shape2) {
  47. return shape1;
  48. }
  49. // lengths of two shapes are not same, join failed
  50. if (shape1->shape().size() != shape2->shape().size()) {
  51. // special case: shape(1), shape() -> shape(1)
  52. if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) {
  53. return shape1;
  54. }
  55. if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
  56. return shape2;
  57. }
  58. MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString()
  59. << ", shape2 = " << shape2->ToString();
  60. }
  61. ShapeVector dims;
  62. bool has_dynamic_shape = false;
  63. dims.resize(shape1->shape().size());
  64. for (std::size_t i = 0; i < shape1->shape().size(); i++) {
  65. if (shape1->shape()[i] == shape2->shape()[i]) {
  66. dims[i] = shape1->shape()[i];
  67. if (shape1->shape()[i] == Shape::SHP_ANY) {
  68. has_dynamic_shape = true;
  69. }
  70. } else {
  71. dims[i] = Shape::SHP_ANY;
  72. has_dynamic_shape = true;
  73. }
  74. }
  75. if (!has_dynamic_shape) {
  76. return std::make_shared<Shape>(dims);
  77. }
  78. // calculate dynamic shape
  79. ShapeVector min_dims(dims.size());
  80. ShapeVector max_dims(dims.size());
  81. for (size_t i = 0; i < dims.size(); ++i) {
  82. if (dims[i] != Shape::SHP_ANY) {
  83. min_dims[i] = max_dims[i] = dims[i];
  84. continue;
  85. }
  86. if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
  87. min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]);
  88. max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]);
  89. continue;
  90. }
  91. if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
  92. if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
  93. MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
  94. << " has dynamic shape, but does not have min/max shape info.";
  95. }
  96. min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]);
  97. max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]);
  98. continue;
  99. }
  100. if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) {
  101. if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
  102. MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
  103. << " has dynamic shape, but does not have min/max shape info.";
  104. }
  105. min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]);
  106. max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]);
  107. continue;
  108. }
  109. // both shapes contains dynamic shape
  110. if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
  111. MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
  112. << " has dynamic shape, but does not have min/max shape info.";
  113. }
  114. if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
  115. MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString()
  116. << " has dynamic shape, but does not have min/max shape info.";
  117. }
  118. min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]);
  119. max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]);
  120. }
  121. return std::make_shared<Shape>(dims, min_dims, max_dims);
  122. }
  123. AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
  124. if (args_spec_list.size() < 1) {
  125. MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size()
  126. << ".";
  127. }
  128. AbstractBasePtr arg_spec_tmp = args_spec_list[0];
  129. MS_EXCEPTION_IF_NULL(arg_spec_tmp);
  130. for (auto arg_spec : args_spec_list) {
  131. arg_spec_tmp = arg_spec_tmp->Join(arg_spec);
  132. MS_EXCEPTION_IF_NULL(arg_spec_tmp);
  133. }
  134. return arg_spec_tmp;
  135. }
  136. AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2) {
  137. if (spec1.size() != spec2.size()) {
  138. MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. spec1: " << ::mindspore::ToString(spec1)
  139. << ", spec2: " << ::mindspore::ToString(spec2);
  140. }
  141. AbstractBasePtrList joined_list;
  142. bool changes = false;
  143. for (std::size_t i = 0; i < spec1.size(); i++) {
  144. auto joined_elem = spec1[i]->Join(spec2[i]);
  145. if (joined_elem != spec1[i]) {
  146. changes = true;
  147. }
  148. joined_list.push_back(joined_elem);
  149. }
  150. if (!changes) {
  151. return spec1;
  152. }
  153. return joined_list;
  154. }
  155. AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
  156. AbstractFunctionPtr f_spec = dyn_cast<AbstractFunction>(spec);
  157. if (f_spec != nullptr) {
  158. return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  159. }
  160. return spec->Clone();
  161. }
  162. namespace {
  163. // Join all types in args_type_list;
  164. TypePtr TypeJoin(const TypePtrList &args_type_list) {
  165. if (args_type_list.empty()) {
  166. MS_LOG(EXCEPTION) << "args_type_list is empty";
  167. }
  168. TypePtr type_tmp = args_type_list[0];
  169. for (std::size_t i = 1; i < args_type_list.size(); i++) {
  170. type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
  171. }
  172. return type_tmp;
  173. }
  174. } // namespace
  175. bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
  176. // As x and predicate both are mindspore type staticly, here we only to judge whether
  177. // x is predicate or is a subclass of predicate.
  178. return IsIdentidityOrSubclass(x, expected_type);
  179. }
  180. TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
  181. MS_EXCEPTION_IF_NULL(predicate);
  182. for (auto arg_type : args_type_list) {
  183. MS_EXCEPTION_IF_NULL(arg_type);
  184. if (!CheckType(predicate, arg_type)) {
  185. MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
  186. }
  187. }
  188. return TypeJoin(args_type_list);
  189. }
  190. int GetPositiveAxis(int axis_value, size_t increment) {
  191. if (axis_value < 0) {
  192. axis_value = axis_value + SizeToInt(increment);
  193. }
  194. if (axis_value < 0) {
  195. MS_LOG(EXCEPTION) << "axis_value should not still <0";
  196. }
  197. return axis_value;
  198. }
  199. // Return if two shapes can be broadcast.
  200. // Broadcast shape is placed in broadcast_output_shape.
  201. ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) {
  202. std::reverse(x_shape.begin(), x_shape.end());
  203. std::reverse(y_shape.begin(), y_shape.end());
  204. // Fill a placeholder value 1 which will be replaced later.
  205. size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size();
  206. y_shape.resize(std_len, 1);
  207. x_shape.resize(std_len, 1);
  208. ShapeVector broadcast_shape;
  209. for (size_t i = 0; i < std_len; i++) {
  210. int x_i = x_shape[i]; // i-th dimension of x
  211. int y_i = y_shape[i]; // i-th dimension of y
  212. int output_i = 0; // i-th dimension of the output
  213. if (x_i == y_i) {
  214. output_i = x_i;
  215. } else if (x_i == 1) {
  216. output_i = y_i;
  217. } else if (y_i == 1) {
  218. output_i = x_i;
  219. } else {
  220. MS_LOG(EXCEPTION)
  221. << op
  222. << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting "
  223. "requirements";
  224. }
  225. broadcast_shape.push_back(output_i);
  226. }
  227. std::reverse(broadcast_shape.begin(), broadcast_shape.end());
  228. return broadcast_shape;
  229. }
  230. ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
  231. int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
  232. if (dlen < 0) {
  233. for (int i = 0; i < -dlen; ++i) {
  234. (void)shpx.insert(shpx.begin(), 1);
  235. }
  236. } else if (dlen > 0) {
  237. for (int i = 0; i < dlen; i++) {
  238. (void)shpy.insert(shpy.begin(), 1);
  239. }
  240. }
  241. if (shpx.size() != shpy.size()) {
  242. MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
  243. }
  244. ShapeVector shp;
  245. for (size_t i = 0; i < shpx.size(); i++) {
  246. auto a = shpx[i];
  247. auto b = shpy[i];
  248. if (a == 1) {
  249. shp.push_back(b);
  250. } else if (b == 1) {
  251. shp.push_back(a);
  252. } else if (a == -1) {
  253. shp.push_back(b);
  254. } else if (b == -1) {
  255. shp.push_back(a);
  256. } else if (a == b) {
  257. shp.push_back(a);
  258. } else {
  259. return ShapeVector();
  260. }
  261. }
  262. return shp;
  263. }
  264. ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x,
  265. const AbstractTensorPtr &tensor_y) {
  266. mindspore::abstract::ShapePtr tensor_x_shape = tensor_x->shape();
  267. mindspore::abstract::ShapePtr tensor_y_shape = tensor_y->shape();
  268. // if is the same shape ,just return the x_shape
  269. if (*tensor_x_shape == *tensor_y_shape) {
  270. return tensor_x_shape;
  271. }
  272. auto x_shape = tensor_x_shape->shape();
  273. auto y_shape = tensor_y_shape->shape();
  274. return std::make_shared<Shape>(RealBroadcast(op, x_shape, y_shape));
  275. }
  276. } // namespace abstract
  277. } // namespace mindspore