You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

analysis_context.cc 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/analysis_context.h"
  17. #include <algorithm>
  18. #include "utils/symbolic.h"
  19. #include "debug/trace.h"
  20. namespace mindspore {
  21. namespace abstract {
  22. AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent, FuncGraphPtr fg,
  23. const AbstractBasePtrList &args_spec_list) {
  24. auto children_context_map_iter = parent->children_cache_.find(fg);
  25. if (children_context_map_iter != parent->children_cache_.end()) {
  26. auto children_context_map = children_context_map_iter->second;
  27. auto children_context_iter = children_context_map.find(args_spec_list);
  28. if (children_context_iter != children_context_map.end()) {
  29. return children_context_iter->second.lock();
  30. }
  31. }
  32. AnalysisContextPtr context_new = std::make_shared<AnalysisContext>(parent, fg, args_spec_list);
  33. // Reference to myself, so use weak_ptr to break reference cycle.
  34. auto weak_context = std::weak_ptr<AnalysisContext>(context_new);
  35. context_new->parent_cache_[fg] = weak_context;
  36. parent->children_cache_[fg][args_spec_list] = weak_context;
  37. return context_new;
  38. }
  39. AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph,
  40. const AbstractBasePtrList &args_spec_list) {
  41. FuncGraphPtr graph_parent = func_graph->parent();
  42. auto iter = parent_cache_.find(graph_parent);
  43. AnalysisContextPtr parent_context = nullptr;
  44. if (iter != parent_cache_.end()) {
  45. parent_context = iter->second.lock();
  46. }
  47. // if this happen, it will be bug in code. but we raise exception to keep the scene.
  48. if (parent_context == nullptr) {
  49. std::ostringstream oss;
  50. oss << "BUG: cannot found parent_context in current context: " << this->ToString()
  51. << ", func_graph: " << func_graph->ToString() << ", graph_parent: ";
  52. if (graph_parent != nullptr) {
  53. oss << graph_parent->ToString();
  54. } else {
  55. oss << "nullptr";
  56. }
  57. MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
  58. }
  59. return NewContext(parent_context, func_graph, args_spec_list);
  60. }
  61. AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) {
  62. auto p_iter = parent_cache_.find(func_graph);
  63. AnalysisContextPtr parent_context = nullptr;
  64. if (p_iter != parent_cache_.end()) {
  65. parent_context = p_iter->second.lock();
  66. } else {
  67. auto iter_parent = parent_cache_.find(func_graph->parent());
  68. if (iter_parent != parent_cache_.end()) {
  69. parent_context = iter_parent->second.lock();
  70. }
  71. }
  72. // if this happen, it will be bug in code. but we raise exception to keep the scene.
  73. if (parent_context == nullptr) {
  74. std::ostringstream oss;
  75. oss << "BUG: Filter graph failed: " << func_graph->ToString() << ", graph_parent: ";
  76. if (func_graph->parent() != nullptr) {
  77. oss << func_graph->parent()->ToString();
  78. } else {
  79. oss << "nullptr";
  80. }
  81. oss << " parent_cache_: {";
  82. for (auto iter : parent_cache_) {
  83. if (iter.first == nullptr) {
  84. oss << " [graph: nullptr";
  85. } else {
  86. oss << " [graph: " << iter.first->ToString();
  87. }
  88. // iter.second cannot be nullptr even iter.first is nullptr as it will
  89. // always be a Context() object.
  90. oss << ", context: " << iter.second.lock()->ToString() << "]";
  91. }
  92. oss << "}";
  93. MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
  94. }
  95. return parent_context;
  96. }
  97. AnalysisContextPtr AnalysisContext::DummyContext() {
  98. AnalysisContextPtr dummy_context = std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
  99. dummy_context->parent_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
  100. return dummy_context;
  101. }
  102. bool AnalysisContext::IsDummyContext() {
  103. if (parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty()) {
  104. return true;
  105. }
  106. return false;
  107. }
  108. const AnalysisContextPtr kDummyAnalysisContext =
  109. std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
  110. bool AnalysisContext::operator==(const AnalysisContext &other) const {
  111. if (func_graph_ != other.func_graph_) {
  112. return false;
  113. }
  114. if (args_spec_list_.size() != other.args_spec_list_.size()) {
  115. return false;
  116. }
  117. if (((parent_ == nullptr) && (other.parent_ != nullptr)) || ((parent_ != nullptr) && (other.parent_ == nullptr))) {
  118. return false;
  119. }
  120. // Compare parent with content.
  121. bool is_parent_equal = false;
  122. if (parent_ == other.parent_) {
  123. is_parent_equal = true;
  124. } else if (*parent_ == *other.parent_) {
  125. is_parent_equal = true;
  126. } else {
  127. return false;
  128. }
  129. for (std::size_t i = 0; i < args_spec_list_.size(); i++) {
  130. if (!(*args_spec_list_[i] == *other.args_spec_list_[i])) {
  131. return false;
  132. }
  133. }
  134. return is_parent_equal;
  135. }
  136. // brief The key which controls the graph cloning in Specialize.
  137. //
  138. // Originally, specialize use context directly as the key for cloning graph. The graph will be cloned multiple times
  139. // for different context, which means the graph is called from different node with different arguments and different
  140. // free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what
  141. // graph can be reused.
  142. // The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined
  143. // and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in eval, thus the reused
  144. // graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize.
  145. // The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies
  146. // on correct shape to specialize a tensor constant.
  147. AnalysisContextPtr AnalysisContext::SpecializeKey() const {
  148. AbstractBasePtrList args_broad_shp;
  149. (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(args_broad_shp),
  150. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  151. if (arg->isa<AbstractScalar>()) {
  152. auto val = arg->GetValueTrack();
  153. if (val->isa<SymbolicKeyInstance>()) {
  154. auto scalar_spec = dyn_cast<AbstractScalar>(arg);
  155. auto ret_spec = scalar_spec->Broaden();
  156. return ret_spec;
  157. }
  158. }
  159. if (arg->isa<AbstractRef>()) {
  160. MS_LOG(DEBUG) << "refkey broaden";
  161. auto arg_spec = dyn_cast<AbstractRef>(arg);
  162. auto ret_spec = arg_spec->Broaden();
  163. return ret_spec;
  164. }
  165. return arg;
  166. });
  167. AnalysisContextPtr context_new = std::make_shared<AnalysisContext>(nullptr, func_graph_, args_broad_shp);
  168. context_new->parent_ = parent_;
  169. return context_new;
  170. }
  171. std::size_t AnalysisContext::hash() {
  172. std::size_t hash_value = 0;
  173. // hash() recursion exit condition.
  174. if (parent_ != nullptr) {
  175. hash_value = hash_combine(hash_value, parent_->hash());
  176. }
  177. if (func_graph_ != nullptr) {
  178. hash_value = hash_combine(hash_value, func_graph_->hash());
  179. }
  180. return hash_value;
  181. }
  182. std::string AnalysisContext::ToString() const {
  183. std::ostringstream buffer;
  184. buffer << "{";
  185. if (func_graph_ != nullptr) {
  186. buffer << "Func Graph: " << func_graph_->ToString();
  187. }
  188. buffer << " Args: ";
  189. int i = 0;
  190. for (const auto &arg : args_spec_list_) {
  191. buffer << "[" << i << "]: " << arg->ToString() << ", ";
  192. i++;
  193. }
  194. if (parent_ != nullptr) {
  195. buffer << "Parent: " << parent_->ToString();
  196. }
  197. buffer << "}";
  198. return buffer.str();
  199. }
  200. } // namespace abstract
  201. } // namespace mindspore