You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

analysis_context.cc 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pipeline/static_analysis/analysis_context.h"
  17. #include <algorithm>
  18. #include "utils/symbolic.h"
  19. #include "debug/trace.h"
  20. namespace mindspore {
  21. namespace abstract {
  22. AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph,
  23. const AbstractBasePtrList &args_spec_list) {
  24. FuncGraphPtr graph_parent = func_graph->parent();
  25. auto iter = parent_cache_.find(graph_parent);
  26. AnalysisContextPtr parent_context = nullptr;
  27. if (iter != parent_cache_.end()) {
  28. parent_context = iter->second.lock();
  29. }
  30. // if this happen, it will be bug in code. but we raise exception to keep the scene.
  31. if (parent_context == nullptr) {
  32. std::ostringstream oss;
  33. oss << "BUG: cannot found parent_context in current context: " << this->ToString()
  34. << ", func_graph: " << func_graph->ToString() << ", graph_parent: ";
  35. if (graph_parent != nullptr) {
  36. oss << graph_parent->ToString();
  37. } else {
  38. oss << "nullptr";
  39. }
  40. MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
  41. }
  42. return NewContext(parent_context, func_graph, args_spec_list);
  43. }
  44. AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) {
  45. auto p_iter = parent_cache_.find(func_graph);
  46. AnalysisContextPtr parent_context = nullptr;
  47. if (p_iter != parent_cache_.end()) {
  48. parent_context = p_iter->second.lock();
  49. } else {
  50. auto iter_parent = parent_cache_.find(func_graph->parent());
  51. if (iter_parent != parent_cache_.end()) {
  52. parent_context = iter_parent->second.lock();
  53. }
  54. }
  55. // if this happen, it will be bug in code. but we raise exception to keep the scene.
  56. if (parent_context == nullptr) {
  57. std::ostringstream oss;
  58. oss << "BUG: Filter graph failed: " << func_graph->ToString() << ", graph_parent: ";
  59. if (func_graph->parent() != nullptr) {
  60. oss << func_graph->parent()->ToString();
  61. } else {
  62. oss << "nullptr";
  63. }
  64. oss << " parent_cache_: {";
  65. for (auto iter : parent_cache_) {
  66. if (iter.first == nullptr) {
  67. oss << " [graph: nullptr";
  68. } else {
  69. oss << " [graph: " << iter.first->ToString();
  70. }
  71. // iter.second cannot be nullptr even iter.first is nullptr as it will
  72. // always be a Context() object.
  73. oss << ", context: " << iter.second.lock()->ToString() << "]";
  74. }
  75. oss << "}";
  76. MS_LOG(EXCEPTION) << "" << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
  77. }
  78. return parent_context;
  79. }
  80. AnalysisContextPtr AnalysisContext::DummyContext() {
  81. AnalysisContextPtr dummy_context = std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
  82. dummy_context->parent_cache_[nullptr] = std::weak_ptr<AnalysisContext>(dummy_context);
  83. return dummy_context;
  84. }
  85. const AnalysisContextPtr kDummyAnalysisContext =
  86. std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
  87. bool AnalysisContext::operator==(const AnalysisContext &other) const {
  88. if (func_graph_ != other.func_graph_) {
  89. return false;
  90. }
  91. if (args_spec_list_.size() != other.args_spec_list_.size()) {
  92. return false;
  93. }
  94. if (((parent_ == nullptr) && (other.parent_ != nullptr)) || ((parent_ != nullptr) && (other.parent_ == nullptr))) {
  95. return false;
  96. }
  97. // Compare parent with content.
  98. bool is_parent_equal = false;
  99. if (parent_ == other.parent_) {
  100. is_parent_equal = true;
  101. } else if (*parent_ == *other.parent_) {
  102. is_parent_equal = true;
  103. } else {
  104. return false;
  105. }
  106. for (std::size_t i = 0; i < args_spec_list_.size(); i++) {
  107. if (!(*args_spec_list_[i] == *other.args_spec_list_[i])) {
  108. return false;
  109. }
  110. }
  111. return is_parent_equal;
  112. }
  113. // brief The key which controls the graph cloning in Specialize.
  114. //
  115. // Originally, specialize use context directly as the key for cloning graph. The graph will be cloned multiple times
  116. // for different context, which means the graph is called from different node with different arguments and different
  117. // free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what
  118. // graph can be reused.
  119. // The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined
  120. // and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in infer, thus the reused
  121. // graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize.
  122. // The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies
  123. // on correct shape to specialize a tensor constant.
  124. AnalysisContextPtr AnalysisContext::SpecializeKey() const {
  125. AbstractBasePtrList args_broad_shp;
  126. (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(args_broad_shp),
  127. [](const AbstractBasePtr &arg) -> AbstractBasePtr {
  128. if (arg->isa<AbstractScalar>()) {
  129. auto val = arg->GetValueTrack();
  130. if (val->isa<SymbolicKeyInstance>()) {
  131. auto scalar_spec = dyn_cast<AbstractScalar>(arg);
  132. auto ret_spec = scalar_spec->Broaden();
  133. ret_spec->set_value(kAnyValue);
  134. return ret_spec;
  135. }
  136. }
  137. if (arg->isa<AbstractRef>()) {
  138. MS_LOG(DEBUG) << "refkey broaden";
  139. auto arg_spec = dyn_cast<AbstractRef>(arg);
  140. auto ret_spec = arg_spec->Broaden();
  141. return ret_spec;
  142. }
  143. return arg;
  144. });
  145. AnalysisContextPtr context_new = std::make_shared<AnalysisContext>(nullptr, func_graph_, args_broad_shp);
  146. context_new->parent_ = parent_;
  147. return context_new;
  148. }
  149. std::size_t AnalysisContext::hash() {
  150. std::size_t hash_value = 0;
  151. // hash() recursion exit condition.
  152. if (parent_ != nullptr) {
  153. hash_value = hash_combine(hash_value, parent_->hash());
  154. }
  155. if (func_graph_ != nullptr) {
  156. hash_value = hash_combine(hash_value, func_graph_->hash());
  157. }
  158. return hash_value;
  159. }
  160. std::string AnalysisContext::ToString() const {
  161. std::ostringstream buffer;
  162. buffer << "{";
  163. if (func_graph_ != nullptr) {
  164. buffer << "Func Graph: " << func_graph_->ToString();
  165. }
  166. buffer << " Args: ";
  167. int i = 0;
  168. for (const auto &arg : args_spec_list_) {
  169. buffer << "[" << i << "]: " << arg->ToString() << ", ";
  170. i++;
  171. }
  172. if (parent_ != nullptr) {
  173. buffer << "Parent: " << parent_->ToString();
  174. }
  175. buffer << "}";
  176. return buffer.str();
  177. }
  178. } // namespace abstract
  179. } // namespace mindspore