You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prim_others.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/param_validator.h"
  17. #include "pipeline/static_analysis/prim.h"
  18. #include "operator/ops.h"
  19. #include "pipeline/static_analysis/utils.h"
  20. #include "utils/symbolic.h"
  21. namespace mindspore {
  22. namespace abstract {
  23. AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  24. const AbstractBasePtrList &args_spec_list) {
  25. // An object of a subclass of AbstractBase
  26. CheckArgsSize(primitive->name(), args_spec_list, 1);
  27. return args_spec_list[0];
  28. }
  29. AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  30. const AbstractBasePtrList &args_spec_list) {
  31. // args: An object of AbstractFunction.
  32. CheckArgsSize(primitive->name(), args_spec_list, 1);
  33. MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString();
  34. AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
  35. if (x == nullptr) {
  36. return std::make_shared<AbstractJTagged>(args_spec_list[0]);
  37. }
  38. AbstractFuncAtomPtrList jv;
  39. auto build_jv = [&jv](const AbstractFuncAtomPtr &func) {
  40. auto j_closure = std::make_shared<JTransformedAbstractClosure>(func);
  41. jv.push_back(j_closure);
  42. };
  43. x->Visit(build_jv);
  44. return AbstractFunction::MakeAbstractFunction(jv);
  45. }
  46. AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  47. const AbstractBasePtrList &args_spec_list) {
  48. MS_EXCEPTION_IF_NULL(primitive);
  49. // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
  50. CheckArgsSize(primitive->name(), args_spec_list, 3);
  51. auto key = args_spec_list[1];
  52. auto dflt = args_spec_list[2];
  53. TypePtr type = key->GetTypeTrack();
  54. MS_EXCEPTION_IF_NULL(type);
  55. if (type->type_id() != kObjectTypeSymbolicKeyType) {
  56. MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
  57. }
  58. if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
  59. return dflt;
  60. }
  61. ValuePtr key_value_ptr = key->GetValueTrack();
  62. MS_EXCEPTION_IF_NULL(key_value_ptr);
  63. auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
  64. auto expected = key_value_track->abstract();
  65. MS_EXCEPTION_IF_NULL(expected);
  66. (void)expected->Join(dflt);
  67. return expected;
  68. }
  69. AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  70. const AbstractBasePtrList &args_spec_list) {
  71. // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
  72. CheckArgsSize(primitive->name(), args_spec_list, 3);
  73. auto key = args_spec_list[1];
  74. auto value = args_spec_list[2];
  75. ValuePtr key_value_ptr = key->GetValueTrack();
  76. MS_EXCEPTION_IF_NULL(key_value_ptr);
  77. auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
  78. if (key_value_track == nullptr) {
  79. MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: "
  80. << key_value_ptr->ToString();
  81. }
  82. auto expected = key_value_track->abstract();
  83. MS_EXCEPTION_IF_NULL(expected);
  84. (void)expected->Join(value);
  85. return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  86. }
  87. AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  88. const AbstractBasePtrList &args_spec_list) {
  89. // args: Three objects of a subclass of AbstractBase, env, key, dflt(default).
  90. CheckArgsSize(primitive->name(), args_spec_list, 2);
  91. return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
  92. }
  93. AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) {
  94. ValuePtr name_value = prim->GetAttr("tag");
  95. auto name = name_value->cast<StringImmPtr>();
  96. if (name == nullptr) {
  97. MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << ".";
  98. }
  99. auto refkey = std::make_shared<RefKey>(name->value());
  100. if (refkey == nullptr) {
  101. MS_LOG(EXCEPTION) << "MakeRefKey std::make_shared<RefKey> failed";
  102. }
  103. return refkey->ToAbstract();
  104. }
  105. AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &,
  106. const AbstractBasePtrList &args_spec_list) {
  107. // arguments: key, value, original value
  108. if (args_spec_list.size() != 3) {
  109. MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
  110. << ".";
  111. }
  112. TypePtr type = args_spec_list[0]->GetTypeTrack();
  113. if (type->type_id() != kObjectTypeRefKey) {
  114. MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
  115. }
  116. return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
  117. }
  118. AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
  119. const AbstractBasePtrList &args_spec_list) {
  120. // arguments: value
  121. if (args_spec_list.size() != 1) {
  122. MS_LOG(EXCEPTION) << "get_ref_key requires 1 parameters, while the input size is " << args_spec_list.size() << ".";
  123. }
  124. TypePtr type = args_spec_list[0]->GetTypeTrack();
  125. if (type->type_id() != kObjectTypeRef) {
  126. MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString();
  127. }
  128. return args_spec_list[0]->cast<AbstractRefPtr>()->ref();
  129. }
  130. AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &,
  131. const AbstractBasePtrList &args_spec_list) {
  132. // arguments: value
  133. if (args_spec_list.size() != 1) {
  134. MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size()
  135. << ".";
  136. }
  137. TypePtr type = args_spec_list[0]->GetTypeTrack();
  138. if (type->type_id() != kObjectTypeRef) {
  139. MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString();
  140. }
  141. return args_spec_list[0]->cast<AbstractRefPtr>()->ref();
  142. }
  143. AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &,
  144. const AbstractBasePtrList &args_spec_list) {
  145. // arguments: value
  146. if (args_spec_list.size() != 1) {
  147. MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size()
  148. << ".";
  149. }
  150. TypePtr type = args_spec_list[0]->GetTypeTrack();
  151. if (type->type_id() != kObjectTypeRef) {
  152. MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString();
  153. }
  154. return args_spec_list[0]->cast<AbstractRefPtr>()->ref_origin();
  155. }
  156. AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  157. const AbstractBasePtrList &args_spec_list) {
  158. // args: Two objects of a subclass of AbstractBase, key and value.
  159. CheckArgsSize(primitive->name(), args_spec_list, 2);
  160. TypePtr type = args_spec_list[0]->GetTypeTrack();
  161. MS_EXCEPTION_IF_NULL(type);
  162. if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) {
  163. MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString();
  164. }
  165. return std::make_shared<AbstractScalar>(kAnyValue, kBool);
  166. }
  167. AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  168. const AbstractBasePtrList &args_spec_list) {
  169. if (args_spec_list.empty()) {
  170. MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0";
  171. }
  172. auto depends = args_spec_list[0]->Broaden();
  173. return depends;
  174. }
  175. bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
  176. if (x_shape.size() != y_shape.size()) {
  177. return false;
  178. }
  179. for (size_t i = 0; i < x_shape.size(); ++i) {
  180. if (GetValue<int>(x_shape[i]) != GetValue<int>(y_shape[i])) {
  181. return false;
  182. }
  183. }
  184. return true;
  185. }
  186. enum State {
  187. SAME,
  188. X_ONE,
  189. Y_ONE,
  190. };
  191. void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y,
  192. std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) {
  193. const size_t n = reverse_x.size();
  194. for (size_t i = 0; i < n; ++i) {
  195. State curr;
  196. const int32_t x_i = reverse_x[i];
  197. const int32_t y_i = reverse_y[i];
  198. const int reduce_idx = SizeToInt(n - 1 - i);
  199. if (x_i == y_i) {
  200. curr = SAME;
  201. } else if (x_i == 1) {
  202. grad_x_reduce_idx->push_back(reduce_idx);
  203. curr = X_ONE;
  204. } else if (y_i == 1) {
  205. grad_y_reduce_idy->push_back(reduce_idx);
  206. curr = Y_ONE;
  207. } else {
  208. MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs";
  209. }
  210. if (curr == SAME && x_i == 1) {
  211. grad_x_reduce_idx->push_back(reduce_idx);
  212. grad_y_reduce_idy->push_back(reduce_idx);
  213. continue;
  214. }
  215. }
  216. std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
  217. std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
  218. }
  219. AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
  220. std::vector<int> reverse_x;
  221. std::vector<int> reverse_y;
  222. (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x),
  223. [](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
  224. (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y),
  225. [](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
  226. if (reverse_x.size() > reverse_y.size()) {
  227. reverse_y.resize(reverse_x.size(), 1);
  228. } else {
  229. reverse_x.resize(reverse_y.size(), 1);
  230. }
  231. std::vector<int> grad_x_reduce_idx;
  232. std::vector<int> grad_y_reduce_idy;
  233. ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
  234. AbstractBasePtrList abs_list_x;
  235. AbstractBasePtrList abs_list_y;
  236. (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x),
  237. [](int v) { return abstract::FromValue(v); });
  238. (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y),
  239. [](int v) { return abstract::FromValue(v); });
  240. auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x);
  241. auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y);
  242. AbstractBasePtrList elem_list;
  243. elem_list.push_back(x_reduce_idx);
  244. elem_list.push_back(y_reduce_idx);
  245. return std::make_shared<AbstractTuple>(elem_list);
  246. }
  247. AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  248. const AbstractBasePtrList &args_spec_list) {
  249. // this primitive get the index that need to reduce
  250. // input: x's shape and y's shape, inputs should be tuple
  251. // output: tuple of x and y 's reduce index, reduce index should be a tuple
  252. const std::string op_name = primitive->name();
  253. CheckArgsSize(op_name, args_spec_list, 2);
  254. auto arg_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
  255. auto arg_y = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
  256. ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
  257. MS_EXCEPTION_IF_NULL(arg_x_value);
  258. ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
  259. MS_EXCEPTION_IF_NULL(arg_y_value);
  260. const std::vector<ValuePtr> x_shape = arg_x_value->value();
  261. const std::vector<ValuePtr> y_shape = arg_y_value->value();
  262. bool is_same_shape = CompareShape(x_shape, y_shape);
  263. // if it is the same shape , do not need reduce , return empty tuple
  264. if (is_same_shape) {
  265. AbstractBasePtrList empty_list;
  266. auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
  267. auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
  268. AbstractBasePtrList elem_list;
  269. elem_list.push_back(x_reduce_idx);
  270. elem_list.push_back(y_reduce_idx);
  271. return std::make_shared<AbstractTuple>(elem_list);
  272. }
  273. return BroadcastGradientArgsDiff(x_shape, y_shape);
  274. }
  275. AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
  276. const AbstractBasePtrList &args_spec_list) {
  277. // args: Two objects of a subclass of AbstractBase
  278. CheckArgsSize(primitive->name(), args_spec_list, 2);
  279. auto arg_src = args_spec_list[0];
  280. auto arg_dst = args_spec_list[1];
  281. // control depend can not setup tuple of ops to tuple of ops dependency relation
  282. if (arg_src->isa<AbstractTuple>() && arg_dst->isa<AbstractTuple>()) {
  283. auto src_size = arg_src->cast<AbstractTuplePtr>()->size();
  284. auto dst_size = arg_src->cast<AbstractTuplePtr>()->size();
  285. if (src_size > 1 && dst_size > 1) {
  286. MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple";
  287. }
  288. }
  289. return std::make_shared<AbstractScalar>(kAnyValue, kBool);
  290. }
  291. } // namespace abstract
  292. } // namespace mindspore