Browse Source

Return a new abstract without tracking_id for fg ValueNode in CSE.

tags/v1.1.0
Zhang Qinghua 5 years ago
parent
commit
077bde0767
5 changed files with 13 additions and 10 deletions
  1. +10
    -6
      mindspore/ccsrc/frontend/optimizer/cse.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/optimizer/cse.h
  3. +0
    -2
      mindspore/ccsrc/frontend/optimizer/cse_pass.h
  4. +0
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
  5. +2
    -0
      mindspore/core/abstract/abstract_function.h

+ 10
- 6
mindspore/ccsrc/frontend/optimizer/cse.cc View File

@@ -32,19 +32,23 @@ using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractFunctionPtr;

BasePtr AbsOf(const AnfNodePtr &node) {
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
MS_EXCEPTION_IF_NULL(node);
auto node_abs = node->abstract();
// in testcase: TestOptOpt.CSE, node->abstract() is null;
// In testcase: TestOptOpt.CSE, node->abstract() is null.
if (node_abs == nullptr) {
return kAnyValue;
}
// Ignore the tracking_id and prim pointer hash;
if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
// Ignore the tracking_id and prim pointer hash.
auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>();
return prim_abs->prim();
} else if (ignore_fg_abs_tracking_id && node_abs->isa<abstract::FuncGraphAbstractClosure>()) {
// Ignore the tracking_id.
auto new_fg_abs = node_abs->cast<abstract::AbstractFunctionPtr>()->Copy();
new_fg_abs->set_tracking_id(nullptr);
return new_fg_abs;
}

return node_abs;
}

@@ -68,7 +72,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
ValueNodePtr value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
h = hash_combine(value->hash(), (AbsOf(value_node)->hash()));
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
@@ -134,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
auto main_value = GetValueNode(main);
auto node_value = GetValueNode(node);
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value);
} else if (main->isa<CNode>() && node->isa<CNode>()) {
auto c_main = main->cast<CNodePtr>();
auto c_node = node->cast<CNodePtr>();


+ 1
- 1
mindspore/ccsrc/frontend/optimizer/cse.h View File

@@ -46,7 +46,7 @@ class CSE {
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const;
};

BasePtr AbsOf(const AnfNodePtr &node);
BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id = false);
} // namespace opt
} // namespace mindspore



+ 0
- 2
mindspore/ccsrc/frontend/optimizer/cse_pass.h View File

@@ -44,8 +44,6 @@ class CSEPass : public CSE {
private:
bool report_changes_;
};

BasePtr AbsOf(const AnfNodePtr &node);
} // namespace opt
} // namespace mindspore



+ 0
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc View File

@@ -467,7 +467,6 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
} else {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction";
}
return nullptr;
}

EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {


+ 2
- 0
mindspore/core/abstract/abstract_function.h View File

@@ -113,6 +113,8 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {

AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }

void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); }

AbstractFunctionPtr Copy() const override {
return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id());
}


Loading…
Cancel
Save