You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cc 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/utils.h"
  19. #include <string>
  20. #include <sstream>
  21. #include <memory>
  22. #include "utils/symbolic.h"
  23. #include "pipeline/static_analysis/param_validator.h"
  24. namespace mindspore {
  25. namespace abstract {
  26. ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) {
  27. MS_EXCEPTION_IF_NULL(value1);
  28. MS_EXCEPTION_IF_NULL(value2);
  29. if (*value1 == *value2) {
  30. return value1;
  31. }
  32. return kAnyValue;
  33. }
  34. TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
  35. MS_EXCEPTION_IF_NULL(type1);
  36. MS_EXCEPTION_IF_NULL(type2);
  37. if (*type1 == *type2) {
  38. return type1;
  39. }
  40. return kAnyType;
  41. }
  42. ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
  43. MS_EXCEPTION_IF_NULL(shape1);
  44. MS_EXCEPTION_IF_NULL(shape2);
  45. if (*shape1 == *shape2) {
  46. return shape1;
  47. }
  48. if (shape1->shape().size() != shape2->shape().size()) {
  49. MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString();
  50. return shape1;
  51. }
  52. std::vector<int> dims;
  53. dims.resize(shape1->shape().size());
  54. for (std::size_t i = 0; i < shape1->shape().size(); i++) {
  55. if (shape1->shape()[i] == shape2->shape()[i]) {
  56. dims[i] = shape1->shape()[i];
  57. } else {
  58. dims[i] = Shape::SHP_ANY;
  59. }
  60. }
  61. return std::make_shared<Shape>(dims);
  62. }
  63. AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
  64. if (args_spec_list.size() < 1) {
  65. MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size()
  66. << ".";
  67. }
  68. AbstractBasePtr arg_spec_tmp = args_spec_list[0];
  69. MS_EXCEPTION_IF_NULL(arg_spec_tmp);
  70. for (auto arg_spec : args_spec_list) {
  71. arg_spec_tmp = arg_spec_tmp->Join(arg_spec);
  72. MS_EXCEPTION_IF_NULL(arg_spec_tmp);
  73. }
  74. return arg_spec_tmp;
  75. }
  76. AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2) {
  77. if (spec1.size() != spec2.size()) {
  78. MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. spec1: " << ::mindspore::ToString(spec1)
  79. << ", spec2: " << ::mindspore::ToString(spec2);
  80. }
  81. AbstractBasePtrList joined_list;
  82. bool changes = false;
  83. for (std::size_t i = 0; i < spec1.size(); i++) {
  84. auto joined_elem = spec1[i]->Join(spec2[i]);
  85. if (joined_elem != spec1[i]) {
  86. changes = true;
  87. }
  88. joined_list.push_back(joined_elem);
  89. }
  90. if (!changes) {
  91. return spec1;
  92. }
  93. return joined_list;
  94. }
  95. AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
  96. AbstractFunctionPtr f_spec = dyn_cast<AbstractFunction>(spec);
  97. if (f_spec != nullptr) {
  98. return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  99. }
  100. return spec->Clone();
  101. }
  102. namespace {
  103. // Join all types in args_type_list;
  104. TypePtr TypeJoin(const TypePtrList &args_type_list) {
  105. if (args_type_list.empty()) {
  106. MS_LOG(EXCEPTION) << "args_type_list is empty";
  107. }
  108. TypePtr type_tmp = args_type_list[0];
  109. for (std::size_t i = 1; i < args_type_list.size(); i++) {
  110. type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
  111. }
  112. return type_tmp;
  113. }
  114. } // namespace
  115. bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
  116. // As x and predicate both are mindspore type staticly, here we only to judge whether
  117. // x is predicate or is a subclass of predicate.
  118. return IsIdentidityOrSubclass(x, expected_type);
  119. }
  120. TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
  121. MS_EXCEPTION_IF_NULL(predicate);
  122. for (auto arg_type : args_type_list) {
  123. MS_EXCEPTION_IF_NULL(arg_type);
  124. if (!CheckType(predicate, arg_type)) {
  125. MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
  126. }
  127. }
  128. return TypeJoin(args_type_list);
  129. }
  130. int GetPositiveAxis(int axis_value, size_t increment) {
  131. if (axis_value < 0) {
  132. axis_value = axis_value + SizeToInt(increment);
  133. }
  134. if (axis_value < 0) {
  135. MS_LOG(EXCEPTION) << "axis_value should not still <0";
  136. }
  137. return axis_value;
  138. }
  139. // Return if two shapes can be broadcast.
  140. // Broadcast shape is placed in broadcast_output_shape.
  141. std::vector<int> RealBroadcast(const std::string &op, std::vector<int> x_shape, std::vector<int> y_shape) {
  142. std::reverse(x_shape.begin(), x_shape.end());
  143. std::reverse(y_shape.begin(), y_shape.end());
  144. // Fill a placeholder value 1 which will be replaced later.
  145. size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size();
  146. y_shape.resize(std_len, 1);
  147. x_shape.resize(std_len, 1);
  148. std::vector<int> broadcast_shape;
  149. for (size_t i = 0; i < std_len; i++) {
  150. int x_i = x_shape[i]; // i-th dimension of x
  151. int y_i = y_shape[i]; // i-th dimension of y
  152. int output_i = 0; // i-th dimension of the output
  153. if (x_i == y_i) {
  154. output_i = x_i;
  155. } else if (x_i == 1) {
  156. output_i = y_i;
  157. } else if (y_i == 1) {
  158. output_i = x_i;
  159. } else {
  160. MS_LOG(EXCEPTION)
  161. << "" << op
  162. << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting "
  163. "requirements";
  164. }
  165. broadcast_shape.push_back(output_i);
  166. }
  167. std::reverse(broadcast_shape.begin(), broadcast_shape.end());
  168. return broadcast_shape;
  169. }
  170. ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x,
  171. const AbstractTensorPtr &tensor_y) {
  172. mindspore::abstract::ShapePtr tensor_x_shape = tensor_x->shape();
  173. mindspore::abstract::ShapePtr tensor_y_shape = tensor_y->shape();
  174. // if is the same shape ,just return the x_shape
  175. if (*tensor_x_shape == *tensor_y_shape) {
  176. return tensor_x_shape;
  177. }
  178. auto x_shape = tensor_x_shape->shape();
  179. auto y_shape = tensor_y_shape->shape();
  180. return std::make_shared<Shape>(RealBroadcast(op, x_shape, y_shape));
  181. }
  182. } // namespace abstract
  183. } // namespace mindspore