Browse Source

Reuse the flags if the tuple node visited before.

feature/build-system-rewrite
Zhang Qinghua 4 years ago
parent
commit
198b79a24d
10 changed files with 44 additions and 39 deletions
  1. +1
    -1
      mindspore/ccsrc/debug/anf_ir_dump.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.h
  3. +1
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
  4. +14
    -20
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  5. +4
    -4
      mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc
  6. +3
    -3
      mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
  7. +3
    -1
      mindspore/ccsrc/runtime/framework/control_node_parser.cc
  8. +2
    -2
      mindspore/core/abstract/abstract_value.cc
  9. +3
    -0
      mindspore/core/abstract/prim_structures.cc
  10. +12
    -6
      mindspore/core/ir/anf.cc

+ 1
- 1
mindspore/ccsrc/debug/anf_ir_dump.cc View File

@@ -76,7 +76,7 @@ void PrintTupleNodeUsedFlags(std::ostringstream &buffer, const abstract::Abstrac
buffer << "node={" << node->DebugString();
auto flags = GetSequenceNodeElementsUseFlags(node);
if (flags != nullptr) {
buffer << ", elements_use_flags=" << (*flags) << "}";
buffer << ", elements_use_flags: {ptr: " << flags << ", value: " << (*flags) << "}";
}
}
if (i != sequence_abs->sequence_nodes().size() - 1) {


+ 1
- 1
mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.h View File

@@ -25,7 +25,7 @@ class EliminateDeadNodePass {
EliminateDeadNodePass() = default;
~EliminateDeadNodePass() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
return false;


+ 1
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc View File

@@ -396,7 +396,7 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
MS_LOG(DEBUG) << "[" << this << "/" << evaluator_name
<< "] cache hit. result: " << eval_result->abstract()->ToString() << ", args: " << args_spec_list;
// Update inputs sequence nodes info, if matched in cache.
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
for (size_t i = 0; i < args_spec_list.size(); ++i) {


+ 14
- 20
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -1583,20 +1583,17 @@ class MakeTupleEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT(MakeTupleEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!args_spec_list.empty()) {
if (enable_eliminate_unused_element) {
for (auto &arg : args_spec_list) {
SetSequenceElementsUseFlags(arg, true);
}
}
} else {
if (args_spec_list.empty()) {
MS_LOG(INFO) << "For MakeTuple, the inputs should not be empty. node: " << out_conf->node()->DebugString();
}

static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
if (flags == nullptr) {
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
}
}
AnfNodeWeakPtrList sequence_nodes =
(enable_eliminate_unused_element ? AnfNodeWeakPtrList({AnfNodeWeakPtr(out_conf->node())}) : AnfNodeWeakPtrList());
@@ -1614,20 +1611,17 @@ class MakeListEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT(MakeListEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!args_spec_list.empty()) {
if (enable_eliminate_unused_element) {
for (auto &arg : args_spec_list) {
SetSequenceElementsUseFlags(arg, true);
}
}
} else {
if (args_spec_list.empty()) {
MS_LOG(INFO) << "For MakeList, the inputs should not be empty. node: " << out_conf->node()->DebugString();
}

static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
if (flags == nullptr) {
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
}
}
AnfNodeWeakPtrList sequence_nodes =
(enable_eliminate_unused_element ? AnfNodeWeakPtrList({AnfNodeWeakPtr(out_conf->node())}) : AnfNodeWeakPtrList());


+ 4
- 4
mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc View File

@@ -79,7 +79,7 @@ FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisConte
}
auto res = SpecializeFuncGraph(fg, context);
// Call PurifyElements() to purify tuple/list elements.
static const auto only_mark_unused_element = common::GetEnv("MS_DEV_ONLY_MARK_SEQUENCE_UNUSED_ELEMENT");
static const auto only_mark_unused_element = common::GetEnv("MS_DEV_DDE_ONLY_MARK");
static const auto enable_only_mark_unused_element = (only_mark_unused_element == "1");
if (!enable_only_mark_unused_element) {
for (auto &sequence_abs : sequence_abstract_list_) {
@@ -592,7 +592,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return;
}
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
auto attrs = conf->ObtainEvalResult()->attribute();
auto c_old = node->cast<CNodePtr>();
@@ -1010,9 +1010,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
// Set the updated inputs.
cnode->set_inputs(new_inputs);

static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
static const auto only_mark_unused_element = common::GetEnv("MS_DEV_ONLY_MARK_SEQUENCE_UNUSED_ELEMENT");
static const auto only_mark_unused_element = common::GetEnv("MS_DEV_DDE_ONLY_MARK");
static const auto enable_only_mark_unused_element = (only_mark_unused_element == "1");
if (enable_eliminate_unused_element && !enable_only_mark_unused_element) {
EliminateUnusedSequenceItem(cnode);


+ 3
- 3
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc View File

@@ -177,7 +177,7 @@ void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const E
MS_LOG(DEBUG) << "Found previous result for NodeConfig: " << conf->ToString()
<< ", result: " << iter->second->abstract().get() << "/" << iter->second->abstract()->ToString();
// Update sequence nodes info, if matched in cache.
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
auto new_sequence = dyn_cast<AbstractTuple>(result->abstract());
@@ -875,7 +875,7 @@ AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
new_elements.push_back(orig_elements[i]);
}
}
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
AbstractBasePtr new_abs;
if (orig_abs->isa<AbstractTuple>()) {
@@ -1112,7 +1112,7 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont
auto prim = value->cast<PrimitivePtr>();
return MakeAbstractClosure(prim, anf_node);
}
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element && value->isa<ValueSequence>()) {
auto abs = value->ToAbstract();


+ 3
- 1
mindspore/ccsrc/runtime/framework/control_node_parser.cc View File

@@ -1564,7 +1564,9 @@ void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &
continue;
}
std::vector<FuncGraphPtr> func_graphs;
if (common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT") == "1") {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
func_graphs = GetFuncGraphs(control_node->cast<CNodePtr>()->input(0));
} else {
func_graphs = func_graph_analyzer->GetCallerFuncGraphs(control_node);


+ 2
- 2
mindspore/core/abstract/abstract_value.cc View File

@@ -273,7 +273,7 @@ std::string AbstractSequence::ToString() const {
}
auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
if (flags != nullptr) {
ss << ", elements_use_flags: " << (*flags);
ss << ", elements_use_flags: {ptr: " << flags << ", value: " << (*flags) << "}";
}
if (i != sequence_nodes_.size() - 1) {
ss << ", ";
@@ -364,7 +364,7 @@ void SynchronizeSequenceNodesElementsUseFlagsInner(const AnfNodeWeakPtrList &seq

AnfNodeWeakPtrList AbstractSequence::SequenceNodesJoin(const AbstractBasePtr &other) {
AnfNodeWeakPtrList sequence_nodes;
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!enable_eliminate_unused_element) {
return sequence_nodes;


+ 3
- 0
mindspore/core/abstract/prim_structures.cc View File

@@ -136,6 +136,7 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
} else {
index_unsigned_value = LongToSize(index_int64_value + SizeToLong(nelems));
}
MS_LOG(DEBUG) << "GetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
SetSequenceElementsUseFlags(queue, index_unsigned_value, true);
return queue->elements()[index_unsigned_value];
}
@@ -165,6 +166,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
size_t index_unsigned_value = LongToSize(index_positive_value);
constexpr int target_value_index = 2;
elements[index_unsigned_value] = args_spec_list[target_value_index];
MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString();
SetSequenceElementsUseFlags(queue, index_unsigned_value, true);
return std::make_shared<T>(elements, queue->sequence_nodes());
}
@@ -298,6 +300,7 @@ AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(item);
auto new_list = AbstractBasePtrList(list->elements());
new_list.emplace_back(item);
MS_LOG(DEBUG) << "ListAppend, new size: " << new_list.size() << ", for " << list->ToString();
return std::make_shared<AbstractList>(new_list, list->sequence_nodes());
}



+ 12
- 6
mindspore/core/ir/anf.cc View File

@@ -638,7 +638,7 @@ bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set)

// Set the sequence nodes' elements use flags to 'new_flag' at specific 'index' position.
void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, std::size_t index, bool new_flag) {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!enable_eliminate_unused_element) {
return;
@@ -658,16 +658,22 @@ void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, std::size_t index,
continue;
}
auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
if (flags != nullptr) {
(*flags)[index] = new_flag;
MS_LOG(DEBUG) << "Set item[" << index << "] use flag as " << new_flag << ", for " << sequence_node->DebugString();
if (flags == nullptr) {
continue;
}
if (index >= flags->size()) {
MS_LOG(ERROR) << "The index " << index << " is out of range, size is " << flags->size() << ", for "
<< sequence_node->DebugString();
return;
}
(*flags)[index] = new_flag;
MS_LOG(DEBUG) << "Set item[" << index << "] use flag as " << new_flag << ", for " << sequence_node->DebugString();
}
}

// Set the sequence nodes' elements use flags all to 'new_flag'.
void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag) {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!enable_eliminate_unused_element) {
return;
@@ -697,7 +703,7 @@ void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag) {

// Set the sequence nodes' elements use flags all to 'new_flag' recursively.
void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag) {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ENABLE_DDE");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!enable_eliminate_unused_element) {
return;


Loading…
Cancel
Save