You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_statement.cc 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/param_validator.h"
  17. #include "pipeline/static_analysis/prim.h"
  18. #include "operator/ops.h"
  19. #include "pipeline/static_analysis/utils.h"
  20. #include "utils/symbolic.h"
  21. namespace mindspore {
  22. namespace abstract {
  23. AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
  24. const AbstractBasePtrList &args_spec_list) {
  25. // Inputs: a pointer to an AbstractBase object
  26. if (args_spec_list.size() != 1) {
  27. MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? "
  28. "while the input size is "
  29. << args_spec_list.size() << ".";
  30. }
  31. AbstractBasePtr abs_base = args_spec_list[0];
  32. return abs_base;
  33. }
  34. AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
  35. const AbstractBasePtrList &args_spec_list) {
  36. // Inputs: a pointer to an AbstractBase object
  37. if (args_spec_list.size() != 1) {
  38. MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size()
  39. << ".";
  40. }
  41. AbstractBasePtr abs_base = args_spec_list[0];
  42. MS_EXCEPTION_IF_NULL(abs_base);
  43. TypePtr type = abs_base->BuildType();
  44. return std::make_shared<AbstractType>(type);
  45. }
  46. AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  47. const AbstractBasePtrList &args_spec_list) {
  48. // Inputs: a pointer to an AbstractBase object and a pointer to a Type
  49. const std::string op_name = primitive->name();
  50. CheckArgsSize(op_name, args_spec_list, 2);
  51. AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
  52. auto mode_v = abs_type->GetValueTrack();
  53. MS_EXCEPTION_IF_NULL(mode_v);
  54. if (!mode_v->isa<Type>()) {
  55. MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed.";
  56. }
  57. TypePtr mode_t = mode_v->cast<TypePtr>();
  58. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  59. bool v = IsSubtype(args_spec_list[0], mode_t);
  60. return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
  61. }
  62. AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  63. const AbstractBasePtrList &args_spec_list) {
  64. // Inputs: two tensors.
  65. const std::string op_name = primitive->name();
  66. CheckArgsSize(op_name, args_spec_list, 2);
  67. AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  68. AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
  69. ShapePtr x_shp = input_x->shape();
  70. auto x_shp_value = x_shp->shape();
  71. ShapePtr y_shp = input_y->shape();
  72. auto y_shp_value = y_shp->shape();
  73. // Should be matrix which shape size is 2.
  74. if (x_shp_value.size() != 2 || y_shp_value.size() != 2) {
  75. MS_LOG(EXCEPTION) << "" << op_name
  76. << " evaluator requires input two 2D tensors, while the dimensions of two tensors are "
  77. << x_shp_value.size() << ", " << y_shp_value.size() << " ";
  78. }
  79. if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) {
  80. MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}";
  81. }
  82. auto x_element = input_x->element();
  83. MS_EXCEPTION_IF_NULL(x_element);
  84. (void)x_element->Join(input_y->element());
  85. auto param = {x_shp_value[0], y_shp_value[1]};
  86. return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
  87. }
  88. AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
  89. const AbstractBasePtrList &args_spec_list) {
  90. // Inputs: condition, true branch, false branch
  91. if (args_spec_list.size() != 3) {
  92. MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
  93. << ".";
  94. }
  95. auto cond = args_spec_list[0];
  96. auto tb = args_spec_list[1];
  97. auto fb = args_spec_list[2];
  98. MS_EXCEPTION_IF_NULL(cond);
  99. ValuePtr v = cond->GetValueTrack();
  100. MS_EXCEPTION_IF_NULL(v);
  101. if (v->isa<AnyValue>()) {
  102. MS_EXCEPTION_IF_NULL(tb);
  103. return tb->Join(fb);
  104. }
  105. if (v->isa<Scalar>()) {
  106. if (v->cast<ScalarPtr>()->IsOne()) {
  107. return tb;
  108. } else {
  109. return fb;
  110. }
  111. }
  112. MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString();
  113. }
  114. std::vector<ValuePtr> GetSupportedTargetValue() {
  115. std::vector<ValuePtr> list = {kNone, MakeValue(false), MakeValue(true)};
  116. return list;
  117. }
  118. bool SupportedIsTargetValue(const ValuePtr t) {
  119. auto list = GetSupportedTargetValue();
  120. auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; });
  121. return match;
  122. }
  123. AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  124. const AbstractBasePtrList &args_spec_list) {
  125. // statement: x is t
  126. // Inputs: x, t
  127. const std::string op_name = primitive->name();
  128. CheckArgsSize(op_name, args_spec_list, 2);
  129. ValuePtr t = args_spec_list[1]->BuildValue();
  130. if (!SupportedIsTargetValue(t)) {
  131. MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString()
  132. << " for statement is, supported list is:None, False, True ";
  133. }
  134. ValuePtr x = args_spec_list[0]->BuildValue();
  135. return std::make_shared<AbstractScalar>(*t == *x);
  136. }
  137. AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  138. const AbstractBasePtrList &args_spec_list) {
  139. // statement: x is not t
  140. // Inputs: x, t
  141. const std::string op_name = primitive->name();
  142. CheckArgsSize(op_name, args_spec_list, 2);
  143. ValuePtr t = args_spec_list[1]->BuildValue();
  144. if (!SupportedIsTargetValue(t)) {
  145. MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString()
  146. << " for statement is not, supported list is:None, False, True ";
  147. }
  148. ValuePtr x = args_spec_list[0]->BuildValue();
  149. return std::make_shared<AbstractScalar>(!(*t == *x));
  150. }
  151. bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
  152. const std::string op_name = primitive->name();
  153. CheckArgsSize(op_name, args_spec_list, 2);
  154. auto key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  155. auto dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 1);
  156. ValuePtr key_value = key->BuildValue();
  157. if (!key_value->isa<StringImm>()) {
  158. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  159. }
  160. auto key_str = GetValue<std::string>(key_value);
  161. std::vector<AbstractAttribute> dict_elems = dict->elements();
  162. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  163. [key_str](const AbstractAttribute &item) { return item.first == key_str; });
  164. return it != dict_elems.end();
  165. }
  166. AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  167. const AbstractBasePtrList &args_spec_list) {
  168. // statement: x in t
  169. // Inputs: x, t
  170. return std::make_shared<AbstractScalar>(IsInDict(primitive, args_spec_list));
  171. }
  172. AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  173. const AbstractBasePtrList &args_spec_list) {
  174. // statement: x not in t
  175. // Inputs: x, t
  176. return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
  177. }
  178. } // namespace abstract
  179. } // namespace mindspore