Merge pull request !3370 from xychow/fix-context-duptags/v0.7.0-beta
| @@ -36,6 +36,11 @@ BasePtr AbsOf(const AnfNodePtr &node) { | |||
| if (node_abs == nullptr) { | |||
| return kAnyValue; | |||
| } | |||
| // Ignore the tracking_id and prim pointer hash; | |||
| if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) { | |||
| auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>(); | |||
| return prim_abs->prim(); | |||
| } | |||
| return node_abs; | |||
| } | |||
| @@ -470,7 +470,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||
| MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| if (func->tracking_id() == nullptr) { | |||
| if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() || | |||
| func->isa<abstract::FuncGraphAbstractClosure>()) { | |||
| EvaluatorPtr evaluator = _GetEvaluatorFor(func); | |||
| return evaluator; | |||
| } | |||
| @@ -639,12 +640,12 @@ EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { | |||
| } | |||
| abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, | |||
| const abstract::AnalysisContextPtr &context) { | |||
| const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) { | |||
| AnalysisContextPtr temp_context = context; | |||
| if (temp_context == nullptr) { | |||
| temp_context = abstract::AnalysisContext::DummyContext(); | |||
| } | |||
| return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context); | |||
| return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, anf_node); | |||
| } | |||
| abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) { | |||
| @@ -652,7 +653,8 @@ abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_ | |||
| if (anf_node == nullptr) { | |||
| meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph); | |||
| } else { | |||
| meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node->scope()); | |||
| meta_func_graph_fn = | |||
| std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node, anf_node->scope()); | |||
| } | |||
| return meta_func_graph_fn; | |||
| } | |||
| @@ -663,14 +665,14 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con | |||
| } | |||
| AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { | |||
| if (value->isa<FuncGraph>()) { | |||
| auto func_graph = value->cast<FuncGraphPtr>(); | |||
| return MakeAbstractClosure(func_graph, context); | |||
| } | |||
| AnfNodePtr anf_node = nullptr; | |||
| if (conf != nullptr) { | |||
| anf_node = conf->node(); | |||
| } | |||
| if (value->isa<FuncGraph>()) { | |||
| auto func_graph = value->cast<FuncGraphPtr>(); | |||
| return MakeAbstractClosure(func_graph, context, anf_node); | |||
| } | |||
| if (value->isa<MetaFuncGraph>()) { | |||
| auto meta_func_graph = value->cast<MetaFuncGraphPtr>(); | |||
| return MakeAbstractClosure(meta_func_graph, anf_node); | |||
| @@ -232,7 +232,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| const PrimEvaluatorMap &prim_constructors_; | |||
| FuncGraphManagerPtr func_graph_manager_; | |||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_; | |||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_; | |||
| AnfNodeConfigMap anfnode_config_map_; | |||
| // Use a list to trace multiple evaluators. | |||
| std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_; | |||
| @@ -143,14 +143,23 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| return false; | |||
| } | |||
| std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } | |||
| std::size_t PrimitiveAbstractClosure::hash() const { | |||
| auto hash_value = hash_combine(tid(), prim_->hash()); | |||
| // Keep in sync with operator==() which compares the prim_ pointer; | |||
| hash_value = hash_combine(hash_value, std::hash<Primitive *>{}(prim_.get())); | |||
| if (tracking_id() != nullptr) { | |||
| hash_value = hash_combine(hash_value, tracking_id()->hash()); | |||
| } | |||
| return hash_value; | |||
| } | |||
| bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<FuncGraphAbstractClosure>()) { | |||
| return false; | |||
| } | |||
| auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other); | |||
| if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { | |||
| if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ && | |||
| tracking_id() == other_fg->tracking_id()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| @@ -159,9 +168,11 @@ bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| std::size_t FuncGraphAbstractClosure::hash() const { | |||
| auto hash_value = hash_combine(tid(), func_graph_->hash()); | |||
| hash_value = hash_combine(hash_value, context_->hash()); | |||
| if (tracking_id() != nullptr) { | |||
| hash_value = hash_combine(hash_value, tracking_id()->hash()); | |||
| } | |||
| return hash_value; | |||
| } | |||
| std::string FuncGraphAbstractClosure::ToString() const { | |||
| std::stringstream ss; | |||
| ss << "FuncGraphAbstractClosure: " | |||
| @@ -174,7 +185,7 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con | |||
| return false; | |||
| } | |||
| auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other); | |||
| if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { | |||
| if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| @@ -182,6 +193,9 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con | |||
| std::size_t MetaFuncGraphAbstractClosure::hash() const { | |||
| auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); | |||
| if (tracking_id() != nullptr) { | |||
| hash_value = hash_combine(hash_value, tracking_id()->hash()); | |||
| } | |||
| return hash_value; | |||
| } | |||
| @@ -92,13 +92,15 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom { | |||
| // one reference cycle example is Graph::set_output() input0 local variable. | |||
| AnfNodeWeakPtr tracking_id_; | |||
| }; | |||
| using PrimitiveAbstractClosurePtr = std::shared_ptr<PrimitiveAbstractClosure>; | |||
| class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| public: | |||
| // Represents a Graph in a certain Context. | |||
| // context: The context, or Context.empty() | |||
| FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) | |||
| : func_graph_(func_graph), context_(context) { | |||
| FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | |||
| const AnfNodePtr &tracking_id = nullptr) | |||
| : func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| } | |||
| @@ -109,8 +111,10 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| AnalysisContextPtr context() const override { return context_; } | |||
| AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } | |||
| AbstractFunctionPtr Copy() const override { | |||
| return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_); | |||
| return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id()); | |||
| } | |||
| bool operator==(const AbstractFunction &other) const override; | |||
| @@ -121,13 +125,22 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| private: | |||
| FuncGraphPtr func_graph_; | |||
| AnalysisContextPtr context_; | |||
| // To discriminate different usage of same graph by using this tracking_id, | |||
| // so different tracking_id will produce different FuncGraphAbstractClosure, | |||
| // different FuncGraphEvaluator. | |||
| // Espcecially usefull for recursive func graph call, so it will not mess up | |||
| // the graph_context_ in FuncGraphEvaluator. | |||
| // Notes: Be careful to use nullptr for this variable. | |||
| // store it as weak_ptr to break reference cycle. | |||
| AnfNodeWeakPtr tracking_id_; | |||
| }; | |||
| using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>; | |||
| class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| public: | |||
| explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope) | |||
| : meta_func_graph_(meta_func_graph), scope_(scope) {} | |||
| explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, | |||
| const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope) | |||
| : meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {} | |||
| ~MetaFuncGraphAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) | |||
| @@ -137,7 +150,11 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| ScopePtr GetScope() { return scope_; } | |||
| AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); } | |||
| AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } | |||
| AbstractFunctionPtr Copy() const override { | |||
| return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, tracking_id()); | |||
| } | |||
| bool operator==(const AbstractFunction &other) const override; | |||
| std::size_t hash() const override; | |||
| @@ -145,6 +162,9 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | |||
| private: | |||
| MetaFuncGraphPtr meta_func_graph_; | |||
| // refer the comment in FuncGraphAbstractClosure; | |||
| // store it as weak_ptr to break reference cycle. | |||
| AnfNodeWeakPtr tracking_id_; | |||
| ScopePtr scope_; | |||
| }; | |||
| using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>; | |||
| @@ -67,3 +67,62 @@ def test_assign_in_while(): | |||
| z = Tensor(np.random.randn(*input_shape).astype(np.float32)) | |||
| net = Net(input_shape) | |||
| net(x, y, z) | |||
| def test_dup_context(): | |||
| ''' different func_with_fv in net1 and net2 should produce 2 different FuncGraphAbstractClosure and | |||
| Evaluator. | |||
| ''' | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| def identity(f): | |||
| return f | |||
| def func_with_fv(): | |||
| return x | |||
| def net1(): | |||
| local_func = identity(func_with_fv) | |||
| out = local_func() + 20.0 | |||
| return out | |||
| def net2(): | |||
| local_func = identity(func_with_fv) | |||
| out = local_func() + 15.0 | |||
| return out | |||
| return net1() + net2() | |||
| Net()(5.0) | |||
| def test_maybe_poly_func(): | |||
| ''' different func_with_fv in net1 and net2 may produce poly node. ''' | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x, y, z): | |||
| def identity(f, inp): | |||
| return f(inp) | |||
| def func_with_fv(yy): | |||
| return (x, yy) | |||
| def make_call(): | |||
| out1 = identity(func_with_fv, y) | |||
| out2 = identity(func_with_fv, z) | |||
| return (out1, out2) | |||
| return make_call() | |||
| y_input = Tensor(np.array([1, 2]).astype(np.int32)) | |||
| z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32)) | |||
| Net()(1, y_input, z_input) | |||