Browse Source

!3289 Move abstract function to core abstract folder.

Merge pull request !3289 from ZhangQinghua/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a4f447af6c
11 changed files with 35 additions and 74 deletions
  1. +1
    -1
      mindspore/ccsrc/common.h
  2. +1
    -1
      mindspore/ccsrc/frontend/operator/composite/composite.cc
  3. +1
    -1
      mindspore/ccsrc/frontend/operator/composite/map.cc
  4. +1
    -1
      mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc
  5. +1
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc
  6. +24
    -2
      mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
  7. +1
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h
  8. +1
    -39
      mindspore/core/abstract/abstract_function.cc
  9. +3
    -25
      mindspore/core/abstract/abstract_function.h
  10. +0
    -1
      mindspore/core/abstract/abstract_value.h
  11. +1
    -1
      tests/ut/cpp/operator/composite_test.cc

+ 1
- 1
mindspore/ccsrc/common.h View File

@@ -25,7 +25,7 @@

#include "abstract/dshape.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "abstract/abstract_function.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/parse/parse.h"
#include "pipeline/jit/parse/parse_base.h"


+ 1
- 1
mindspore/ccsrc/frontend/operator/composite/composite.cc View File

@@ -25,7 +25,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "abstract/abstract_function.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "frontend/operator/cc_implementations.h"


+ 1
- 1
mindspore/ccsrc/frontend/operator/composite/map.cc View File

@@ -23,7 +23,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "abstract/abstract_function.h"
#include "abstract/dshape.h"
#include "pybind_api/api_register.h"
#include "debug/trace.h"


+ 1
- 1
mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc View File

@@ -25,7 +25,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "abstract/abstract_function.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "frontend/operator/cc_implementations.h"


+ 1
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc View File

@@ -23,7 +23,7 @@
#include "./common.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/do_signature.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "abstract/abstract_function.h"
#include "utils/graph_utils.h"
#include "utils/log_adapter.h"
#include "utils/profile.h"


+ 24
- 2
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc View File

@@ -434,8 +434,30 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimiti
// Forward to specific subclass of FunctionWrapper.
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
MS_EXCEPTION_IF_NULL(func);
EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this());
return evaluator;
if (func->isa<PrimitiveAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<PrimitiveAbstractClosure>>());
} else if (func->isa<FuncGraphAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<FuncGraphAbstractClosure>>());
} else if (func->isa<MetaFuncGraphAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<MetaFuncGraphAbstractClosure>>());
} else if (func->isa<JTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
} else if (func->isa<VirtualAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<VirtualAbstractClosure>>());
} else if (func->isa<PartialAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<PartialAbstractClosure>>());
} else if (func->isa<TypedPrimitiveAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<TypedPrimitiveAbstractClosure>>());
} else if (func->isa<AbstractFuncAtom>()) {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
} else if (func->isa<AbstractFuncUnion>()) {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
} else if (func->isa<DummyAbstractClosure>()) {
MS_LOG(EXCEPTION) << "A dummy function cannot eval";
} else {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction";
}
return nullptr;
}

EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {


+ 1
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h View File

@@ -35,7 +35,7 @@
#include "ir/anf.h"
#include "ir/primitive_py.h"
#include "abstract/analysis_context.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "abstract/abstract_function.h"
#include "pipeline/jit/parse/parse.h"

namespace mindspore {


mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc → mindspore/core/abstract/abstract_function.cc View File

@@ -14,12 +14,10 @@
* limitations under the License.
*/

#include "pipeline/jit/static_analysis/abstract_function.h"
#include "abstract/abstract_function.h"

#include <vector>

#include "pipeline/jit/static_analysis/static_analysis.h"

namespace mindspore {
namespace abstract {
class Evaluator;
@@ -134,11 +132,6 @@ std::size_t AbstractFuncUnion::hash() const {
return hash_sum;
}

EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<PrimitiveAbstractClosure>());
}

bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<PrimitiveAbstractClosure>()) {
return false;
@@ -152,11 +145,6 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {

std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); }

EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<FuncGraphAbstractClosure>());
}

bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<FuncGraphAbstractClosure>()) {
return false;
@@ -181,11 +169,6 @@ std::string FuncGraphAbstractClosure::ToString() const {
return ss.str();
}

EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<MetaFuncGraphAbstractClosure>());
}

bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<MetaFuncGraphAbstractClosure>()) {
return false;
@@ -229,11 +212,6 @@ std::size_t PartialAbstractClosure::hash() const {
return hash_value;
}

EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<PartialAbstractClosure>());
}

std::string PartialAbstractClosure::ToString() const {
std::ostringstream buffer;
buffer << "PartialAbstractClosure(" << fn_->ToString() << "(";
@@ -244,11 +222,6 @@ std::string PartialAbstractClosure::ToString() const {
return buffer.str();
}

EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<JTransformedAbstractClosure>());
}

bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<JTransformedAbstractClosure>()) {
return false;
@@ -265,11 +238,6 @@ std::size_t JTransformedAbstractClosure::hash() const {
return hash_value;
}

EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<VirtualAbstractClosure>());
}

bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<VirtualAbstractClosure>()) {
return false;
@@ -306,12 +274,6 @@ std::string VirtualAbstractClosure::ToString() const {
return buffer.str();
}

EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);

return engine->_GetEvaluatorFor(shared_from_base<TypedPrimitiveAbstractClosure>());
}

bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<TypedPrimitiveAbstractClosure>()) {
return false;

mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h → mindspore/core/abstract/abstract_function.h View File

@@ -16,8 +16,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
#ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
#define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_

#include <memory>
#include <string>
@@ -35,10 +35,6 @@ class AbstractFuncAtom : public AbstractFunction {
MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction)

AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); }
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
}

AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
bool operator==(const AbstractFunction &other) const override;
@@ -56,9 +52,6 @@ class AbstractFuncUnion : public AbstractFunction {
std::string ToString() const override;

AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; }
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
}
bool IsSuperSet(const AbstractFunctionPtr &other);
AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
@@ -80,8 +73,6 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom {
~PrimitiveAbstractClosure() override = default;
MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom)

EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;

PrimitivePtr prim() { return prim_; }

AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
@@ -114,8 +105,6 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
~FuncGraphAbstractClosure() override = default;
MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom)

EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;

FuncGraphPtr func_graph() { return func_graph_; }

AnalysisContextPtr context() const override { return context_; }
@@ -146,8 +135,6 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {

AnalysisContextPtr context() const override { return kDummyAnalysisContext; }

EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;

ScopePtr GetScope() { return scope_; }

AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); }
@@ -172,8 +159,6 @@ class PartialAbstractClosure : public AbstractFuncAtom {
~PartialAbstractClosure() override = default;
MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)

EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;

AbstractFunctionPtr fn() { return fn_; }
AbstractBasePtrList args() { return args_spec_list_; }
AnfNodePtr node() { return node_.lock(); }
@@ -199,7 +184,6 @@ class JTransformedAbstractClosure : public AbstractFuncAtom {
explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
~JTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;

AbstractFuncAtomPtr fn() { return fn_; }
AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); }
@@ -224,8 +208,6 @@ class VirtualAbstractClosure : public AbstractFuncAtom {
~VirtualAbstractClosure() override = default;
MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom)

EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;

AbstractBasePtrList args_spec_list() { return args_spec_list_; }

AbstractBasePtr output() { return output_; }
@@ -254,8 +236,6 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
~TypedPrimitiveAbstractClosure() override = default;
MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom)

EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;

PrimitivePtr prim() { return prim_; }
AbstractBasePtrList args_spec_list() { return args_spec_list_; }
AbstractBasePtr output() { return output_; }
@@ -280,8 +260,6 @@ class DummyAbstractClosure : public AbstractFuncAtom {
~DummyAbstractClosure() override = default;
MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom)

EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; }

AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); }
bool operator==(const AbstractFunction &other) const override;

@@ -300,4 +278,4 @@ struct AbstractFunctionEqual {
};
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_

+ 0
- 1
mindspore/core/abstract/abstract_value.h View File

@@ -193,7 +193,6 @@ class AbstractFunction : public AbstractBase {

static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list);

virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0;
virtual AnfNodePtr tracking_id() const { return nullptr; }
virtual void set_tracking_id(AnfNodePtr) {}
virtual AnalysisContextPtr context() const { return nullptr; }


+ 1
- 1
tests/ut/cpp/operator/composite_test.cc View File

@@ -21,7 +21,7 @@
#include "frontend/operator/composite/composite.h"
#include "frontend/operator/ops.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "abstract/abstract_function.h"
#include "debug/trace.h"

namespace mindspore {


Loading…
Cancel
Save