|
- /**
- * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
- *
- * Copyright 2019 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #ifndef PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
- #define PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
-
- #include <list>
- #include <memory>
- #include <string>
- #include <unordered_map>
- #include <vector>
- #include <utility>
- #include <map>
-
- #ifdef DEBUG
- #include <stack>
- #endif
-
- #include "utils/log_adapter.h"
- #include "ir/anf.h"
- #include "ir/primitive.h"
- #include "pipeline/static_analysis/analysis_context.h"
- #include "pipeline/static_analysis/abstract_function.h"
- #include "pipeline/parse/parse.h"
-
- namespace mindspore {
- namespace abstract {
- // define attribute value map
- using AttrValueMap = std::unordered_map<std::string, ValuePtr>;
- using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
-
- // the class to save evaluated result: abstract value and modified attribute
- class EvalResult : public Base {
- public:
- EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {}
- ~EvalResult() override = default;
- MS_DECLARE_PARENT(EvalResult, Base);
- AbstractBasePtr abstract() { return abstract_; }
- AttrValueMapPtr attribute() { return attribute_; }
-
- private:
- AbstractBasePtr abstract_;
- AttrValueMapPtr attribute_;
- };
-
- using EvalResultPtr = std::shared_ptr<EvalResult>;
- // Superclass for AnfNodeConfig and VirtualConfig.
- class Config : public Base {
- public:
- Config() = default;
- ~Config() override = default;
- MS_DECLARE_PARENT(Config, Base);
- virtual EvalResultPtr GetEvaluatedValue() = 0;
- };
-
- // Config will be stored in AnalysisCache
- using ConfigPtr = std::shared_ptr<Config>;
- using ConfigPtrList = std::vector<ConfigPtr>;
-
- // Config to a certain node in a certain context.
- class AnfNodeConfig : public Config {
- public:
- AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context)
- : Config(), engine_(std::weak_ptr<AnalysisEngine>(engine)), node_(node) {
- FuncGraphPtr fg;
- if (IsValueNode<FuncGraph>(node)) {
- auto v = node->cast<ValueNodePtr>();
- fg = v->value()->cast<FuncGraphPtr>();
- } else {
- fg = node->func_graph();
- }
- context_ = nullptr;
- if (context != nullptr) {
- context_ = context->Filter(fg);
- }
- }
-
- ~AnfNodeConfig() override = default;
- MS_DECLARE_PARENT(AnfNodeConfig, Config);
-
- EvalResultPtr GetEvaluatedValue() override;
-
- AnalysisContextPtr context() const { return context_; }
-
- AnfNodePtr node() const { return node_; }
-
- AnalysisEnginePtr engine() const { return engine_.lock(); }
-
- // used by unordered_map;
- bool operator==(const AnfNodeConfig &other) const {
- // compare node with pointer, context with pointer except DummyContext as it's created by make_shared;
- // context should not be nullptr;
- if (context_->IsDummyContext() && other.context_->IsDummyContext()) {
- return true;
- }
- return (node_ == other.node_) && (context_ == other.context_);
- }
-
- std::string ToString() const override {
- std::ostringstream buffer;
- buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString();
- return buffer.str();
- }
-
- private:
- // AnalysisEngine is global.
- // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use
- // weak_ptr to break Config cycle.
- std::weak_ptr<AnalysisEngine> engine_;
- AnfNodePtr node_;
- AnalysisContextPtr context_;
- };
-
- using AnfNodeConfigPtr = std::shared_ptr<AnfNodeConfig>;
-
- struct AnfNodeConfigHasher {
- std::size_t operator()(const AnfNodeConfigPtr conf) const;
- };
-
- struct AnfNodeConfigEqual {
- bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const;
- };
-
- class VirtualConfig : public Config {
- public:
- explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {}
-
- ~VirtualConfig() override = default;
- MS_DECLARE_PARENT(VirtualConfig, Config);
- EvalResultPtr GetEvaluatedValue() override {
- return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
- }
-
- private:
- AbstractBasePtr abstract_;
- };
-
- // AnalysisCache
- class AnalysisCache {
- public:
- AnalysisCache() = default;
- ~AnalysisCache() = default;
- void Clear() { cache_.clear(); }
- void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
- EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
-
- private:
- std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
- };
-
- using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
- using AnfNodeConfigMap =
- std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
-
- struct AnalysisResult {
- EvalResultPtr inferred;
- AnalysisContextPtr context;
- };
-
- using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator;
-
- class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
- public:
- AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
- : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {}
- ~AnalysisEngine() = default;
-
- // func_graph: The func_graph to analyze.
- // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
- AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
- EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
- // Return the Evaluator for the given function.
- EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
-
- AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
- EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
- // Infer the result of fn(args).
- EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
- void Clear();
- void ClearEvaluatorCache();
- AnalysisCache &cache() { return cache_; }
- AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) {
- return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context);
- }
- // Overloaded function.
- EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &fn);
- EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<PartialAbstractClosure> &fn);
- EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &fn);
- EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &fn);
- EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
- EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
- EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
-
- FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; }
- const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; }
-
- // Set the analysis result for orig to the result for new.
- // This sets an entry in anfnode_config_map from orig to new.
- EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
- // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
- (void)anfnode_config_map_.emplace(orig_conf, new_conf);
- MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
- << ", to new_conf: " << new_conf->node()->DebugString();
- return GetEvaluatedValue(new_conf);
- }
- const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
-
- AnalysisCache cache_;
- std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
-
- private:
- void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
- EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
- const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
- bool *continue_flag);
- EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs);
-
- const PrimEvaluatorMap &prim_constructors_;
- FuncGraphManagerPtr func_graph_manager_;
- std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_;
- AnfNodeConfigMap anfnode_config_map_;
- // Use a list to trace multiple evaluators.
- std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
- std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
-
- AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
- const ConfigPtrList &args_conf_list);
- EvalResultPtr Eval(const AnfNodeConfigPtr &conf);
- EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
- EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
- const ConfigPtrList &args_conf_list);
- EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
- const ConfigPtrList &args_conf_list);
-
- #ifdef DEBUG
- std::vector<AnfNodePtr> compute_conf_stack_;
- #endif
- };
-
- // Translate the value to an abstract value.
- // Arguments:
- // value: The value to convert.
- // context: The context in which the value was found, used if the value is a Graph.
- // conf: The Config to the valuenode we are converting, if there is one,
- // so that we can generate a tracking_id.
- AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr,
- const AnfNodeConfigPtr &conf = nullptr);
-
- // Convert a value to an abstract value.
- // Arguments:
- // v: The value to convert.
- // broaden: If True, concrete values will be made more abstract, so e.g.
- // the value 1234 would become ANYTHING.
- AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false);
-
- template <typename T>
- AbstractBasePtr FromValue(const T &value, bool broaden = false) {
- return FromValueInside(MakeValue(value), broaden);
- }
-
- EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
- } // namespace abstract
- } // namespace mindspore
-
- #endif // PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_
|