You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_statement.cc 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/param_validator.h"
  17. #include "pipeline/static_analysis/prim.h"
  18. #include "operator/ops.h"
  19. #include "pipeline/static_analysis/utils.h"
  20. #include "utils/symbolic.h"
  21. namespace mindspore {
  22. namespace abstract {
  23. AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
  24. const AbstractBasePtrList &args_spec_list) {
  25. // Inputs: a pointer to an AbstractBase object
  26. if (args_spec_list.size() != 1) {
  27. MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? "
  28. "while the input size is "
  29. << args_spec_list.size() << ".";
  30. }
  31. AbstractBasePtr abs_base = args_spec_list[0];
  32. return abs_base;
  33. }
  34. AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
  35. const AbstractBasePtrList &args_spec_list) {
  36. // Inputs: a pointer to an AbstractBase object
  37. if (args_spec_list.size() != 1) {
  38. MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size()
  39. << ".";
  40. }
  41. AbstractBasePtr abs_base = args_spec_list[0];
  42. MS_EXCEPTION_IF_NULL(abs_base);
  43. TypePtr type = abs_base->BuildType();
  44. return std::make_shared<AbstractType>(type);
  45. }
  46. AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  47. const AbstractBasePtrList &args_spec_list) {
  48. // Inputs: a pointer to an AbstractBase object and a pointer to a Type
  49. const std::string op_name = primitive->name();
  50. CheckArgsSize(op_name, args_spec_list, 2);
  51. AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
  52. auto mode_v = abs_type->GetValueTrack();
  53. MS_EXCEPTION_IF_NULL(mode_v);
  54. if (!mode_v->isa<Type>()) {
  55. MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed.";
  56. }
  57. TypePtr mode_t = mode_v->cast<TypePtr>();
  58. MS_EXCEPTION_IF_NULL(args_spec_list[0]);
  59. bool v = IsSubtype(args_spec_list[0], mode_t);
  60. return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
  61. }
  62. AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  63. const AbstractBasePtrList &args_spec_list) {
  64. // Inputs: two tensors.
  65. const std::string op_name = primitive->name();
  66. CheckArgsSize(op_name, args_spec_list, 2);
  67. AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
  68. AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
  69. ShapePtr x_shp = input_x->shape();
  70. auto x_shp_value = x_shp->shape();
  71. ShapePtr y_shp = input_y->shape();
  72. auto y_shp_value = y_shp->shape();
  73. // Should be matrix which shape size is 2.
  74. if (x_shp_value.size() != 2 || y_shp_value.size() != 2) {
  75. MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are "
  76. << x_shp_value.size() << ", " << y_shp_value.size() << " ";
  77. }
  78. if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) {
  79. MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}";
  80. }
  81. auto x_element = input_x->element();
  82. MS_EXCEPTION_IF_NULL(x_element);
  83. (void)x_element->Join(input_y->element());
  84. auto param = {x_shp_value[0], y_shp_value[1]};
  85. return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
  86. }
  87. AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
  88. const AbstractBasePtrList &args_spec_list) {
  89. // Inputs: condition, true branch, false branch
  90. if (args_spec_list.size() != 3) {
  91. MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
  92. << ".";
  93. }
  94. auto cond = args_spec_list[0];
  95. auto tb = args_spec_list[1];
  96. auto fb = args_spec_list[2];
  97. MS_EXCEPTION_IF_NULL(cond);
  98. ValuePtr v = cond->GetValueTrack();
  99. MS_EXCEPTION_IF_NULL(v);
  100. if (v->isa<AnyValue>()) {
  101. MS_EXCEPTION_IF_NULL(tb);
  102. return tb->Join(fb);
  103. }
  104. if (v->isa<Scalar>()) {
  105. if (v->cast<ScalarPtr>()->IsOne()) {
  106. return tb;
  107. } else {
  108. return fb;
  109. }
  110. }
  111. MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString();
  112. }
  113. std::vector<ValuePtr> GetSupportedTargetValue() {
  114. std::vector<ValuePtr> list = {kNone, MakeValue(false), MakeValue(true)};
  115. return list;
  116. }
  117. bool SupportedIsTargetValue(const ValuePtr t) {
  118. auto list = GetSupportedTargetValue();
  119. auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; });
  120. return match;
  121. }
  122. AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  123. const AbstractBasePtrList &args_spec_list) {
  124. // statement: x is t
  125. // Inputs: x, t
  126. const std::string op_name = primitive->name();
  127. CheckArgsSize(op_name, args_spec_list, 2);
  128. ValuePtr t = args_spec_list[1]->BuildValue();
  129. if (!SupportedIsTargetValue(t)) {
  130. MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString()
  131. << " for statement is, supported list is:None, False, True ";
  132. }
  133. ValuePtr x = args_spec_list[0]->BuildValue();
  134. return std::make_shared<AbstractScalar>(*t == *x);
  135. }
  136. AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  137. const AbstractBasePtrList &args_spec_list) {
  138. // statement: x is not t
  139. // Inputs: x, t
  140. const std::string op_name = primitive->name();
  141. CheckArgsSize(op_name, args_spec_list, 2);
  142. ValuePtr t = args_spec_list[1]->BuildValue();
  143. if (!SupportedIsTargetValue(t)) {
  144. MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString()
  145. << " for statement is not, supported list is:None, False, True ";
  146. }
  147. ValuePtr x = args_spec_list[0]->BuildValue();
  148. return std::make_shared<AbstractScalar>(!(*t == *x));
  149. }
  150. bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
  151. const std::string op_name = primitive->name();
  152. CheckArgsSize(op_name, args_spec_list, 2);
  153. auto key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
  154. auto dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 1);
  155. ValuePtr key_value = key->BuildValue();
  156. if (!key_value->isa<StringImm>()) {
  157. MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
  158. }
  159. auto key_str = GetValue<std::string>(key_value);
  160. std::vector<AbstractAttribute> dict_elems = dict->elements();
  161. auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
  162. [key_str](const AbstractAttribute &item) { return item.first == key_str; });
  163. return it != dict_elems.end();
  164. }
  165. AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  166. const AbstractBasePtrList &args_spec_list) {
  167. // statement: x in t
  168. // Inputs: x, t
  169. return std::make_shared<AbstractScalar>(IsInDict(primitive, args_spec_list));
  170. }
  171. AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  172. const AbstractBasePtrList &args_spec_list) {
  173. // statement: x not in t
  174. // Inputs: x, t
  175. return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
  176. }
  177. } // namespace abstract
  178. } // namespace mindspore