You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

abstract_function.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/abstract_function.h"
  17. #include <vector>
  18. #include "pipeline/static_analysis/analysis_context.h"
  19. #include "pipeline/static_analysis/static_analysis.h"
  20. namespace mindspore {
  21. namespace abstract {
  22. class Evaluator;
  23. class AnalysisEngine;
  24. AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) {
  25. if (func_list.size() == 1) {
  26. return func_list[0];
  27. }
  28. return std::make_shared<AbstractFuncUnion>(func_list);
  29. }
  30. AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) {
  31. auto this_func = shared_from_base<AbstractFuncAtom>();
  32. if (other->isa<AbstractFuncAtom>()) {
  33. if (*this_func == *other) {
  34. return this_func;
  35. }
  36. return std::make_shared<AbstractFuncUnion>(this_func, other);
  37. }
  38. auto other_union = dyn_cast<AbstractFuncUnion>(other);
  39. if (other_union->IsSuperSet(this_func)) {
  40. return other;
  41. }
  42. return std::make_shared<AbstractFuncUnion>(this_func, other);
  43. }
  44. void AbstractFuncAtom::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
  45. visit_func(const_cast<AbstractFuncAtom *>(this)->shared_from_base<AbstractFuncAtom>());
  46. }
  47. bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; }
  48. AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; }
  49. AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) {
  50. AbstractFuncAtomPtrList new_func_list;
  51. auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); };
  52. first->Visit(build_func_list);
  53. second->Visit(build_func_list);
  54. func_list_ = new_func_list;
  55. }
  56. std::string AbstractFuncUnion::ToString() const {
  57. std::ostringstream buffer;
  58. buffer << "AbstractFuncUnion({";
  59. int i = 0;
  60. for (const auto &func : func_list_) {
  61. MS_EXCEPTION_IF_NULL(func);
  62. buffer << "[" << i << "]: " << func->ToString() << ", ";
  63. i++;
  64. }
  65. buffer << "})";
  66. return buffer.str();
  67. }
  68. bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) {
  69. MS_EXCEPTION_IF_NULL(other);
  70. std::vector<bool> is_in_list;
  71. auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) {
  72. auto iter = find(func_list_.begin(), func_list_.end(), func);
  73. if (iter == func_list_.end()) {
  74. is_in_list.push_back(false);
  75. }
  76. return true;
  77. };
  78. other->Visit(build_in_list);
  79. return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; });
  80. }
  81. AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
  82. auto this_func = shared_from_base<AbstractFunction>();
  83. if (other->isa<AbstractFuncAtom>()) {
  84. if (IsSuperSet(other)) {
  85. return this_func;
  86. }
  87. return std::make_shared<AbstractFuncUnion>(this_func, other);
  88. }
  89. auto other_union = dyn_cast<AbstractFuncUnion>(other);
  90. if (other_union->IsSuperSet(this_func)) {
  91. return other;
  92. }
  93. return std::make_shared<AbstractFuncUnion>(this_func, other);
  94. }
  95. void AbstractFuncUnion::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
  96. for (AbstractFuncAtomPtr poss : func_list_) {
  97. visit_func(poss);
  98. }
  99. }
  100. bool AbstractFuncUnion::operator==(const AbstractFunction &other) const {
  101. if (!other.isa<AbstractFuncUnion>()) {
  102. return false;
  103. }
  104. auto other_union = static_cast<const AbstractFuncUnion *>(&other);
  105. if (func_list_.size() != other_union->func_list_.size()) {
  106. return false;
  107. }
  108. if (func_list_ == other_union->func_list_) {
  109. return true;
  110. }
  111. return false;
  112. }
  113. std::size_t AbstractFuncUnion::hash() const {
  114. std::size_t hash_sum = 0;
  115. for (auto f : func_list_) {
  116. hash_sum = hash_combine(hash_sum, f->hash());
  117. }
  118. return hash_sum;
  119. }
  120. EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
  121. MS_EXCEPTION_IF_NULL(engine);
  122. return engine->_GetEvaluatorFor(shared_from_base<PrimitiveAbstractClosure>());
  123. }
  124. bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
  125. if (!other.isa<PrimitiveAbstractClosure>()) {
  126. return false;
  127. }
  128. auto other_prim = static_cast<const PrimitiveAbstractClosure *>(&other);
  129. if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) {
  130. return true;
  131. }
  132. return false;
  133. }
  134. std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); }
  135. EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
  136. MS_EXCEPTION_IF_NULL(engine);
  137. return engine->_GetEvaluatorFor(shared_from_base<FuncGraphAbstractClosure>());
  138. }
  139. bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
  140. if (!other.isa<FuncGraphAbstractClosure>()) {
  141. return false;
  142. }
  143. auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other);
  144. if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) {
  145. return true;
  146. }
  147. return false;
  148. }
  149. std::size_t FuncGraphAbstractClosure::hash() const {
  150. auto hash_value = hash_combine(tid(), func_graph_->hash());
  151. hash_value = hash_combine(hash_value, context_->hash());
  152. return hash_value;
  153. }
  154. std::string FuncGraphAbstractClosure::ToString() const {
  155. std::stringstream ss;
  156. ss << "FuncGraphAbstractClosure: " << this << "FuncGraph: " << func_graph_.get() << ", " << func_graph_->ToString()
  157. << "; Context: " << context_.get() << context_->ToString();
  158. return ss.str();
  159. }
  160. EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
  161. MS_EXCEPTION_IF_NULL(engine);
  162. return engine->_GetEvaluatorFor(shared_from_base<MetaFuncGraphAbstractClosure>());
  163. }
  164. bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
  165. if (!other.isa<MetaFuncGraphAbstractClosure>()) {
  166. return false;
  167. }
  168. auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other);
  169. if (meta_func_graph_ == other_meta_fg->meta_func_graph_) {
  170. return true;
  171. }
  172. return false;
  173. }
  174. std::size_t MetaFuncGraphAbstractClosure::hash() const {
  175. auto hash_value = hash_combine(tid(), meta_func_graph_->hash());
  176. return hash_value;
  177. }
  178. std::string MetaFuncGraphAbstractClosure::ToString() const {
  179. return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name();
  180. }
  181. bool PartialAbstractClosure::operator==(const AbstractFunction &other) const {
  182. if (!other.isa<PartialAbstractClosure>()) {
  183. return false;
  184. }
  185. auto other_partial = static_cast<const PartialAbstractClosure *>(&other);
  186. if (fn_ != other_partial->fn_) {
  187. return false;
  188. }
  189. if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
  190. return false;
  191. }
  192. if (args_spec_list_ == other_partial->args_spec_list_) {
  193. return true;
  194. }
  195. return false;
  196. }
  197. std::size_t PartialAbstractClosure::hash() const {
  198. auto hash_value = hash_combine(tid(), fn_->hash());
  199. hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
  200. return hash_value;
  201. }
  202. EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
  203. MS_EXCEPTION_IF_NULL(engine);
  204. return engine->_GetEvaluatorFor(shared_from_base<PartialAbstractClosure>());
  205. }
  206. std::string PartialAbstractClosure::ToString() const {
  207. std::ostringstream buffer;
  208. buffer << "PartialAbstractClosure(" << fn_->ToString() << "(";
  209. for (auto arg : args_spec_list_) {
  210. buffer << arg->ToString() << ", ";
  211. }
  212. buffer << "))";
  213. return buffer.str();
  214. }
  215. EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
  216. MS_EXCEPTION_IF_NULL(engine);
  217. return engine->_GetEvaluatorFor(shared_from_base<JTransformedAbstractClosure>());
  218. }
  219. bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
  220. if (!other.isa<JTransformedAbstractClosure>()) {
  221. return false;
  222. }
  223. auto other_transformed = static_cast<const JTransformedAbstractClosure *>(&other);
  224. if (fn_ == other_transformed->fn_) {
  225. return true;
  226. }
  227. return false;
  228. }
  229. std::size_t JTransformedAbstractClosure::hash() const {
  230. auto hash_value = hash_combine(tid(), fn_->hash());
  231. return hash_value;
  232. }
  233. EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
  234. MS_EXCEPTION_IF_NULL(engine);
  235. return engine->_GetEvaluatorFor(shared_from_base<VirtualAbstractClosure>());
  236. }
  237. bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
  238. if (!other.isa<VirtualAbstractClosure>()) {
  239. return false;
  240. }
  241. auto other_virtual = static_cast<const VirtualAbstractClosure *>(&other);
  242. if (output_ != other_virtual->output_) {
  243. return false;
  244. }
  245. if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) {
  246. return false;
  247. }
  248. if (args_spec_list_ == other_virtual->args_spec_list_) {
  249. return true;
  250. }
  251. return false;
  252. }
  253. std::size_t VirtualAbstractClosure::hash() const {
  254. auto hash_value = hash_combine(tid(), output_->hash());
  255. hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
  256. return hash_value;
  257. }
  258. std::string VirtualAbstractClosure::ToString() const {
  259. std::ostringstream buffer;
  260. buffer << "VirtualAbstractClosure(args: {";
  261. int i = 0;
  262. for (const auto &arg : args_spec_list_) {
  263. MS_EXCEPTION_IF_NULL(arg);
  264. buffer << "[" << i << "]: " << arg->ToString() << ", ";
  265. i++;
  266. }
  267. buffer << "}, output: " << output_->ToString() << ")";
  268. return buffer.str();
  269. }
  270. EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
  271. MS_EXCEPTION_IF_NULL(engine);
  272. return engine->_GetEvaluatorFor(shared_from_base<TypedPrimitiveAbstractClosure>());
  273. }
  274. bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
  275. if (!other.isa<TypedPrimitiveAbstractClosure>()) {
  276. return false;
  277. }
  278. auto other_typed = static_cast<const TypedPrimitiveAbstractClosure *>(&other);
  279. if (output_ != other_typed->output_) {
  280. return false;
  281. }
  282. if (prim_ != other_typed->prim_) {
  283. return false;
  284. }
  285. if (args_spec_list_.size() != other_typed->args_spec_list_.size()) {
  286. return false;
  287. }
  288. if (args_spec_list_ == other_typed->args_spec_list_) {
  289. return true;
  290. }
  291. return false;
  292. }
  293. std::size_t TypedPrimitiveAbstractClosure::hash() const {
  294. auto hash_value = hash_combine(tid(), prim_->hash());
  295. hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_));
  296. return hash_value;
  297. }
  298. std::string TypedPrimitiveAbstractClosure::ToString() const {
  299. std::ostringstream buffer;
  300. buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {";
  301. int i = 0;
  302. for (const auto &arg : args_spec_list_) {
  303. MS_EXCEPTION_IF_NULL(arg);
  304. buffer << "[" << i << "]: " << arg->ToString() << ", ";
  305. i++;
  306. }
  307. buffer << "}, output: " << output_->ToString() << ")";
  308. return buffer.str();
  309. }
  310. bool DummyAbstractClosure::operator==(const AbstractFunction &other) const {
  311. if (!other.isa<DummyAbstractClosure>()) {
  312. return false;
  313. }
  314. return true;
  315. }
  316. } // namespace abstract
  317. } // namespace mindspore