Merge pull request !3370 from xychow/fix-context-duptags/v0.7.0-beta
| @@ -36,6 +36,11 @@ BasePtr AbsOf(const AnfNodePtr &node) { | |||||
| if (node_abs == nullptr) { | if (node_abs == nullptr) { | ||||
| return kAnyValue; | return kAnyValue; | ||||
| } | } | ||||
| // Ignore the tracking_id and prim pointer hash; | |||||
| if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) { | |||||
| auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>(); | |||||
| return prim_abs->prim(); | |||||
| } | |||||
| return node_abs; | return node_abs; | ||||
| } | } | ||||
| @@ -470,7 +470,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||||
| MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); | MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(func); | MS_EXCEPTION_IF_NULL(func); | ||||
| if (func->tracking_id() == nullptr) { | |||||
| if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() || | |||||
| func->isa<abstract::FuncGraphAbstractClosure>()) { | |||||
| EvaluatorPtr evaluator = _GetEvaluatorFor(func); | EvaluatorPtr evaluator = _GetEvaluatorFor(func); | ||||
| return evaluator; | return evaluator; | ||||
| } | } | ||||
| @@ -639,12 +640,12 @@ EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { | |||||
| } | } | ||||
| abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, | abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, | ||||
| const abstract::AnalysisContextPtr &context) { | |||||
| const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) { | |||||
| AnalysisContextPtr temp_context = context; | AnalysisContextPtr temp_context = context; | ||||
| if (temp_context == nullptr) { | if (temp_context == nullptr) { | ||||
| temp_context = abstract::AnalysisContext::DummyContext(); | temp_context = abstract::AnalysisContext::DummyContext(); | ||||
| } | } | ||||
| return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context); | |||||
| return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, anf_node); | |||||
| } | } | ||||
| abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) { | abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) { | ||||
| @@ -652,7 +653,8 @@ abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_ | |||||
| if (anf_node == nullptr) { | if (anf_node == nullptr) { | ||||
| meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph); | meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph); | ||||
| } else { | } else { | ||||
| meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node->scope()); | |||||
| meta_func_graph_fn = | |||||
| std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node, anf_node->scope()); | |||||
| } | } | ||||
| return meta_func_graph_fn; | return meta_func_graph_fn; | ||||
| } | } | ||||
| @@ -663,14 +665,14 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con | |||||
| } | } | ||||
| AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { | AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { | ||||
| if (value->isa<FuncGraph>()) { | |||||
| auto func_graph = value->cast<FuncGraphPtr>(); | |||||
| return MakeAbstractClosure(func_graph, context); | |||||
| } | |||||
| AnfNodePtr anf_node = nullptr; | AnfNodePtr anf_node = nullptr; | ||||
| if (conf != nullptr) { | if (conf != nullptr) { | ||||
| anf_node = conf->node(); | anf_node = conf->node(); | ||||
| } | } | ||||
| if (value->isa<FuncGraph>()) { | |||||
| auto func_graph = value->cast<FuncGraphPtr>(); | |||||
| return MakeAbstractClosure(func_graph, context, anf_node); | |||||
| } | |||||
| if (value->isa<MetaFuncGraph>()) { | if (value->isa<MetaFuncGraph>()) { | ||||
| auto meta_func_graph = value->cast<MetaFuncGraphPtr>(); | auto meta_func_graph = value->cast<MetaFuncGraphPtr>(); | ||||
| return MakeAbstractClosure(meta_func_graph, anf_node); | return MakeAbstractClosure(meta_func_graph, anf_node); | ||||
| @@ -232,7 +232,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| const PrimEvaluatorMap &prim_constructors_; | const PrimEvaluatorMap &prim_constructors_; | ||||
| FuncGraphManagerPtr func_graph_manager_; | FuncGraphManagerPtr func_graph_manager_; | ||||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_; | |||||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_; | |||||
| AnfNodeConfigMap anfnode_config_map_; | AnfNodeConfigMap anfnode_config_map_; | ||||
| // Use a list to trace multiple evaluators. | // Use a list to trace multiple evaluators. | ||||
| std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_; | std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_; | ||||
| @@ -143,14 +143,23 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } | |||||
| std::size_t PrimitiveAbstractClosure::hash() const { | |||||
| auto hash_value = hash_combine(tid(), prim_->hash()); | |||||
| // Keep in sync with operator==() which compares the prim_ pointer; | |||||
| hash_value = hash_combine(hash_value, std::hash<Primitive *>{}(prim_.get())); | |||||
| if (tracking_id() != nullptr) { | |||||
| hash_value = hash_combine(hash_value, tracking_id()->hash()); | |||||
| } | |||||
| return hash_value; | |||||
| } | |||||
| bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { | bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { | ||||
| if (!other.isa<FuncGraphAbstractClosure>()) { | if (!other.isa<FuncGraphAbstractClosure>()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other); | auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other); | ||||
| if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { | |||||
| if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ && | |||||
| tracking_id() == other_fg->tracking_id()) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| @@ -159,9 +168,11 @@ bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { | |||||
| std::size_t FuncGraphAbstractClosure::hash() const { | std::size_t FuncGraphAbstractClosure::hash() const { | ||||
| auto hash_value = hash_combine(tid(), func_graph_->hash()); | auto hash_value = hash_combine(tid(), func_graph_->hash()); | ||||
| hash_value = hash_combine(hash_value, context_->hash()); | hash_value = hash_combine(hash_value, context_->hash()); | ||||
| if (tracking_id() != nullptr) { | |||||
| hash_value = hash_combine(hash_value, tracking_id()->hash()); | |||||
| } | |||||
| return hash_value; | return hash_value; | ||||
| } | } | ||||
| std::string FuncGraphAbstractClosure::ToString() const { | std::string FuncGraphAbstractClosure::ToString() const { | ||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << "FuncGraphAbstractClosure: " | ss << "FuncGraphAbstractClosure: " | ||||
| @@ -174,7 +185,7 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other); | auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other); | ||||
| if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { | |||||
| if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| @@ -182,6 +193,9 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con | |||||
| std::size_t MetaFuncGraphAbstractClosure::hash() const { | std::size_t MetaFuncGraphAbstractClosure::hash() const { | ||||
| auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); | auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); | ||||
| if (tracking_id() != nullptr) { | |||||
| hash_value = hash_combine(hash_value, tracking_id()->hash()); | |||||
| } | |||||
| return hash_value; | return hash_value; | ||||
| } | } | ||||
| @@ -92,13 +92,15 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom { | |||||
| // one reference cycle example is Graph::set_output() input0 local variable. | // one reference cycle example is Graph::set_output() input0 local variable. | ||||
| AnfNodeWeakPtr tracking_id_; | AnfNodeWeakPtr tracking_id_; | ||||
| }; | }; | ||||
| using PrimitiveAbstractClosurePtr = std::shared_ptr<PrimitiveAbstractClosure>; | |||||
| class FuncGraphAbstractClosure : public AbstractFuncAtom { | class FuncGraphAbstractClosure : public AbstractFuncAtom { | ||||
| public: | public: | ||||
| // Represents a Graph in a certain Context. | // Represents a Graph in a certain Context. | ||||
| // context: The context, or Context.empty() | // context: The context, or Context.empty() | ||||
| FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) | |||||
| : func_graph_(func_graph), context_(context) { | |||||
| FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | |||||
| const AnfNodePtr &tracking_id = nullptr) | |||||
| : func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| } | } | ||||
| @@ -109,8 +111,10 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||||
| AnalysisContextPtr context() const override { return context_; } | AnalysisContextPtr context() const override { return context_; } | ||||
| AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } | |||||
| AbstractFunctionPtr Copy() const override { | AbstractFunctionPtr Copy() const override { | ||||
| return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_); | |||||
| return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id()); | |||||
| } | } | ||||
| bool operator==(const AbstractFunction &other) const override; | bool operator==(const AbstractFunction &other) const override; | ||||
| @@ -121,13 +125,22 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||||
| private: | private: | ||||
| FuncGraphPtr func_graph_; | FuncGraphPtr func_graph_; | ||||
| AnalysisContextPtr context_; | AnalysisContextPtr context_; | ||||
| // To discriminate different usage of same graph by using this tracking_id, | |||||
| // so different tracking_id will produce different FuncGraphAbstractClosure, | |||||
| // different FuncGraphEvaluator. | |||||
| // Espcecially usefull for recursive func graph call, so it will not mess up | |||||
| // the graph_context_ in FuncGraphEvaluator. | |||||
| // Notes: Be careful to use nullptr for this variable. | |||||
| // store it as weak_ptr to break reference cycle. | |||||
| AnfNodeWeakPtr tracking_id_; | |||||
| }; | }; | ||||
| using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>; | using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>; | ||||
| class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | ||||
| public: | public: | ||||
| explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope) | |||||
| : meta_func_graph_(meta_func_graph), scope_(scope) {} | |||||
| explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, | |||||
| const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope) | |||||
| : meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {} | |||||
| ~MetaFuncGraphAbstractClosure() override = default; | ~MetaFuncGraphAbstractClosure() override = default; | ||||
| MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) | MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) | ||||
| @@ -137,7 +150,11 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | |||||
| ScopePtr GetScope() { return scope_; } | ScopePtr GetScope() { return scope_; } | ||||
| AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); } | |||||
| AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } | |||||
| AbstractFunctionPtr Copy() const override { | |||||
| return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, tracking_id()); | |||||
| } | |||||
| bool operator==(const AbstractFunction &other) const override; | bool operator==(const AbstractFunction &other) const override; | ||||
| std::size_t hash() const override; | std::size_t hash() const override; | ||||
| @@ -145,6 +162,9 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | |||||
| private: | private: | ||||
| MetaFuncGraphPtr meta_func_graph_; | MetaFuncGraphPtr meta_func_graph_; | ||||
| // refer the comment in FuncGraphAbstractClosure; | |||||
| // store it as weak_ptr to break reference cycle. | |||||
| AnfNodeWeakPtr tracking_id_; | |||||
| ScopePtr scope_; | ScopePtr scope_; | ||||
| }; | }; | ||||
| using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>; | using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>; | ||||
| @@ -67,3 +67,62 @@ def test_assign_in_while(): | |||||
| z = Tensor(np.random.randn(*input_shape).astype(np.float32)) | z = Tensor(np.random.randn(*input_shape).astype(np.float32)) | ||||
| net = Net(input_shape) | net = Net(input_shape) | ||||
| net(x, y, z) | net(x, y, z) | ||||
| def test_dup_context(): | |||||
| ''' different func_with_fv in net1 and net2 should produce 2 different FuncGraphAbstractClosure and | |||||
| Evaluator. | |||||
| ''' | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def construct(self, x): | |||||
| def identity(f): | |||||
| return f | |||||
| def func_with_fv(): | |||||
| return x | |||||
| def net1(): | |||||
| local_func = identity(func_with_fv) | |||||
| out = local_func() + 20.0 | |||||
| return out | |||||
| def net2(): | |||||
| local_func = identity(func_with_fv) | |||||
| out = local_func() + 15.0 | |||||
| return out | |||||
| return net1() + net2() | |||||
| Net()(5.0) | |||||
| def test_maybe_poly_func(): | |||||
| ''' different func_with_fv in net1 and net2 may produce poly node. ''' | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def construct(self, x, y, z): | |||||
| def identity(f, inp): | |||||
| return f(inp) | |||||
| def func_with_fv(yy): | |||||
| return (x, yy) | |||||
| def make_call(): | |||||
| out1 = identity(func_with_fv, y) | |||||
| out2 = identity(func_with_fv, z) | |||||
| return (out1, out2) | |||||
| return make_call() | |||||
| y_input = Tensor(np.array([1, 2]).astype(np.int32)) | |||||
| z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32)) | |||||
| Net()(1, y_input, z_input) | |||||