Merge pull request !3289 from ZhangQinghua/mastertags/v0.7.0-beta
| @@ -25,7 +25,7 @@ | |||
| #include "abstract/dshape.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| #include "pipeline/jit/parse/parse.h" | |||
| #include "pipeline/jit/parse/parse_base.h" | |||
| @@ -25,7 +25,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "abstract/dshape.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "frontend/operator/cc_implementations.h" | |||
| @@ -23,7 +23,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "abstract/dshape.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "debug/trace.h" | |||
| @@ -25,7 +25,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "abstract/dshape.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "frontend/operator/cc_implementations.h" | |||
| @@ -23,7 +23,7 @@ | |||
| #include "./common.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/operator/composite/do_signature.h" | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/profile.h" | |||
| @@ -434,8 +434,30 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimiti | |||
| // Forward to specific subclass of FunctionWrapper. | |||
| EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this()); | |||
| return evaluator; | |||
| if (func->isa<PrimitiveAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<PrimitiveAbstractClosure>>()); | |||
| } else if (func->isa<FuncGraphAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<FuncGraphAbstractClosure>>()); | |||
| } else if (func->isa<MetaFuncGraphAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<MetaFuncGraphAbstractClosure>>()); | |||
| } else if (func->isa<JTransformedAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>()); | |||
| } else if (func->isa<VirtualAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<VirtualAbstractClosure>>()); | |||
| } else if (func->isa<PartialAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<PartialAbstractClosure>>()); | |||
| } else if (func->isa<TypedPrimitiveAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<TypedPrimitiveAbstractClosure>>()); | |||
| } else if (func->isa<AbstractFuncAtom>()) { | |||
| MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom"; | |||
| } else if (func->isa<AbstractFuncUnion>()) { | |||
| MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion"; | |||
| } else if (func->isa<DummyAbstractClosure>()) { | |||
| MS_LOG(EXCEPTION) << "A dummy function cannot eval"; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction"; | |||
| } | |||
| return nullptr; | |||
| } | |||
| EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||
| @@ -35,7 +35,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive_py.h" | |||
| #include "abstract/analysis_context.h" | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "pipeline/jit/parse/parse.h" | |||
| namespace mindspore { | |||
| @@ -14,12 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include <vector> | |||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| class Evaluator; | |||
| @@ -134,11 +132,6 @@ std::size_t AbstractFuncUnion::hash() const { | |||
| return hash_sum; | |||
| } | |||
| EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| return engine->_GetEvaluatorFor(shared_from_base<PrimitiveAbstractClosure>()); | |||
| } | |||
| bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<PrimitiveAbstractClosure>()) { | |||
| return false; | |||
| @@ -152,11 +145,6 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } | |||
| EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| return engine->_GetEvaluatorFor(shared_from_base<FuncGraphAbstractClosure>()); | |||
| } | |||
| bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<FuncGraphAbstractClosure>()) { | |||
| return false; | |||
| @@ -181,11 +169,6 @@ std::string FuncGraphAbstractClosure::ToString() const { | |||
| return ss.str(); | |||
| } | |||
| EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| return engine->_GetEvaluatorFor(shared_from_base<MetaFuncGraphAbstractClosure>()); | |||
| } | |||
| bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<MetaFuncGraphAbstractClosure>()) { | |||
| return false; | |||
| @@ -229,11 +212,6 @@ std::size_t PartialAbstractClosure::hash() const { | |||
| return hash_value; | |||
| } | |||
| EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| return engine->_GetEvaluatorFor(shared_from_base<PartialAbstractClosure>()); | |||
| } | |||
| std::string PartialAbstractClosure::ToString() const { | |||
| std::ostringstream buffer; | |||
| buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; | |||
| @@ -244,11 +222,6 @@ std::string PartialAbstractClosure::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| return engine->_GetEvaluatorFor(shared_from_base<JTransformedAbstractClosure>()); | |||
| } | |||
| bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<JTransformedAbstractClosure>()) { | |||
| return false; | |||
| @@ -265,11 +238,6 @@ std::size_t JTransformedAbstractClosure::hash() const { | |||
| return hash_value; | |||
| } | |||
| EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| return engine->_GetEvaluatorFor(shared_from_base<VirtualAbstractClosure>()); | |||
| } | |||
| bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<VirtualAbstractClosure>()) { | |||
| return false; | |||
| @@ -306,12 +274,6 @@ std::string VirtualAbstractClosure::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { | |||
| MS_EXCEPTION_IF_NULL(engine); | |||
| return engine->_GetEvaluatorFor(shared_from_base<TypedPrimitiveAbstractClosure>()); | |||
| } | |||
| bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<TypedPrimitiveAbstractClosure>()) { | |||
| return false; | |||
| @@ -16,8 +16,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ | |||
| #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ | |||
| #ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ | |||
| #define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| @@ -35,10 +35,6 @@ class AbstractFuncAtom : public AbstractFunction { | |||
| MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) | |||
| AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); } | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { | |||
| MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom"; | |||
| } | |||
| AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; | |||
| void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final; | |||
| bool operator==(const AbstractFunction &other) const override; | |||
| @@ -56,9 +52,6 @@ class AbstractFuncUnion : public AbstractFunction { | |||
| std::string ToString() const override; | |||
| AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { | |||
| MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion"; | |||
| } | |||
| bool IsSuperSet(const AbstractFunctionPtr &other); | |||
| AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; | |||
| void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final; | |||
| @@ -80,8 +73,6 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom { | |||
| ~PrimitiveAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; | |||
| PrimitivePtr prim() { return prim_; } | |||
| AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } | |||
| @@ -114,8 +105,6 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| ~FuncGraphAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; | |||
| FuncGraphPtr func_graph() { return func_graph_; } | |||
| AnalysisContextPtr context() const override { return context_; } | |||
| @@ -146,8 +135,6 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| AnalysisContextPtr context() const override { return kDummyAnalysisContext; } | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; | |||
| ScopePtr GetScope() { return scope_; } | |||
| AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); } | |||
| @@ -172,8 +159,6 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||
| ~PartialAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; | |||
| AbstractFunctionPtr fn() { return fn_; } | |||
| AbstractBasePtrList args() { return args_spec_list_; } | |||
| AnfNodePtr node() { return node_.lock(); } | |||
| @@ -199,7 +184,6 @@ class JTransformedAbstractClosure : public AbstractFuncAtom { | |||
| explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} | |||
| ~JTransformedAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; | |||
| AbstractFuncAtomPtr fn() { return fn_; } | |||
| AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); } | |||
| @@ -224,8 +208,6 @@ class VirtualAbstractClosure : public AbstractFuncAtom { | |||
| ~VirtualAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; | |||
| AbstractBasePtrList args_spec_list() { return args_spec_list_; } | |||
| AbstractBasePtr output() { return output_; } | |||
| @@ -254,8 +236,6 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom { | |||
| ~TypedPrimitiveAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; | |||
| PrimitivePtr prim() { return prim_; } | |||
| AbstractBasePtrList args_spec_list() { return args_spec_list_; } | |||
| AbstractBasePtr output() { return output_; } | |||
| @@ -280,8 +260,6 @@ class DummyAbstractClosure : public AbstractFuncAtom { | |||
| ~DummyAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom) | |||
| EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; } | |||
| AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); } | |||
| bool operator==(const AbstractFunction &other) const override; | |||
| @@ -300,4 +278,4 @@ struct AbstractFunctionEqual { | |||
| }; | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ | |||
| #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ | |||
| @@ -193,7 +193,6 @@ class AbstractFunction : public AbstractBase { | |||
| static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); | |||
| virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0; | |||
| virtual AnfNodePtr tracking_id() const { return nullptr; } | |||
| virtual void set_tracking_id(AnfNodePtr) {} | |||
| virtual AnalysisContextPtr context() const { return nullptr; } | |||
| @@ -21,7 +21,7 @@ | |||
| #include "frontend/operator/composite/composite.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "debug/trace.h" | |||
| namespace mindspore { | |||