You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_statement.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/param_validator.h"
  17. #include "pipeline/static_analysis/prim.h"
  18. #include "operator/ops.h"
  19. #include "pipeline/static_analysis/utils.h"
  20. #include "utils/symbolic.h"
  21. namespace mindspore {
  22. namespace abstract {
  23. AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
  24. const AbstractBasePtrList &args_spec_list) {
  25. // Inputs: a pointer to an AbstractBase object
  26. if (args_spec_list.size() != 1) {
  27. MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? "
  28. "while the input size is "
  29. << args_spec_list.size() << ".";
  30. }
  31. AbstractBasePtr abs_base = args_spec_list[0];
  32. return abs_base;
  33. }
  34. AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
  35. const AbstractBasePtrList &args_spec_list) {
  36. // Inputs: a pointer to an AbstractBase object
  37. if (args_spec_list.size() != 1) {
  38. MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size()
  39. << ".";
  40. }
  41. AbstractBasePtr abs_base = args_spec_list[0];
  42. MS_EXCEPTION_IF_NULL(abs_base);
  43. TypePtr type = abs_base->BuildType();
  44. return std::make_shared<AbstractType>(type);
  45. }
  46. AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  47. const AbstractBasePtrList &args_spec_list) {
  48. // Inputs: a pointer to an AbstractBase object and a pointer to a Type
  49. const std::string op_name = primitive->name();
  50. CheckArgsSize(op_name, args_spec_list, 2);
  51. AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
  52. auto mode_v = abs_type->GetValueTrack();
  53. MS_EXCEPTION_IF_NULL(mode_v);
  54. if (!mode_v->isa<Type>()) {
  55. MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed.";
  56. }
  57. TypePtr mode_t = mode_v->cast<TypePtr>();
  58. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  59. bool v = IsSubtype(args_spec_list[0], mode_t);
  60. return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
  61. }
  62. AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  63. const AbstractBasePtrList &args_spec_list) {
  64. // Inputs: two tensors.
  65. const std::string op_name = primitive->name();
  66. CheckArgsSize(op_name, args_spec_list, 2);
  67. AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  68. AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
  69. ShapePtr x_shp = input_x->shape();
  70. auto x_shp_value = x_shp->shape();
  71. ShapePtr y_shp = input_y->shape();
  72. auto y_shp_value = y_shp->shape();
  73. // Should be matrix which shape size is 2.
  74. if (x_shp_value.size() != 2 || y_shp_value.size() != 2) {
  75. MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are "
  76. << x_shp_value.size() << ", " << y_shp_value.size() << " ";
  77. }
  78. if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) {
  79. MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}";
  80. }
  81. auto x_element = input_x->element();
  82. MS_EXCEPTION_IF_NULL(x_element);
  83. (void)x_element->Join(input_y->element());
  84. auto param = {x_shp_value[0], y_shp_value[1]};
  85. return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
  86. }
  87. AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
  88. const AbstractBasePtrList &args_spec_list) {
  89. // Inputs: condition, true branch, false branch
  90. if (args_spec_list.size() != 3) {
  91. MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
  92. << ".";
  93. }
  94. auto cond = args_spec_list[0];
  95. auto tb = args_spec_list[1];
  96. auto fb = args_spec_list[2];
  97. MS_EXCEPTION_IF_NULL(cond);
  98. ValuePtr v = cond->GetValueTrack();
  99. MS_EXCEPTION_IF_NULL(v);
  100. // for tensor as condition, keeps both true and false branch.
  101. if (v->isa<AnyValue>() || cond->isa<AbstractTensor>()) {
  102. MS_EXCEPTION_IF_NULL(tb);
  103. return tb->Join(fb);
  104. }
  105. if (v->isa<Scalar>()) {
  106. if (v->cast<ScalarPtr>()->IsOne()) {
  107. return tb;
  108. } else {
  109. return fb;
  110. }
  111. }
  112. MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString();
  113. }
  114. AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  115. const AbstractBasePtrList &args_spec_list) {
  116. // Inputs: index, branch
  117. const std::string op_name = primitive->name();
  118. abstract::CheckArgsSize(op_name, args_spec_list, 2);
  119. (void)CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  120. AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  121. AbstractBasePtrList branches = branches_abs->elements();
  122. const size_t maximum_layer_num = 1000;
  123. if (branches.size() < 0 || branches.size() > maximum_layer_num) {
  124. MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
  125. << branches.size() << " branches.";
  126. }
  127. for (size_t i = 0; i < branches.size(); i++) {
  128. MS_EXCEPTION_IF_NULL(branches[i]);
  129. if (!branches[i]->isa<AbstractFunction>()) {
  130. MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got "
  131. << branches[i]->ToString() << " as the " << i << "th element.";
  132. }
  133. }
  134. auto b = branches[0];
  135. for (size_t i = 1; i < branches.size(); i++) {
  136. b = b->Join(branches[i]);
  137. }
  138. return b;
  139. }
  140. std::vector<ValuePtr> GetSupportedTargetValue() {
  141. std::vector<ValuePtr> list = {kNone, MakeValue(false), MakeValue(true)};
  142. return list;
  143. }
  144. bool SupportedIsTargetValue(const ValuePtr t) {
  145. auto list = GetSupportedTargetValue();
  146. auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; });
  147. return match;
  148. }
  149. AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  150. const AbstractBasePtrList &args_spec_list) {
  151. // statement: x is t
  152. // Inputs: x, t
  153. const std::string op_name = primitive->name();
  154. CheckArgsSize(op_name, args_spec_list, 2);
  155. ValuePtr t = args_spec_list[1]->BuildValue();
  156. if (!SupportedIsTargetValue(t)) {
  157. MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString()
  158. << " for statement is, supported list is:None, False, True ";
  159. }
  160. ValuePtr x = args_spec_list[0]->BuildValue();
  161. return std::make_shared<AbstractScalar>(*t == *x);
  162. }
  163. AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  164. const AbstractBasePtrList &args_spec_list) {
  165. // statement: x is not t
  166. // Inputs: x, t
  167. const std::string op_name = primitive->name();
  168. CheckArgsSize(op_name, args_spec_list, 2);
  169. ValuePtr t = args_spec_list[1]->BuildValue();
  170. if (!SupportedIsTargetValue(t)) {
  171. MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString()
  172. << " for statement is not, supported list is:None, False, True ";
  173. }
  174. ValuePtr x = args_spec_list[0]->BuildValue();
  175. return std::make_shared<AbstractScalar>(!(*t == *x));
  176. }
  177. bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
  178. const std::string op_name = primitive->name();
  179. CheckArgsSize(op_name, args_spec_list, 2);
  180. auto key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  181. auto dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 1);
  182. ValuePtr key_value = key->BuildValue();
  183. if (!key_value->isa<StringImm>()) {
  184. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  185. }
  186. auto key_str = GetValue<std::string>(key_value);
  187. std::vector<AbstractAttribute> dict_elems = dict->elements();
  188. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  189. [key_str](const AbstractAttribute &item) { return item.first == key_str; });
  190. return it != dict_elems.end();
  191. }
  192. AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  193. const AbstractBasePtrList &args_spec_list) {
  194. // statement: x in t
  195. // Inputs: x, t
  196. return std::make_shared<AbstractScalar>(IsInDict(primitive, args_spec_list));
  197. }
  198. AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  199. const AbstractBasePtrList &args_spec_list) {
  200. // statement: x not in t
  201. // Inputs: x, t
  202. return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
  203. }
  204. AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  205. const AbstractBasePtrList &args_spec_list) {
  206. // statement: isconstant(x)
  207. // Inputs: x
  208. if (args_spec_list.size() != 1) {
  209. MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1";
  210. }
  211. ValuePtr v = args_spec_list[0]->BuildValue();
  212. return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
  213. }
  214. } // namespace abstract
  215. } // namespace mindspore