You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "abstract/utils.h"
  19. #include <string>
  20. #include <sstream>
  21. #include <memory>
  22. #include "utils/symbolic.h"
  23. #include "abstract/param_validator.h"
  24. namespace mindspore {
  25. namespace abstract {
  26. const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
  27. {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
  28. {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
  29. {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
  30. {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}};
  31. ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) {
  32. MS_EXCEPTION_IF_NULL(value1);
  33. MS_EXCEPTION_IF_NULL(value2);
  34. if (*value1 == *value2) {
  35. return value1;
  36. }
  37. return kAnyValue;
  38. }
  39. TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
  40. MS_EXCEPTION_IF_NULL(type1);
  41. MS_EXCEPTION_IF_NULL(type2);
  42. if (*type1 == *type2) {
  43. return type1;
  44. }
  45. return kAnyType;
  46. }
  47. ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
  48. MS_EXCEPTION_IF_NULL(shape1);
  49. MS_EXCEPTION_IF_NULL(shape2);
  50. if (*shape1 == *shape2) {
  51. return shape1;
  52. }
  53. // lengths of two shapes are not same, join failed
  54. if (shape1->shape().size() != shape2->shape().size()) {
  55. // special case: shape(1), shape() -> shape(1)
  56. if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) {
  57. return shape1;
  58. }
  59. if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
  60. return shape2;
  61. }
  62. MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString()
  63. << ", shape2 = " << shape2->ToString();
  64. }
  65. ShapeVector dims;
  66. bool has_dynamic_shape = false;
  67. dims.resize(shape1->shape().size());
  68. for (std::size_t i = 0; i < shape1->shape().size(); i++) {
  69. if (shape1->shape()[i] == shape2->shape()[i]) {
  70. dims[i] = shape1->shape()[i];
  71. if (shape1->shape()[i] == Shape::SHP_ANY) {
  72. has_dynamic_shape = true;
  73. }
  74. } else {
  75. dims[i] = Shape::SHP_ANY;
  76. has_dynamic_shape = true;
  77. }
  78. }
  79. if (!has_dynamic_shape) {
  80. return std::make_shared<Shape>(dims);
  81. }
  82. // calculate dynamic shape
  83. ShapeVector min_dims(dims.size());
  84. ShapeVector max_dims(dims.size());
  85. for (size_t i = 0; i < dims.size(); ++i) {
  86. if (dims[i] != Shape::SHP_ANY) {
  87. min_dims[i] = max_dims[i] = dims[i];
  88. continue;
  89. }
  90. if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
  91. min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]);
  92. max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]);
  93. continue;
  94. }
  95. if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
  96. if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
  97. MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
  98. << " has dynamic shape, but does not have min/max shape info.";
  99. }
  100. min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]);
  101. max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]);
  102. continue;
  103. }
  104. if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) {
  105. if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
  106. MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
  107. << " has dynamic shape, but does not have min/max shape info.";
  108. }
  109. min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]);
  110. max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]);
  111. continue;
  112. }
  113. // both shapes contains dynamic shape
  114. if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
  115. MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
  116. << " has dynamic shape, but does not have min/max shape info.";
  117. }
  118. if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
  119. MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString()
  120. << " has dynamic shape, but does not have min/max shape info.";
  121. }
  122. min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]);
  123. max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]);
  124. }
  125. return std::make_shared<Shape>(dims, min_dims, max_dims);
  126. }
  127. AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
  128. if (args_spec_list.size() < 1) {
  129. MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size()
  130. << ".";
  131. }
  132. AbstractBasePtr arg_spec_tmp = args_spec_list[0];
  133. MS_EXCEPTION_IF_NULL(arg_spec_tmp);
  134. for (auto arg_spec : args_spec_list) {
  135. arg_spec_tmp = arg_spec_tmp->Join(arg_spec);
  136. MS_EXCEPTION_IF_NULL(arg_spec_tmp);
  137. }
  138. return arg_spec_tmp;
  139. }
  140. AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2) {
  141. if (spec1.size() != spec2.size()) {
  142. MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. spec1: " << ::mindspore::ToString(spec1)
  143. << ", spec2: " << ::mindspore::ToString(spec2);
  144. }
  145. AbstractBasePtrList joined_list;
  146. bool changes = false;
  147. for (std::size_t i = 0; i < spec1.size(); i++) {
  148. auto joined_elem = spec1[i]->Join(spec2[i]);
  149. if (joined_elem != spec1[i]) {
  150. changes = true;
  151. }
  152. joined_list.push_back(joined_elem);
  153. }
  154. if (!changes) {
  155. return spec1;
  156. }
  157. return joined_list;
  158. }
  159. AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
  160. AbstractFunctionPtr f_spec = dyn_cast<AbstractFunction>(spec);
  161. if (f_spec != nullptr) {
  162. return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  163. }
  164. return spec->Clone();
  165. }
  166. namespace {
  167. // Join all types in args_type_list;
  168. TypePtr TypeJoin(const TypePtrList &args_type_list) {
  169. if (args_type_list.empty()) {
  170. MS_LOG(EXCEPTION) << "args_type_list is empty";
  171. }
  172. TypePtr type_tmp = args_type_list[0];
  173. for (std::size_t i = 1; i < args_type_list.size(); i++) {
  174. type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
  175. }
  176. return type_tmp;
  177. }
  178. } // namespace
  179. bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
  180. // As x and predicate both are mindspore type staticly, here we only to judge whether
  181. // x is predicate or is a subclass of predicate.
  182. return IsIdentidityOrSubclass(x, expected_type);
  183. }
  184. TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
  185. MS_EXCEPTION_IF_NULL(predicate);
  186. for (auto arg_type : args_type_list) {
  187. MS_EXCEPTION_IF_NULL(arg_type);
  188. if (!CheckType(predicate, arg_type)) {
  189. MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
  190. }
  191. }
  192. return TypeJoin(args_type_list);
  193. }
  194. int64_t GetPositiveAxis(int64_t axis_value, size_t increment) {
  195. if (axis_value < 0) {
  196. axis_value = axis_value + SizeToLong(increment);
  197. }
  198. if (axis_value < 0) {
  199. MS_LOG(EXCEPTION) << "axis_value should not still <0";
  200. }
  201. return axis_value;
  202. }
  203. // Return if two shapes can be broadcast.
  204. // Broadcast shape is placed in broadcast_output_shape.
  205. ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) {
  206. std::reverse(x_shape.begin(), x_shape.end());
  207. std::reverse(y_shape.begin(), y_shape.end());
  208. // Fill a placeholder value 1 which will be replaced later.
  209. size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size();
  210. y_shape.resize(std_len, 1);
  211. x_shape.resize(std_len, 1);
  212. ShapeVector broadcast_shape;
  213. for (size_t i = 0; i < std_len; i++) {
  214. int64_t x_i = x_shape[i]; // i-th dimension of x
  215. int64_t y_i = y_shape[i]; // i-th dimension of y
  216. int64_t output_i = 0; // i-th dimension of the output
  217. if (x_i == y_i) {
  218. output_i = x_i;
  219. } else if (x_i == 1) {
  220. output_i = y_i;
  221. } else if (y_i == 1) {
  222. output_i = x_i;
  223. } else {
  224. MS_LOG(EXCEPTION)
  225. << op
  226. << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting "
  227. "requirements";
  228. }
  229. broadcast_shape.push_back(output_i);
  230. }
  231. std::reverse(broadcast_shape.begin(), broadcast_shape.end());
  232. return broadcast_shape;
  233. }
  234. ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
  235. int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
  236. if (dlen < 0) {
  237. for (int i = 0; i < -dlen; ++i) {
  238. (void)shpx.insert(shpx.begin(), 1);
  239. }
  240. } else if (dlen > 0) {
  241. for (int i = 0; i < dlen; i++) {
  242. (void)shpy.insert(shpy.begin(), 1);
  243. }
  244. }
  245. if (shpx.size() != shpy.size()) {
  246. MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
  247. }
  248. ShapeVector shp;
  249. for (size_t i = 0; i < shpx.size(); i++) {
  250. auto a = shpx[i];
  251. auto b = shpy[i];
  252. if (a == 1) {
  253. shp.push_back(b);
  254. } else if (b == 1) {
  255. shp.push_back(a);
  256. } else if (a == -1) {
  257. shp.push_back(b);
  258. } else if (b == -1) {
  259. shp.push_back(a);
  260. } else if (a == b) {
  261. shp.push_back(a);
  262. } else {
  263. return ShapeVector();
  264. }
  265. }
  266. return shp;
  267. }
  268. ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x,
  269. const AbstractTensorPtr &tensor_y) {
  270. mindspore::abstract::ShapePtr tensor_x_shape = tensor_x->shape();
  271. mindspore::abstract::ShapePtr tensor_y_shape = tensor_y->shape();
  272. // if is the same shape ,just return the x_shape
  273. if (*tensor_x_shape == *tensor_y_shape) {
  274. return tensor_x_shape;
  275. }
  276. auto x_shape = tensor_x_shape->shape();
  277. auto y_shape = tensor_y_shape->shape();
  278. return std::make_shared<Shape>(RealBroadcast(op, x_shape, y_shape));
  279. }
  280. size_t TypeIdSize(const TypeId data_type) {
  281. const size_t unsupported_type_error = 0;
  282. auto iter = type_map.find(data_type);
  283. if (iter != type_map.end()) {
  284. return iter->second;
  285. }
  286. return unsupported_type_error;
  287. }
  288. size_t ShapeSize(const std::vector<size_t> &shape) {
  289. return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
  290. }
  291. void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
  292. *min_shape = (*min_shape).empty() ? shape : *min_shape;
  293. *max_shape = (*max_shape).empty() ? shape : *max_shape;
  294. }
  295. int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name) {
  296. int64_t num_segments_value = 0;
  297. if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor
  298. auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>();
  299. MS_EXCEPTION_IF_NULL(num_segments);
  300. auto num_segments_value_ptr = num_segments->BuildValue();
  301. MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
  302. auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
  303. MS_EXCEPTION_IF_NULL(num_segments_tensor);
  304. if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
  305. num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
  306. } else {
  307. num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
  308. }
  309. } else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
  310. auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
  311. if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
  312. num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
  313. } else {
  314. num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
  315. }
  316. } else {
  317. MS_LOG(EXCEPTION) << "num_segments incorrect type in " << op_name;
  318. }
  319. return num_segments_value;
  320. }
  321. } // namespace abstract
  322. } // namespace mindspore