/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/static_analysis/analysis_context.h" #include #include "utils/symbolic.h" #include "debug/trace.h" namespace mindspore { namespace abstract { AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list) { auto children_context_map_iter = parent->children_cache_.find(fg); if (children_context_map_iter != parent->children_cache_.end()) { auto children_context_map = children_context_map_iter->second; auto children_context_iter = children_context_map.find(args_spec_list); if (children_context_iter != children_context_map.end()) { return children_context_iter->second.lock(); } } AnalysisContextPtr context_new = std::make_shared(parent, fg, args_spec_list); // Reference to myself, so use weak_ptr to break reference cycle. auto weak_context = std::weak_ptr(context_new); context_new->parent_cache_[fg] = weak_context; parent->children_cache_[fg][args_spec_list] = weak_context; return context_new; } AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { FuncGraphPtr graph_parent = func_graph->parent(); auto iter = parent_cache_.find(graph_parent); AnalysisContextPtr parent_context = nullptr; if (iter != parent_cache_.end()) { parent_context = iter->second.lock(); } // if this happen, it will be bug in code. but we raise exception to keep the scene. if (parent_context == nullptr) { std::ostringstream oss; oss << "BUG: cannot found parent_context in current context: " << this->ToString() << ", func_graph: " << func_graph->ToString() << ", graph_parent: "; if (graph_parent != nullptr) { oss << graph_parent->ToString(); } else { oss << "nullptr"; } MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } return NewContext(parent_context, func_graph, args_spec_list); } AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) { auto p_iter = parent_cache_.find(func_graph); AnalysisContextPtr parent_context = nullptr; if (p_iter != parent_cache_.end()) { parent_context = p_iter->second.lock(); } else { auto iter_parent = parent_cache_.find(func_graph->parent()); if (iter_parent != parent_cache_.end()) { parent_context = iter_parent->second.lock(); } } // if this happen, it will be bug in code. but we raise exception to keep the scene. if (parent_context == nullptr) { std::ostringstream oss; oss << "BUG: Filter graph failed: " << func_graph->ToString() << ", graph_parent: "; if (func_graph->parent() != nullptr) { oss << func_graph->parent()->ToString(); } else { oss << "nullptr"; } oss << " parent_cache_: {"; for (auto iter : parent_cache_) { if (iter.first == nullptr) { oss << " [graph: nullptr"; } else { oss << " [graph: " << iter.first->ToString(); } // iter.second cannot be nullptr even iter.first is nullptr as it will // always be a Context() object. oss << ", context: " << iter.second.lock()->ToString() << "]"; } oss << "}"; MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } return parent_context; } AnalysisContextPtr AnalysisContext::DummyContext() { AnalysisContextPtr dummy_context = std::make_shared(nullptr, nullptr, AbstractBasePtrList()); dummy_context->parent_cache_[nullptr] = std::weak_ptr(dummy_context); return dummy_context; } bool AnalysisContext::IsDummyContext() { if (parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty()) { return true; } return false; } const AnalysisContextPtr kDummyAnalysisContext = std::make_shared(nullptr, nullptr, AbstractBasePtrList()); bool AnalysisContext::operator==(const AnalysisContext &other) const { if (func_graph_ != other.func_graph_) { return false; } if (args_spec_list_.size() != other.args_spec_list_.size()) { return false; } if (((parent_ == nullptr) && (other.parent_ != nullptr)) || ((parent_ != nullptr) && (other.parent_ == nullptr))) { return false; } // Compare parent with content. bool is_parent_equal = false; if (parent_ == other.parent_) { is_parent_equal = true; } else if (*parent_ == *other.parent_) { is_parent_equal = true; } else { return false; } for (std::size_t i = 0; i < args_spec_list_.size(); i++) { if (!(*args_spec_list_[i] == *other.args_spec_list_[i])) { return false; } } return is_parent_equal; } // brief The key which controls the graph cloning in Specialize. // // Originally, specialize use context directly as the key for cloning graph. The graph will be cloned multiple times // for different context, which means the graph is called from different node with different arguments and different // free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what // graph can be reused. // The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined // and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in eval, thus the reused // graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize. // The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies // on correct shape to specialize a tensor constant. AnalysisContextPtr AnalysisContext::SpecializeKey() const { AbstractBasePtrList args_broad_shp; (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(args_broad_shp), [](const AbstractBasePtr &arg) -> AbstractBasePtr { if (arg->isa()) { auto val = arg->GetValueTrack(); if (val->isa()) { auto scalar_spec = dyn_cast(arg); auto ret_spec = scalar_spec->Broaden(); return ret_spec; } } if (arg->isa()) { MS_LOG(DEBUG) << "refkey broaden"; auto arg_spec = dyn_cast(arg); auto ret_spec = arg_spec->Broaden(); return ret_spec; } return arg; }); AnalysisContextPtr context_new = std::make_shared(nullptr, func_graph_, args_broad_shp); context_new->parent_ = parent_; return context_new; } std::size_t AnalysisContext::hash() { std::size_t hash_value = 0; // hash() recursion exit condition. if (parent_ != nullptr) { hash_value = hash_combine(hash_value, parent_->hash()); } if (func_graph_ != nullptr) { hash_value = hash_combine(hash_value, func_graph_->hash()); } return hash_value; } std::string AnalysisContext::ToString() const { std::ostringstream buffer; buffer << "{"; if (func_graph_ != nullptr) { buffer << "Func Graph: " << func_graph_->ToString(); } buffer << " Args: "; int i = 0; for (const auto &arg : args_spec_list_) { buffer << "[" << i << "]: " << arg->ToString() << ", "; i++; } if (parent_ != nullptr) { buffer << "Parent: " << parent_->ToString(); } buffer << "}"; return buffer.str(); } } // namespace abstract } // namespace mindspore