Browse Source

Eliminate unused nodes in a tuple or list node.

tags/v1.6.0
Zhang Qinghua 4 years ago
parent
commit
d2572719f7
27 changed files with 1007 additions and 344 deletions
  1. +34
    -2
      mindspore/ccsrc/debug/anf_ir_dump.cc
  2. +97
    -57
      mindspore/ccsrc/frontend/operator/composite/composite.cc
  3. +6
    -0
      mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.h
  4. +8
    -8
      mindspore/ccsrc/pipeline/jit/action.cc
  5. +10
    -10
      mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc
  6. +4
    -3
      mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
  7. +69
    -3
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  8. +50
    -50
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
  9. +199
    -44
      mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc
  10. +6
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h
  11. +32
    -13
      mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
  12. +1
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h
  13. +189
    -31
      mindspore/core/abstract/abstract_value.cc
  14. +85
    -15
      mindspore/core/abstract/abstract_value.h
  15. +2
    -0
      mindspore/core/abstract/prim_arrays.cc
  16. +9
    -0
      mindspore/core/abstract/prim_others.cc
  17. +54
    -27
      mindspore/core/abstract/prim_structures.cc
  18. +1
    -0
      mindspore/core/base/base.h
  19. +50
    -0
      mindspore/core/ir/anf.cc
  20. +13
    -1
      mindspore/core/ir/anf.h
  21. +2
    -2
      mindspore/core/ir/func_graph_cloner.cc
  22. +1
    -1
      mindspore/core/ir/func_graph_cloner.h
  23. +4
    -1
      mindspore/core/ops/addn.cc
  24. +14
    -8
      tests/ut/cpp/operator/composite_test.cc
  25. +7
    -7
      tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc
  26. +48
    -48
      tests/ut/cpp/pipeline/static_analysis/prim_test.cc
  27. +12
    -12
      tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc

+ 34
- 2
mindspore/ccsrc/debug/anf_ir_dump.cc View File

@@ -59,16 +59,46 @@ void PrintKernelFormatAndType(std::ostringstream &buffer, const std::string &fmt
buffer << ">";
}

void PrintTupleNodeUsedFlags(std::ostringstream &buffer, const abstract::AbstractSequencePtr &sequence_abs) {
if (sequence_abs == nullptr || sequence_abs->sequence_nodes().empty()) {
return;
}

buffer << ", sequence_nodes={";
for (size_t i = 0; i < sequence_abs->sequence_nodes().size(); ++i) {
auto node = sequence_abs->sequence_nodes()[i].lock();
if (node == nullptr) {
MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
buffer << "node={<freed node>}";
} else {
buffer << "node={" << node->DebugString();
auto flags = GetSequenceNodeElementsUseFlags(node);
if (flags != nullptr) {
buffer << ", elements_use_flags=" << (*flags) << "}";
}
}
if (i != sequence_abs->sequence_nodes().size() - 1) {
buffer << ", ";
}
}
buffer << "}";
}

void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) {
if (node == nullptr) {
return;
}

ValuePtr tensor_value = nullptr;
abstract::AbstractSequencePtr sequence_abs = nullptr;
auto abstract = node->abstract();
if (abstract != nullptr && abstract->isa<abstract::AbstractTensor>()) {
tensor_value = abstract->BuildValue();
if (abstract != nullptr) {
if (abstract->isa<abstract::AbstractTensor>()) {
tensor_value = abstract->BuildValue();
}
sequence_abs = dyn_cast<abstract::AbstractSequence>(abstract);
}

abstract::ShapePtr shape = dyn_cast<abstract::Shape>(node->Shape());
TypePtr type = dyn_cast<Type>(node->Type());
if ((shape != nullptr) && (type != nullptr)) {
@@ -76,12 +106,14 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &node) {
if (tensor_value != nullptr && tensor_value != kAnyValue) {
buffer << ", value=...";
}
PrintTupleNodeUsedFlags(buffer, sequence_abs);
buffer << ">";
} else if (type != nullptr) {
buffer << "<" << type;
if (tensor_value != nullptr && tensor_value != kAnyValue) {
buffer << ", value=...";
}
PrintTupleNodeUsedFlags(buffer, sequence_abs);
buffer << ">";
} else {
buffer << "<null>";


+ 97
- 57
mindspore/ccsrc/frontend/operator/composite/composite.cc View File

@@ -244,7 +244,14 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
inputs.emplace_back(call_node);
}
}
return func_graph->NewCNodeInOrder(inputs);

if (inputs.size() > 1) {
return func_graph->NewCNodeInOrder(inputs);
}
// Empty tuple.
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
return empty_tuple;
}

AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
@@ -452,86 +459,119 @@ bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bo
CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
}

namespace {
void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue,
const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) {
if (pos == nullptr) {
MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
}
AnfNodePtr tuple_parameter = res->add_parameter();
std::vector<AnfNodePtr> pos_elements;
PrimitivePtr pos_op = nullptr;
if (pos->isa<AbstractTuple>()) {
pos_elements.push_back(NewValueNode(prim::kPrimMakeTuple));
pos_op = prim::kPrimTupleGetItem;
} else {
pos_elements.push_back(NewValueNode(prim::kPrimMakeList));
pos_op = prim::kPrimListGetItem;
}
AnfNodePtr pos_value = nullptr;
AnfNodePtr pos_value_adjust = nullptr;
auto pos_parameter = res->add_parameter();
if (pos->size() == 1) {
pos_value = res->NewCNode({NewValueNode(pos_op), pos_parameter, NewValueNode(SizeToLong(0))});
pos_value_adjust = res->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad)) {
res->set_output(res->NewCNode({NewValueNode(pos_op), tuple_parameter, pos_value_adjust}));
} else {
res->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
}
} else {
for (size_t i = 0; i < pos->size(); ++i) {
pos_value = res->NewCNode({NewValueNode(pos_op), pos_parameter, NewValueNode(SizeToLong(i))});
pos_value_adjust = res->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
pos_elements.push_back(res->NewCNodeInOrder({NewValueNode(pos_op), tuple_parameter, pos_value_adjust}));
}
if (pos_elements.size() > 1) {
res->set_output(res->NewCNodeInOrder(pos_elements));
} else if (pos->isa<AbstractTuple>()) { // Empty tuple.
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
res->set_output(empty_tuple);
} else { // Empty list.
auto empty_list_value = std::make_shared<ValueList>(ValuePtrList());
auto empty_list = NewValueNode(empty_list_value);
res->set_output(empty_list);
}
}
}
} // namespace

FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr &sequeue,
const abstract::AbstractSequencePtr &pos) const {
MS_EXCEPTION_IF_NULL(sequeue);

FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("tail");
AnfNodePtr ptrTup = ret->add_parameter();

std::vector<AnfNodePtr> elems;
PrimitivePtr op = nullptr;
if (sequeue->isa<AbstractTuple>()) {
elems.push_back(NewValueNode(prim::kPrimMakeTuple));
op = prim::kPrimTupleGetItem;
} else {
elems.push_back(NewValueNode(prim::kPrimMakeList));
op = prim::kPrimListGetItem;
}
FuncGraphPtr res = std::make_shared<FuncGraph>();
res->set_flag(FUNC_GRAPH_FLAG_CORE, true);
res->debug_info()->set_name("tail");

if (tail_type_ == kGradFirst) {
AnfNodePtr tuple_parameter = res->add_parameter();
PrimitivePtr getitem_op = nullptr;
if (sequeue->isa<AbstractTuple>()) {
getitem_op = prim::kPrimTupleGetItem;
} else {
getitem_op = prim::kPrimListGetItem;
}
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
res->set_output(res->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))}));
} else {
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
res->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
}

return ret;
return res;
}

if (tail_type_ == kGradByPosition) {
if (pos == nullptr) {
MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
}
std::vector<AnfNodePtr> pos_elems;
PrimitivePtr pos_op = nullptr;
if (pos->isa<AbstractTuple>()) {
pos_elems.push_back(NewValueNode(prim::kPrimMakeTuple));
pos_op = prim::kPrimTupleGetItem;
} else {
pos_elems.push_back(NewValueNode(prim::kPrimMakeList));
pos_op = prim::kPrimListGetItem;
}
AnfNodePtr pos_value = nullptr;
AnfNodePtr pos_value_adjust = nullptr;
auto ptrpos = ret->add_parameter();
if (pos->size() == 1) {
pos_value = ret->NewCNode({NewValueNode(pos_op), ptrpos, NewValueNode(SizeToLong(0))});
pos_value_adjust = ret->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, pos_value_adjust}));
} else {
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
}
return ret;
} else {
for (size_t i = 0; i < pos->size(); ++i) {
pos_value = ret->NewCNode({NewValueNode(pos_op), ptrpos, NewValueNode(SizeToLong(i))});
pos_value_adjust = ret->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
pos_elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, pos_value_adjust}));
}
}
ret->set_output(ret->NewCNodeInOrder(pos_elems));
return ret;
GenerateSequenceFuncGraphByPosition(res, sequeue, pos, enable_tuple_grad_);
return res;
}

AnfNodePtr tuple_parameter = res->add_parameter();
std::vector<AnfNodePtr> elements;
PrimitivePtr op = nullptr;
if (sequeue->isa<AbstractTuple>()) {
elements.push_back(NewValueNode(prim::kPrimMakeTuple));
op = prim::kPrimTupleGetItem;
} else {
elements.push_back(NewValueNode(prim::kPrimMakeList));
op = prim::kPrimListGetItem;
}
for (size_t i = 1; i < sequeue->size(); ++i) {
if (tail_type_ == kGradAll) {
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
(*sequeue)[i]->BuildType()->isa<Number>())) {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
}
} else {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
}
}

ret->set_output(ret->NewCNodeInOrder(elems));
return ret;
if (elements.size() > 1) {
res->set_output(res->NewCNodeInOrder(elements));
return res;
} else if (sequeue->isa<AbstractTuple>()) { // Empty tuple.
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
res->set_output(empty_tuple);
return res;
} else { // Empty list.
auto empty_list_value = std::make_shared<ValueList>(ValuePtrList());
auto empty_list = NewValueNode(empty_list_value);
res->set_output(empty_list);
return res;
}
}

FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {


+ 6
- 0
mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.h View File

@@ -25,6 +25,12 @@ class EliminateDeadNodePass {
EliminateDeadNodePass() = default;
~EliminateDeadNodePass() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
return false;
}

static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
MS_LOG(INFO) << "Closure enable:" << enable_closure;
if (!enable_closure) {


+ 8
- 8
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -240,20 +240,20 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr;

abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec, bool clear) {
MS_LOG(DEBUG) << "AbstractAnalyze start";
auto engine = res->engine();
auto engine = resource->engine();
MS_EXCEPTION_IF_NULL(engine);
if (clear) {
auto manager = res->manager();
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
engine->Clear();
for (auto &node : manager->all_nodes()) {
MS_EXCEPTION_IF_NULL(node);

// Handle previous inferred value for CNode if is loaded from MindIR
if (res->is_load()) {
if (resource->is_load()) {
// If the primitive is not defined in front end,keep the inferred value loaded from MindIR.
auto primitive = GetCNodePrimitive(node);
if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
@@ -270,19 +270,19 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
}
}
}
auto ret = engine->Run(func_graph, args_spec);
auto res = engine->Run(func_graph, args_spec);
MS_LOG(INFO) << "function call max depth: " << abstract::FunctionCallMaxDepth()
<< ", simulate call max depth: " << abstract::StackFrameMaxDepth();
MS_LOG(DEBUG) << "AbstractAnalyze end";
return ret;
return res;
}

FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AnalysisContextPtr &context) {
MS_EXCEPTION_IF_NULL(res);
MS_LOG(DEBUG) << "ProgramSpecialize start";
abstract::ProgramSpecializer spc(res->engine());
FuncGraphPtr result = spc.Run(func_graph, context);
abstract::ProgramSpecializer specializer(res->engine());
FuncGraphPtr result = specializer.Run(func_graph, context);
auto manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->KeepRoots({result});


+ 10
- 10
mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc View File

@@ -239,30 +239,30 @@ AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &c
return async_eval_result->GetResult();
}

void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg,
void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &current_abs,
AnalysisConfigAsyncResultCache *cache) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cache);
if (arg == nullptr) {
if (current_abs == nullptr) {
MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
}
std::lock_guard<std::mutex> lock(lock_);
AsyncAbstractPtr async_eval_result = cache->get(conf);
if (async_eval_result == nullptr) {
async_eval_result = std::make_shared<AsyncAbstract>();
async_eval_result->set_result(arg);
async_eval_result->set_result(current_abs);
cache->set(conf, async_eval_result);
} else {
auto ab1 = async_eval_result->TryGetResult();
AbstractBasePtrList absList;
if (ab1 != nullptr) {
absList.push_back(arg);
absList.push_back(ab1);
auto previous_abs = async_eval_result->TryGetResult();
AbstractBasePtrList abstract_list;
if (previous_abs != nullptr) {
abstract_list.push_back(previous_abs);
abstract_list.push_back(current_abs);
// Join two branches's result
auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node());
auto joined_result = AnalysisEngine::ProcessEvalResults(abstract_list, conf->node());
async_eval_result->set_result(joined_result->abstract());
} else {
async_eval_result->set_result(arg);
async_eval_result->set_result(current_abs);
}
}
}


+ 4
- 3
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc View File

@@ -240,7 +240,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg);
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString();
MS_LOG(DEBUG) << GetInferThread() << "Set parameter[" << i << "] for " << fg->ToString()
<< ", conf: " << conf->ToString() << ", arg: " << arg->ToString();
}
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
<< ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString()
@@ -416,8 +417,8 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt

EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
if (args_conf_list.empty()) {
MS_LOG(EXCEPTION) << "Size should be greater than 0";
if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator") {
MS_LOG(EXCEPTION) << "Size should be greater than 0, during running " << identifier_;
}
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),


+ 69
- 3
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -517,7 +517,7 @@ TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_lis
}
return TypeJoin(args_type_list);
}
} // end anonymous namespace
} // namespace

py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
MS_EXCEPTION_IF_NULL(abs_base);
@@ -648,7 +648,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
CheckCustomPrimOutputInferResult(prim_py, res_spec);
return res_spec;
}
} // end anonymous namespace
} // namespace

EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
const AbstractBasePtrList &args) {
@@ -761,6 +761,14 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
if (eval_result != nullptr) {
auto abs = eval_result->abstract()->Clone();
auto attr = eval_result->attribute();

// To check tuple/list operations with a white list of Python primitive.
if (prim_py_->name() == prim::kPrimStack->name()) {
// Set all used flags of tuple as true.
for (auto &arg : args) {
SetSequenceElementsUseFlags(arg, true);
}
}
return std::make_shared<EvalResult>(abs, attr);
}

@@ -774,6 +782,14 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
evaluator_cache_mgr_->SetValue(args, infer_result);

// To check tuple/list operations with a white list of Python primitive.
if (prim_py_->name() == prim::kPrimStack->name()) {
// Set all used flags of tuple as true.
for (auto &arg : args) {
SetSequenceElementsUseFlags(arg, true);
}
}
return infer_result;
}

@@ -1103,7 +1119,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
}
}
} // end anonymous namespace
} // namespace

namespace {
class EmbedEvaluator : public SymbolicPrimEvaluator {
@@ -1452,6 +1468,54 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
}
};

class MakeTupleEvaluator : public TransitionPrimEvaluator {
public:
MakeTupleEvaluator() : TransitionPrimEvaluator("MakeTupleEvaluator") {}
~MakeTupleEvaluator() override = default;
MS_DECLARE_PARENT(MakeTupleEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) {
MS_LOG(WARNING) << "For MakeTuple, the inputs should not be empty.";
}
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
}
AnfNodeWeakPtrList sequence_nodes =
(enable_eliminate_unused_element ? AnfNodeWeakPtrList({AnfNodeWeakPtr(out_conf->node())}) : AnfNodeWeakPtrList());
auto abs = std::make_shared<AbstractTuple>(args_spec_list, sequence_nodes);
auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, res);
return res;
}
};

class MakeListEvaluator : public TransitionPrimEvaluator {
public:
MakeListEvaluator() : TransitionPrimEvaluator("MakeListEvaluator") {}
~MakeListEvaluator() override = default;
MS_DECLARE_PARENT(MakeListEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) {
MS_LOG(WARNING) << "For MakeList, the inputs should not be empty.";
}
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
}
AnfNodeWeakPtrList sequence_nodes =
(enable_eliminate_unused_element ? AnfNodeWeakPtrList({AnfNodeWeakPtr(out_conf->node())}) : AnfNodeWeakPtrList());
auto abs = std::make_shared<AbstractList>(args_spec_list, sequence_nodes);
auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, res);
return res;
}
};

class PartialEvaluator : public Evaluator {
public:
PartialEvaluator() : Evaluator("PartialEvaluator") {}
@@ -1597,6 +1661,8 @@ void InitPrimEvaluatorConstructors() {
constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>();
constructor[prim::kPrimMakeList] = std::make_shared<MakeListEvaluator>();
}
} // namespace



+ 50
- 50
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h View File

@@ -66,6 +66,56 @@ class PythonPrimEvaluator final : public TrivialPrimEvaluator {
PrimitivePyPtr prim_py_;
};

using ValuePtrList = std::vector<ValuePtr>;
using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &);

class UniformPrimEvaluator final : public TrivialPrimEvaluator {
public:
UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type)
: TrivialPrimEvaluator("UniformPrimEvaluator"),
impl_(impl),
eval_value_(eval_value),
func_desc_(func_desc),
nargs_(func_desc_->args().size()),
return_value_type_(func_desc_->retval()),
specify_out_type_(specify_out_type) {
for (size_t i = 0; i < nargs_; ++i) {
TypePtr type = func_desc_->args()[i];
if (type_map_[type]) {
type_map_[type]->push_back(i);
} else {
type_map_[type] = std::make_shared<std::vector<size_t>>();
type_map_[type]->push_back(i);
}
}
}
~UniformPrimEvaluator() override = default;
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);

EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
ValuePtr RunImpl(const ValuePtrList &args) const;

// If eval_value_ is False, return broadened arguments.
AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override {
if (!eval_value_) {
AbstractBasePtrList broadened_args_spec_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list),
[](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
return broadened_args_spec_list;
}
return args_spec_list;
}

private:
PrimitiveImpl impl_;
bool eval_value_;
const FunctionPtr func_desc_;
const std::size_t nargs_;
const TypePtr return_value_type_;
const TypePtr specify_out_type_;
mindspore::HashMap<TypePtr, std::shared_ptr<std::vector<size_t>>, TypeHasher, TypeEqual> type_map_;
};

class DoSignatureEvaluator final : public Evaluator {
public:
explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
@@ -117,56 +167,6 @@ class MixedPrecisionCastEvaluator final : public Evaluator {

bool IsInWhiteList(const PrimitivePtr &primitive);

using ValuePtrList = std::vector<ValuePtr>;
using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &);

class UniformPrimEvaluator final : public TrivialPrimEvaluator {
public:
UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type)
: TrivialPrimEvaluator("UniformPrimEvaluator"),
impl_(impl),
eval_value_(eval_value),
func_desc_(func_desc),
nargs_(func_desc_->args().size()),
return_value_type_(func_desc_->retval()),
specify_out_type_(specify_out_type) {
for (size_t i = 0; i < nargs_; ++i) {
TypePtr type = func_desc_->args()[i];
if (type_map_[type]) {
type_map_[type]->push_back(i);
} else {
type_map_[type] = std::make_shared<std::vector<size_t>>();
type_map_[type]->push_back(i);
}
}
}
~UniformPrimEvaluator() override = default;
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);

EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
ValuePtr RunImpl(const ValuePtrList &args) const;

// If eval_value_ is False, return broadened arguments.
AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override {
if (!eval_value_) {
AbstractBasePtrList broadened_args_spec_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list),
[](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
return broadened_args_spec_list;
}
return args_spec_list;
}

private:
PrimitiveImpl impl_;
bool eval_value_;
const FunctionPtr func_desc_;
const std::size_t nargs_;
const TypePtr return_value_type_;
const TypePtr specify_out_type_;
mindspore::HashMap<TypePtr, std::shared_ptr<std::vector<size_t>>, TypeHasher, TypeEqual> type_map_;
};

PrimEvaluatorMap &GetPrimEvaluatorConstructors();

// Check whether type x is a subtype of model.


+ 199
- 44
mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc View File

@@ -67,7 +67,12 @@ FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisConte
top_context_ = context;
MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
}
return SpecializeFuncGraph(fg, context);
auto res = SpecializeFuncGraph(fg, context);
// Call PurifyElements() to purify tuple/list elements.
for (auto &sequence_abs : sequence_abstract_list_) {
sequence_abs->PurifyElements();
}
return res;
}

FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
@@ -80,10 +85,10 @@ FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, con
}

std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
FuncGraphPtr fg2 = fg_spec->specialized_func_graph();
FuncGraphPtr specialized_func_graph = fg_spec->specialized_func_graph();
specializations_[context->SpecializeKey()] = fg_spec;
fg_spec->Run();
return fg2;
return specialized_func_graph;
}

std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
@@ -290,6 +295,133 @@ void FuncGraphSpecializer::SecondPass() {
}
}

namespace {
// Update elements use flags for MakeTuple/tuple node,
// and update the node's AbstractSequence 'sequence_nodes' info.
void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node, const AbstractBasePtr &old_abs) {
if (new_node == old_node) {
return;
}
AbstractSequencePtr old_sequence_abs = dyn_cast<AbstractSequence>(old_abs);
if (old_sequence_abs == nullptr || old_sequence_abs->sequence_nodes().empty()) {
MS_LOG(DEBUG) << "No sequence node in old abs, " << old_node->DebugString() << " --> " << new_node->DebugString();
return;
}

for (auto &weak_node : old_sequence_abs->sequence_nodes()) {
auto sequence_node = weak_node.lock();
if (sequence_node == nullptr) {
MS_LOG(DEBUG) << "The sequence_nodes is free. " << old_node->DebugString() << " --> " << new_node->DebugString();
continue;
}
if (sequence_node != old_node) {
continue;
}

// Update new node's flags with old one, and update old sequence abstract's source node.
auto flags = GetSequenceNodeElementsUseFlags(old_node);
MS_LOG(DEBUG) << "Update sequence node, " << old_node->DebugString() << " --> " << new_node->DebugString()
<< ", elements_use_flags: " << (*flags);
SetSequenceNodeElementsUseFlags(new_node, flags);
old_sequence_abs->update_sequence_node(sequence_node, new_node);

// Update new sequence abstract if it's not equal to old one.
const AbstractBasePtr &new_abs = new_node->abstract();
if (old_abs == new_abs) {
continue;
}
AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_abs);
if (new_sequence_abs == nullptr) {
MS_LOG(EXCEPTION) << "New node should be sequence type as well, but got " << new_abs->ToString();
}
if (new_sequence_abs->sequence_nodes().empty()) {
new_sequence_abs->set_sequence_nodes({AnfNodeWeakPtr(new_node)});
} else {
new_sequence_abs->insert_sequence_node(new_node);
}
}
}

// Purify specific input of a CNode.
template <typename T>
void PurifySequenceValueNode(const CNodePtr &cnode, size_t index) {
const auto &old_input = cnode->input(index);
auto sequence_value = GetValueNode<std::shared_ptr<T>>(old_input);
if (sequence_value == nullptr) {
return;
}
auto flags = GetSequenceNodeElementsUseFlags(old_input);
if (flags == nullptr) {
return;
}
ValuePtrList elements;
for (size_t i = 0; i < (*flags).size(); ++i) {
if (!(*flags)[i]) {
auto zero = MakeValue(0);
elements.emplace_back(zero);
MS_LOG(INFO) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs["
<< index << "] of " << cnode->DebugString();
} else {
elements.emplace_back(sequence_value->value()[i]);
}
}
auto new_sequence_value = std::make_shared<T>(elements);
auto new_input = NewValueNode(new_sequence_value);
auto new_input_abs = new_sequence_value->ToAbstract();
AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_input_abs);
MS_EXCEPTION_IF_NULL(new_sequence_abs);
new_sequence_abs->set_sequence_nodes({AnfNodeWeakPtr(new_input)});
new_input->set_abstract(new_sequence_abs);
// Always reset tuple value node's use flags as non-use.
SetSequenceNodeElementsUseFlags(new_input, std::make_shared<std::vector<bool>>(new_sequence_abs->elements().size()));
MS_LOG(DEBUG) << "Update ValueTuple/ValueList, " << old_input->DebugString() << " --> " << new_input->DebugString()
<< ", which is inputs[" << index << "] of " << cnode->DebugString();
cnode->set_input(index, new_input);
}
} // namespace

// Eliminate the unused items of Tuple/List.
void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->abstract() == nullptr) {
MS_LOG(EXCEPTION) << "The parameter \'node\' and its abstract should not be null.";
}
const AbstractBasePtr abs = cnode->abstract();
AbstractSequencePtr sequence_abs = dyn_cast<AbstractSequence>(abs);
if (sequence_abs == nullptr || sequence_abs->sequence_nodes().empty()) {
return;
}
// Not call PurifyElements() here, just add to list.
specializer_->sequence_abstract_list().emplace_back(sequence_abs);
// Purify MakeTuple/MakeList CNode.
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
auto flags = GetSequenceNodeElementsUseFlags(cnode);
if (flags != nullptr) {
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(cnode->input(0));
for (size_t i = 0; i < (*flags).size(); ++i) {
if (!(*flags)[i]) {
auto zero_value = NewValueNode(MakeValue(0));
zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0)));
inputs.emplace_back(zero_value);
MS_LOG(INFO) << "Erase inputs[" << i << "] as zero for " << cnode->DebugString();
} else {
inputs.emplace_back(cnode->input(i + 1));
}
}
cnode->set_inputs(std::move(inputs));
cnode->set_abstract(sequence_abs);
}
}
// Purify each Tuple/List ValueNode in CNode.
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
if (IsValueNode<ValueTuple>(cnode->input(i))) {
PurifySequenceValueNode<ValueTuple>(cnode, i);
} else if (IsValueNode<ValueList>(cnode->input(i))) {
PurifySequenceValueNode<ValueList>(cnode, i);
}
}
}

void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
ScopeGuard scope_guard(node->scope());
@@ -304,7 +436,11 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
<< ", specialized_func_graph_: " << specialized_func_graph_->ToString();
return;
}
new_node->set_abstract(GetEvaluatedValue(conf));
try {
new_node->set_abstract(GetEvaluatedValue(conf));
} catch (const std::exception &) {
MS_LOG(EXCEPTION) << "Fail to get abstract value with " << conf->ToString() << ", for " << new_node->DebugString();
}
if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
if (partial_abstract->node() == node) {
@@ -315,35 +451,47 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
<< ", func_graph_: " << func_graph_->ToString()
<< ", specialized_func_graph_: " << specialized_func_graph_->ToString();

if (node->isa<CNode>()) {
auto attrs = conf->ObtainEvalResult()->attribute();
auto c_old = node->cast<CNodePtr>();
auto c_new = new_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_new);
auto new_inputs = c_new->inputs();
auto old_inputs = c_old->inputs();
for (size_t i = 0; i < old_inputs.size(); ++i) {
auto node_input = old_inputs[i];
AnfNodeConfigPtr iconf = MakeConfig(node_input);
AbstractBasePtr ival = GetEvaluatedValue(iconf);
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs, node);
if (replace_node == nullptr) {
replace_node = BuildReplacedNode(iconf);
replace_node->set_abstract(ival);
MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
} else {
MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
<< ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString();
}
if (new_inputs[i] != replace_node) {
new_inputs[i] = replace_node;
MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
}
if (!node->isa<CNode>()) {
return;
}
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
auto attrs = conf->ObtainEvalResult()->attribute();
auto c_old = node->cast<CNodePtr>();
auto c_new = new_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_new);
auto new_inputs = c_new->inputs();
auto old_inputs = c_old->inputs();
for (size_t i = 0; i < old_inputs.size(); ++i) {
auto node_input = old_inputs[i];
AnfNodeConfigPtr input_conf = MakeConfig(node_input);
AbstractBasePtr abs;
try {
abs = GetEvaluatedValue(input_conf);
} catch (const std::exception &) {
MS_LOG(EXCEPTION) << "Fail to get input's abstract value, with input config: " << input_conf->ToString()
<< ", in old node: " << c_old->DebugString();
}
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
AnfNodePtr replace_node = BuildPossibleValueNode(node_input, abs, attrs, node);
if (replace_node == nullptr) {
replace_node = BuildReplacedNode(input_conf);
replace_node->set_abstract(abs);
MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << abs->ToString();
} else {
MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
<< ", abs: " << abs->ToString() << ", replace_node: " << replace_node->ToString();
}
if (enable_eliminate_unused_element) {
UpdateSequenceNode(replace_node, node_input, abs);
}
if (new_inputs[i] != replace_node) {
new_inputs[i] = replace_node;
MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
}
c_new->set_inputs(new_inputs);
}
c_new->set_inputs(new_inputs);
}

AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
@@ -506,10 +654,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &fun
<< ", " << func->ToString();
return func;
}
FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
MS_EXCEPTION_IF_NULL(v);
v->set_flag(kFuncGraphFlagUndetermined, false);
return BuildValueNode(v, abs);
FuncGraphPtr func_graph = specializer_->SpecializeFuncGraph(context->func_graph(), context);
MS_EXCEPTION_IF_NULL(func_graph);
func_graph->set_flag(kFuncGraphFlagUndetermined, false);
return BuildValueNode(func_graph, abs);
}

AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &engine,
@@ -643,20 +791,21 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
return std::make_pair(AbstractBasePtrList(), nullptr);
}

void FuncGraphSpecializer::ProcessCNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (specializer_->seen().count(node) > 0) {
void FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (specializer_->seen().count(cnode) > 0) {
return;
}
specializer_->AddSeen(node);
auto new_inputs = node->inputs();
specializer_->AddSeen(cnode);

auto new_inputs = cnode->inputs();
if (new_inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
}
AnfNodePtr func = new_inputs[0];
MS_EXCEPTION_IF_NULL(func);
constexpr auto recursive_level = 2;
MS_LOG(DEBUG) << "Handle node: " << node->DebugString(recursive_level);
MS_LOG(DEBUG) << "Handle node: " << cnode->DebugString(recursive_level);

// First element is func so arg start from 1
std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
@@ -685,7 +834,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(func->func_graph());
if (status == kSpecializePoly ||
(func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
auto wrapped_node = BuildSpecializedParameterNode(node);
auto wrapped_node = BuildSpecializedParameterNode(cnode);
MS_LOG(DEBUG) << "Partial closure is handled, wrapped_node: " << wrapped_node->DebugString(recursive_level);
new_inputs[0] = wrapped_node;
}
@@ -723,7 +872,13 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &node) {
}

// Set the updated inputs.
node->set_inputs(new_inputs);
cnode->set_inputs(new_inputs);

static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
EliminateUnusedSequenceItem(cnode);
}
}

namespace {
@@ -756,7 +911,7 @@ bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argv
}
return false;
}
} // end anonymous namespace
} // namespace

SpecializeStatusCode FuncGraphSpecializer::AcquireUniqueEvalVal(const AbstractFunctionPtr &func,
const EvaluatorPtr &eval,


+ 6
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h View File

@@ -64,6 +64,8 @@ class ProgramSpecializer {

const AnalysisContextPtr &top_context() { return top_context_; }

std::vector<AbstractSequencePtr> &sequence_abstract_list() { return sequence_abstract_list_; }

private:
std::shared_ptr<AnalysisEngine> engine_;
mindspore::HashSet<AnfNodePtr> seen_;
@@ -71,6 +73,8 @@ class ProgramSpecializer {
std::unordered_map<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>, ContextHasher, ContextEqual>
specializations_;
AnalysisContextPtr top_context_;
// The list to purify tuple/list elements.
std::vector<AbstractSequencePtr> sequence_abstract_list_;
};

class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> {
@@ -99,6 +103,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
void ProcessNode(const AnfNodePtr &node);
void ProcessCNode(const CNodePtr &node);

void EliminateUnusedSequenceItem(const CNodePtr &cnode);

const NodeToNodeMap &cloned_nodes() const { return cloner_->cloned_nodes(); }

inline AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);


+ 32
- 13
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc View File

@@ -122,7 +122,10 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
MS_LOG(INFO) << func_graph->ToString() << ": Run finished.";

MS_EXCEPTION_IF_NULL(output_conf);
result.inferred = output_conf->ObtainEvalResult();
auto eval_result = output_conf->ObtainEvalResult();
// Set the sequence nodes' elements use flags all true.
SetSequenceElementsUseFlagsRecursively(eval_result->abstract(), true);
result.eval_result = eval_result;
result.context = root_context;
} catch (const std::exception &ex) {
MS_LOG(INFO) << "Eval " << func_graph->ToString() << " threw exception.";
@@ -361,13 +364,14 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
return std::make_shared<MixedPrecisionCastEvaluator>(prim);
}

// find prim infer function in the prim function map return a standard evaluator
// Find prim infer function in the prim function map return a standard evaluator
auto eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl.infer_shape_impl_ != nullptr) {
if (eval_impl.infer_shape_impl_ != nullptr && prim->name() != prim::kPrimMakeTuple->name() &&
prim->name() != prim::kPrimMakeList->name()) { // Refactoring infer routine soon.
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
}

// use python infer function if the infer function not founded in the map return a python evaluator
// Use python infer function if the infer function not founded in the map return a python evaluator
EvaluatorPtr evaluator = nullptr;
if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
@@ -388,7 +392,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
return nullptr;
}

// return a default evaluator
// Return a default evaluator
if (engine == nullptr) {
// If engine is nullptr, get constructor from default.
const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
@@ -674,12 +678,13 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt

std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBasePtr &last_spec,
const AnfNodePtr &node, const std::string &error_info) {
constexpr int recursive_level = 2;
std::ostringstream buffer;
buffer << "The return values of different branches do not join. \n"
<< error_info << "\nFor more details, please refer to the FAQ at https://www.mindspore.cn.\n"
<< "The abstract type of the return value of the current branch is " << spec->ToString()
<< ", and that of the previous branch is " << last_spec->ToString() << ".\n"
<< "The node " << node->DebugString();
<< "The node is " << node->DebugString(recursive_level);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>()->input(0);
if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
@@ -803,10 +808,9 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
const std::vector<AsyncAbstractPtr> &pending_async_abstract_list,
const std::vector<std::size_t> &index) {
if (orig_abs->isa<AbstractSequence>()) {
const auto &orig_abstract_seq = orig_abs->cast<AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(orig_abstract_seq);
const auto &orig_elements = orig_abstract_seq->elements();
const auto sequence_abs = dyn_cast<AbstractSequence>(orig_abs);
if (sequence_abs != nullptr) {
const auto &orig_elements = sequence_abs->elements();
AbstractBasePtrList new_elements;
for (size_t i = 0; i < orig_elements.size(); ++i) {
if (orig_elements[i]->isa<AbstractFuncAtom>()) {
@@ -826,11 +830,15 @@ AbstractBasePtr BuildAsyncAbstractRecursively(const AbstractBasePtr &orig_abs,
new_elements.push_back(orig_elements[i]);
}
}
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
AbstractBasePtr new_abs;
if (orig_abs->isa<AbstractTuple>()) {
new_abs = std::make_shared<AbstractTuple>(new_elements);
new_abs = std::make_shared<AbstractTuple>(
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : AnfNodeWeakPtrList()));
} else if (orig_abs->isa<AbstractList>()) {
new_abs = std::make_shared<AbstractList>(new_elements);
new_abs = std::make_shared<AbstractList>(
new_elements, (enable_eliminate_unused_element ? sequence_abs->sequence_nodes() : AnfNodeWeakPtrList()));
} else {
MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << orig_abs->ToString();
}
@@ -864,7 +872,7 @@ void BuildPossibleSpecs(const AbstractBasePtr &first_result,
MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
<< ", new: " << new_first_result->ToString();
std::replace_if(
out_specs->begin(), out_specs->end(), [first_result](const auto &elem) { return elem == first_result; },
out_specs->begin(), out_specs->end(), [first_result](const auto &element) { return element == first_result; },
new_first_result);
} else {
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
@@ -1059,6 +1067,17 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont
auto prim = value->cast<PrimitivePtr>();
return MakeAbstractClosure(prim, anf_node);
}
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element && value->isa<ValueSequence>()) {
auto abs = value->ToAbstract();
auto sequence_abs = dyn_cast<AbstractSequence>(abs);
MS_EXCEPTION_IF_NULL(sequence_abs);
if (anf_node != nullptr) {
SetSequenceNodeElementsUseFlags(anf_node, std::make_shared<std::vector<bool>>(sequence_abs->elements().size()));
sequence_abs->set_sequence_nodes({AnfNodeWeakPtr(anf_node)});
}
}
return value->ToAbstract();
}



+ 1
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h View File

@@ -213,7 +213,7 @@ using AnfNodeConfigMap =
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;

struct AnalysisResult {
EvalResultPtr inferred;
EvalResultPtr eval_result;
AnalysisContextPtr context;
};



+ 189
- 31
mindspore/core/abstract/abstract_value.cc View File

@@ -241,13 +241,13 @@ const AbstractBasePtr AbstractSequence::operator[](const std::size_t &dim) const
return elements_[dim];
}

std::string AbstractSequence::ToString() const {
std::string AbstractSequence::ToStringInternal() const {
std::ostringstream buffer;
size_t i = 0;
size_t size = elements_.size();
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
buffer << "element[" << i << "]: " << ele->ToString();
for (const auto &element : elements_) {
MS_EXCEPTION_IF_NULL(element);
buffer << "element[" << i << "]: " << element->ToString();
if (i < size - 1) {
buffer << ", ";
}
@@ -256,11 +256,169 @@ std::string AbstractSequence::ToString() const {
return buffer.str();
}

std::string AbstractSequence::ToString() const {
std::stringstream ss;
ss << type_name();
ss << "{";
ss << ToStringInternal();
if (!sequence_nodes_.empty()) {
ss << ", sequence_nodes: {";
for (size_t i = 0; i < sequence_nodes_.size(); ++i) {
auto sequence_node = sequence_nodes_[i].lock();
if (sequence_node == nullptr) {
ss << "<freed node>";
continue;
} else {
ss << sequence_node->DebugString();
}
auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
if (flags != nullptr) {
ss << ", elements_use_flags: " << (*flags);
}
if (i != sequence_nodes_.size() - 1) {
ss << ", ";
}
}
ss << "}";
}
ss << "}";
return ss.str();
}

namespace {
void CollectSequenceNodes(const AnfNodeWeakPtrList &source_sequence_nodes, AnfNodeWeakPtrList *sequence_nodes_ptr) {
AnfNodeWeakPtrList &sequence_nodes = *sequence_nodes_ptr;
auto sequence_nodes_size = source_sequence_nodes.size();
for (size_t i = 0; i < sequence_nodes_size; ++i) {
// Lock sequence nodes of this.
auto &source_weak_node = source_sequence_nodes[i];
auto this_sequence_node = source_weak_node.lock();
if (this_sequence_node == nullptr) {
continue;
}
// Check and emplace sequence node for this.
auto this_iter = std::find_if(
sequence_nodes.begin(), sequence_nodes.end(),
[&this_sequence_node](const AnfNodeWeakPtr &weak_node) { return this_sequence_node == weak_node.lock(); });
if (this_iter == sequence_nodes.end()) {
sequence_nodes.emplace_back(AnfNodeWeakPtr(this_sequence_node));
}
}
}

void SynchronizeSequenceNodesElementsUseFlags(const AnfNodeWeakPtrList &sequence_nodes) {
// Synchronize the elements use flags for all sequence nodes.
auto current_sequence_node = sequence_nodes[0].lock();
MS_EXCEPTION_IF_NULL(current_sequence_node);
for (size_t i = 1; i < sequence_nodes.size(); ++i) {
// Synchronize the 'elements_use_flags' for all sequence node.
// We set the same 'elements_use_flags' for them after here.
auto latter_sequence_node = sequence_nodes[i].lock();
MS_EXCEPTION_IF_NULL(latter_sequence_node);
// The 'current_sequence_node' is not equal to 'latter_sequence_node'.
auto current_flags = GetSequenceNodeElementsUseFlags(current_sequence_node);
auto latter_flags = GetSequenceNodeElementsUseFlags(latter_sequence_node);
std::shared_ptr<std::vector<bool>> unique_flags = nullptr; // Choose the ptr (use_count > 1) as unique flags.
if (current_flags.use_count() == 1 && latter_flags.use_count() == 1) {
unique_flags = current_flags;
} else {
MS_EXCEPTION_IF_CHECK_FAIL(current_flags.use_count() > 1 && latter_flags.use_count() > 1,
"Allow only one side has more than one use count.");
if (current_flags.use_count() > 1) {
unique_flags = current_flags;
} else { // If latter_flags.use_count() > 1
unique_flags = latter_flags;
}
}
for (size_t j = 0; j < current_flags->size(); ++j) {
MS_LOG(DEBUG) << "Check elements_use_flags[" << j << "], this_flag: " << (*current_flags)[j]
<< ", other_flag: " << (*latter_flags)[j];
if ((*current_flags)[j] != (*latter_flags)[j]) {
(*unique_flags)[j] = true;
} else {
(*unique_flags)[j] = (*current_flags)[j];
}
}
if (unique_flags != current_flags) {
SetSequenceNodeElementsUseFlags(current_sequence_node, unique_flags);
}
if (unique_flags != latter_flags) {
SetSequenceNodeElementsUseFlags(latter_sequence_node, unique_flags);
}
}
}
} // namespace

AnfNodeWeakPtrList AbstractSequence::SequenceNodesJoin(const AbstractBasePtr &other) {
AnfNodeWeakPtrList sequence_nodes;
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!enable_eliminate_unused_element) {
return sequence_nodes;
}

MS_LOG(DEBUG) << "this: " << ToString() << ", other: " << other->ToString();
auto other_sequence = dyn_cast<AbstractSequence>(other);
auto this_sequence_nodes_size = sequence_nodes_.size();
auto other_sequence_nodes_size = (other_sequence != nullptr ? other_sequence->sequence_nodes_.size() : 0);
if (this_sequence_nodes_size == 0 && other_sequence_nodes_size == 0) {
return sequence_nodes;
}
// Collect this and other sequence nodes.
CollectSequenceNodes(sequence_nodes_, &sequence_nodes);
CollectSequenceNodes(other_sequence->sequence_nodes_, &sequence_nodes);
if (sequence_nodes.empty()) {
MS_LOG(EXCEPTION) << "Sequence nodes size should not be empty.";
}
// Synchronize the elements use flags for all sequence nodes.
SynchronizeSequenceNodesElementsUseFlags(sequence_nodes);
return sequence_nodes;
}

void AbstractSequence::PurifyElements() {
if (sequence_nodes_.empty()) {
return;
}
// Just use any sequence node's elements_use_flags.
std::shared_ptr<std::vector<bool>> elements_use_flags_ptr = nullptr;
for (auto &weak_node : sequence_nodes_) {
auto sequence_node = weak_node.lock();
if (sequence_node == nullptr) {
MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
continue;
}
auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
if (flags != nullptr) {
elements_use_flags_ptr = flags;
break;
}
}
// Purify the elements.
if (elements_use_flags_ptr == nullptr) {
MS_LOG(ERROR) << "Check if all sequence nodes are released, or none elements use flags in them. " << ToString();
return;
}
auto &elements_use_flags = *elements_use_flags_ptr;
if (elements_use_flags.size() != elements_.size()) {
MS_LOG(EXCEPTION) << "Elements size should be equal to elements use flags size.";
}
for (size_t i = 0; i < elements_use_flags.size(); ++i) {
MS_EXCEPTION_IF_NULL(elements_[i]);
if (!elements_use_flags[i]) {
const auto unuse_node_none = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(0));
elements_[i] = unuse_node_none;
MS_LOG(INFO) << "Set element[" << i << "] to Zero.";
} else {
MS_LOG(DEBUG) << "Keep element[" << i << "] as " << elements_[i]->ToString();
}
}
}

TypePtrList AbstractSequence::ElementsType() const {
TypePtrList element_type_list;
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
TypePtr element_type = ele->BuildType();
for (const auto &element : elements_) {
MS_EXCEPTION_IF_NULL(element);
TypePtr element_type = element->BuildType();
element_type_list.push_back(element_type);
}
return element_type_list;
@@ -268,50 +426,50 @@ TypePtrList AbstractSequence::ElementsType() const {

BaseShapePtrList AbstractSequence::ElementsShape() const {
BaseShapePtrList element_shape_list;
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
BaseShapePtr element_shape = ele->BuildShape();
for (const auto &element : elements_) {
MS_EXCEPTION_IF_NULL(element);
BaseShapePtr element_shape = element->BuildShape();
element_shape_list.push_back(element_shape);
}
return element_shape_list;
}

AbstractBasePtrList AbstractSequence::ElementsClone() const {
AbstractBasePtrList ele_list;
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
AbstractBasePtr clone = ele->Clone();
ele_list.push_back(clone);
AbstractBasePtrList element_list;
for (const auto &element : elements_) {
MS_EXCEPTION_IF_NULL(element);
AbstractBasePtr clone = element->Clone();
element_list.push_back(clone);
}
return ele_list;
return element_list;
}

AbstractBasePtrList AbstractSequence::ElementsBroaden() const {
AbstractBasePtrList ele_list;
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
AbstractBasePtr broadend = ele->Broaden();
ele_list.push_back(broadend);
AbstractBasePtrList element_list;
for (const auto &element : elements_) {
MS_EXCEPTION_IF_NULL(element);
AbstractBasePtr broadend = element->Broaden();
element_list.push_back(broadend);
}
return ele_list;
return element_list;
}

AbstractBasePtrList AbstractSequence::ElementsPartialBroaden() const {
AbstractBasePtrList ele_list;
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
AbstractBasePtr broadend = ele->PartialBroaden();
ele_list.push_back(broadend);
AbstractBasePtrList element_list;
for (const auto &element : elements_) {
MS_EXCEPTION_IF_NULL(element);
AbstractBasePtr broadend = element->PartialBroaden();
element_list.push_back(broadend);
}
return ele_list;
return element_list;
}

template <typename T>
ValuePtr AbstractSequence::ElementsBuildValue() const {
std::vector<ValuePtr> element_value_list;
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
ValuePtr element_value = ele->BuildValue();
for (const auto &element : elements_) {
MS_EXCEPTION_IF_NULL(element);
ValuePtr element_value = element->BuildValue();
MS_EXCEPTION_IF_NULL(element_value);
if (element_value->isa<AnyValue>()) {
return kAnyValue;


+ 85
- 15
mindspore/core/abstract/abstract_value.h View File

@@ -692,7 +692,9 @@ class MS_CORE_API AbstractSequence : public AbstractBase {
/// \brief Constructor of AbstractSequence.
///
/// \param[in] elements A list of abstracts.
explicit AbstractSequence(const AbstractBasePtrList &elements) : elements_(elements) {}
/// \param[in] sequence_nodes The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes.
explicit AbstractSequence(const AbstractBasePtrList &elements, const AnfNodeWeakPtrList &sequence_nodes)
: elements_(elements), sequence_nodes_(sequence_nodes) {}

/// \brief Destructor of AbstractSequence.
~AbstractSequence() override = default;
@@ -738,6 +740,12 @@ class MS_CORE_API AbstractSequence : public AbstractBase {
template <typename T>
AbstractBasePtr ElementsJoin(const AbstractBasePtr &other);

/// \brief Combine other sequence nodes with this one.
///
/// \param[in] other The other abstract to be joined.
/// \return A sequence nodes list combined.
AnfNodeWeakPtrList SequenceNodesJoin(const AbstractBasePtr &other);

/// \brief Get the size of the stored elements.
///
/// \return A size_t.
@@ -748,8 +756,51 @@ class MS_CORE_API AbstractSequence : public AbstractBase {
/// \return A vector of elements.
const AbstractBasePtrList &elements() const { return elements_; }

/// \brief Purify the elements list, and clean unused elements.
void PurifyElements();

/// \brief Get the sequence nodes where these 'AbstractSequence' evaluated from.
///
/// \return The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes.
const AnfNodeWeakPtrList &sequence_nodes() const { return sequence_nodes_; }

/// \brief Set the sequence nodes where these 'AbstractSequence' evaluated from.
///
/// \param[in] sequence_nodes The nodes of tuple/list, usually are MakeTuple/MakeList CNodes or tuple/list ValueNodes.
void set_sequence_nodes(const AnfNodeWeakPtrList &sequence_nodes) { sequence_nodes_ = sequence_nodes; }

/// \brief Insert a node into the sequence nodes.
///
/// \param[in] sequence_node The node to intert into sequence nodes.
void insert_sequence_node(const AnfNodePtr &sequence_node) {
auto iter =
std::find_if(sequence_nodes_.begin(), sequence_nodes_.end(),
[&sequence_node](const AnfNodeWeakPtr &weak_node) { return sequence_node == weak_node.lock(); });
if (iter == sequence_nodes_.end()) {
sequence_nodes_.emplace_back(sequence_node);
} else {
MS_LOG(EXCEPTION) << "Fail to insert node \'" << sequence_node->DebugString() << "\' into sequence nodes.";
}
}

/// \brief Update the sequence nodes.
///
/// \param[in] old_sequence_node The old node in sequence nodes.
/// \param[in] new_sequence_node The new node to replace old node in sequence nodes.
void update_sequence_node(const AnfNodePtr &old_sequence_node, const AnfNodePtr &new_sequence_node) {
auto iter = std::find_if(
sequence_nodes_.begin(), sequence_nodes_.end(),
[&old_sequence_node](const AnfNodeWeakPtr &weak_node) { return old_sequence_node == weak_node.lock(); });
if (iter != sequence_nodes_.end()) {
*iter = new_sequence_node;
return;
}
MS_LOG(EXCEPTION) << "Not found old node \'" << old_sequence_node->DebugString() << "\' in sequence nodes.";
}

std::size_t hash() const override;

std::string ToStringInternal() const;
std::string ToString() const override;

/// \brief Overwrite the operator '[]' to get an element.
@@ -767,6 +818,7 @@ class MS_CORE_API AbstractSequence : public AbstractBase {

protected:
AbstractBasePtrList elements_;
AnfNodeWeakPtrList sequence_nodes_;
};
using AbstractSequencePtr = std::shared_ptr<AbstractSequence>;

@@ -776,7 +828,9 @@ class MS_CORE_API AbstractTuple final : public AbstractSequence {
/// \brief Constructor of AbstractTuple.
///
/// \param[in] elements A list of abstracts.
explicit AbstractTuple(const AbstractBasePtrList &elements) : AbstractSequence(elements) {}
/// \param[in] tuple_node The nodes of tuple, usually are MakeTuple CNodes or tuple ValueNodes.
explicit AbstractTuple(const AbstractBasePtrList &elements, const AnfNodeWeakPtrList &tuple_nodes = {})
: AbstractSequence(elements, tuple_nodes) {}

/// \brief Destructor of AbstractTuple.
~AbstractTuple() override = default;
@@ -786,15 +840,22 @@ class MS_CORE_API AbstractTuple final : public AbstractSequence {

BaseShapePtr BuildShape() const override { return std::make_shared<TupleShape>(ElementsShape()); }

AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); }

AbstractBasePtr Broaden() const override { return std::make_shared<AbstractTuple>(ElementsBroaden()); }
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone(), sequence_nodes_); }

AbstractBasePtr PartialBroaden() const override { return std::make_shared<AbstractTuple>(ElementsPartialBroaden()); }
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractTuple>(ElementsBroaden(), sequence_nodes_);
}

AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); }
AbstractBasePtr PartialBroaden() const override {
return std::make_shared<AbstractTuple>(ElementsPartialBroaden(), sequence_nodes_);
}

std::string ToString() const override { return type_name() + "(" + AbstractSequence::ToString() + ")"; }
AbstractBasePtr Join(const AbstractBasePtr &other) override {
auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractTuple>(other));
MS_EXCEPTION_IF_NULL(res);
res->set_sequence_nodes(SequenceNodesJoin(other));
return res;
}

/// \brief Check whether all elements of the tuple are tensors.
///
@@ -821,7 +882,9 @@ class MS_CORE_API AbstractList final : public AbstractSequence {
/// \brief Constructor of AbstractList.
///
/// \param[in] elements A list of abstracts.
explicit AbstractList(const AbstractBasePtrList &elements) : AbstractSequence(elements) {}
/// \param[in] list_node The nodes of list, usually are MakeList CNodes or list ValueNodes.
explicit AbstractList(const AbstractBasePtrList &elements, const AnfNodeWeakPtrList &list_nodes = {})
: AbstractSequence(elements, list_nodes) {}

/// \brief Destructor of AbstractList.
~AbstractList() override = default;
@@ -831,15 +894,22 @@ class MS_CORE_API AbstractList final : public AbstractSequence {

BaseShapePtr BuildShape() const override { return std::make_shared<ListShape>(ElementsShape()); }

AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); }

AbstractBasePtr Broaden() const override { return std::make_shared<AbstractList>(ElementsBroaden()); }
AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone(), sequence_nodes_); }

AbstractBasePtr PartialBroaden() const override { return std::make_shared<AbstractList>(ElementsPartialBroaden()); }
AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractList>(ElementsBroaden(), sequence_nodes_);
}

AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); }
AbstractBasePtr PartialBroaden() const override {
return std::make_shared<AbstractList>(ElementsPartialBroaden(), sequence_nodes_);
}

std::string ToString() const override { return type_name() + "[" + AbstractSequence::ToString() + "]"; }
AbstractBasePtr Join(const AbstractBasePtr &other) override {
auto res = dyn_cast<AbstractSequence>(ElementsJoin<AbstractList>(other));
MS_EXCEPTION_IF_NULL(res);
res->set_sequence_nodes(SequenceNodesJoin(other));
return res;
}

/// \brief Overwrite the operator '==' to compare other abstract list.
///


+ 2
- 0
mindspore/core/abstract/prim_arrays.cc View File

@@ -136,6 +136,8 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
tuple_len = arg->elements().size();
tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
// For Stack(tuple), set all used flags of tuple as true.
SetSequenceElementsUseFlags(args_spec_list[0], true);
} else if (args_spec_list[0]->isa<AbstractTensor>()) {
tuple_len = args_spec_list.size();
tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);


+ 9
- 0
mindspore/core/abstract/prim_others.cc View File

@@ -191,6 +191,9 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
return args_spec_list[0];
}

// For F.depend(x, MakeTuple()) or F.depend(x, tuple), set all used flags of tuple as true.
SetSequenceElementsUseFlags(dependant_abstract, true);

auto depends = args_spec_list[0]->Broaden(); // Avoid eliminating the dependent node.
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR)) {
// For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
@@ -207,6 +210,12 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0";
}
MS_EXCEPTION_IF_NULL(args_spec_list[0]);

// For UpdateState(x, MakeTuple()) or UpdateState(x, tuple), set all used flags of tuple as true.
for (size_t i = 1; i < args_spec_list.size(); i++) {
SetSequenceElementsUseFlags(args_spec_list[i], true);
}

return args_spec_list[0]->Broaden();
}



+ 54
- 27
mindspore/core/abstract/prim_structures.cc View File

@@ -36,7 +36,8 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tuples.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
constexpr int args_spec_size = 2;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);

@@ -66,7 +67,8 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list) {
// Inputs: a string and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
constexpr int args_spec_size = 2;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);

ValuePtr keyPtr = key->BuildValue();
@@ -82,7 +84,8 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive
const AbstractBasePtrList &args_spec_list) {
// Inputs: a string and a keyword.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
constexpr int args_spec_size = 2;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);

@@ -103,7 +106,8 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive
template <typename T>
AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list and a scalar whose value is an int32 number.
CheckArgsSize(op_name, args_spec_list, 2);
constexpr int args_spec_size = 2;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);

@@ -117,26 +121,41 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
}
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index->ToString();
}
auto idx_v = GetValue<int64_t>(index_value);
auto index_int64_value = GetValue<int64_t>(index_value);
std::size_t nelems = queue->elements().size();
if (idx_v >= SizeToLong(nelems) || idx_v < -SizeToLong(nelems)) {
if (index_int64_value >= SizeToLong(nelems) || index_int64_value < -SizeToLong(nelems)) {
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToLong(nelems) << ", "
<< SizeToLong(nelems) << "), but got " << idx_v << ".";
<< SizeToLong(nelems) << "), but got " << index_int64_value << ".";
}

std::size_t uidx_v = 0;
if (idx_v >= 0) {
uidx_v = LongToSize(idx_v);
std::size_t index_unsigned_value = 0;
if (index_int64_value >= 0) {
index_unsigned_value = LongToSize(index_int64_value);
} else {
uidx_v = LongToSize(idx_v + SizeToLong(nelems));
index_unsigned_value = LongToSize(index_int64_value + SizeToLong(nelems));
}
return queue->elements()[uidx_v];
if (!queue->sequence_nodes().empty()) {
for (auto &node : queue->sequence_nodes()) {
auto sequence_node = node.lock();
if (sequence_node == nullptr) {
MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
continue;
}
auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
if (flags != nullptr) {
(*flags)[index_unsigned_value] = true;
MS_LOG(DEBUG) << "Set item[" << index_unsigned_value << "] as use flag for " << sequence_node->DebugString();
}
}
}
return queue->elements()[index_unsigned_value];
}

template <typename T>
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list, a scalar whose value is an int64 number and an object of a subclass of AbstractBase.
CheckArgsSize(op_name, args_spec_list, 3);
constexpr int args_spec_size = 3;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);

@@ -146,16 +165,17 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
<< index_value->ToString();
}
auto idx_v = GetValue<int64_t>(index_value);
auto index_int64_value = GetValue<int64_t>(index_value);
AbstractBasePtrList elements = queue->elements();
std::size_t nelems = elements.size();
int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems);
if (idx_t < 0 || idx_t >= SizeToLong(nelems)) {
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << idx_v << " to set out of range: [-" << nelems
<< "," << (nelems - 1) << "].";
int64_t index_positive_value = index_int64_value >= 0 ? index_int64_value : index_int64_value + SizeToLong(nelems);
if (index_positive_value < 0 || index_positive_value >= SizeToLong(nelems)) {
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << index_int64_value << " to set out of range: [-"
<< nelems << "," << (nelems - 1) << "].";
}
size_t uidx_v = LongToSize(idx_t);
elements[uidx_v] = args_spec_list[2];
size_t index_unsigned_value = LongToSize(index_positive_value);
constexpr int target_value_index = 2;
elements[index_unsigned_value] = args_spec_list[target_value_index];
return std::make_shared<T>(elements);
}

@@ -183,7 +203,8 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict and a scalar whose value is a string.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
constexpr int args_spec_size = 2;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);

@@ -206,7 +227,8 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
constexpr int args_spec_size = 3;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);

@@ -235,7 +257,8 @@ AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
AbstractBasePtrList keys;
@@ -248,7 +271,8 @@ AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const Primitiv
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
AbstractBasePtrList values;
@@ -261,7 +285,8 @@ AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
AbstractBasePtrList items;
@@ -276,7 +301,8 @@ AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePt
const AbstractBasePtrList &args_spec_list) {
// Inputs: a list and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
constexpr int args_spec_size = 2;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
AbstractBasePtr item = dyn_cast<AbstractBase>(args_spec_list[1]);
MS_EXCEPTION_IF_NULL(item);
@@ -298,7 +324,8 @@ AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
auto arg_abs = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto shape = arg_abs->BuildShape()->cast<ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);


+ 1
- 0
mindspore/core/base/base.h View File

@@ -203,6 +203,7 @@ using AnfNodePtr = std::shared_ptr<AnfNode>;
using AnfNodePtrList = std::vector<AnfNodePtr>;
using AnfNodeSet = OrderedSet<AnfNodePtr>;
using AnfNodeWeakPtr = std::weak_ptr<AnfNode>;
using AnfNodeWeakPtrList = std::vector<AnfNodeWeakPtr>;

class FuncGraph;
using FuncGraphPtr = std::shared_ptr<FuncGraph>;


+ 50
- 0
mindspore/core/ir/anf.cc View File

@@ -643,4 +643,54 @@ bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set)
}
return IsOneOfPrimitive(cnode->input(0), prim_set);
}

// Set the sequence nodes' elements use flags all true.
void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag) {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!enable_eliminate_unused_element) {
return;
}

auto sequence_abs = dyn_cast<abstract::AbstractSequence>(abs);
if (sequence_abs == nullptr) {
return;
}
if (sequence_abs->sequence_nodes().empty()) {
return;
}
for (auto &weak_node : sequence_abs->sequence_nodes()) {
auto sequence_node = weak_node.lock();
if (sequence_node == nullptr) {
MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
continue;
}
auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
if (flags != nullptr) {
auto &all_flags = (*flags);
std::transform(all_flags.begin(), all_flags.end(), all_flags.begin(),
[&new_flag](bool) -> bool { return new_flag; });
}
}
}

// Set the sequence nodes' elements use flags all true recursively.
void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag) {
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (!enable_eliminate_unused_element) {
return;
}

SetSequenceElementsUseFlags(abs, new_flag);

// Check its elements if it's sequence node.
auto sequence_abs = dyn_cast<abstract::AbstractSequence>(abs);
if (sequence_abs == nullptr) {
return;
}
for (auto &element : sequence_abs->elements()) {
SetSequenceElementsUseFlagsRecursively(element, new_flag);
}
}
} // namespace mindspore

+ 13
- 1
mindspore/core/ir/anf.h View File

@@ -301,7 +301,7 @@ class MS_CORE_API AnfNode : public Base {
user_data_.set<T>(T::key, value);
}

/// \brief Set user data.
/// \brief Get user data.
///
/// \param[in] key The key of user data.
/// \return Pointer to user data.
@@ -1200,6 +1200,18 @@ struct GraphSegment {
uint32_t graph_id_{0};
};
using GraphSegmentPtr = std::shared_ptr<GraphSegment>;

constexpr auto kElementsUseFlagsKey = "elements_use_flags";
inline std::shared_ptr<std::vector<bool>> GetSequenceNodeElementsUseFlags(const AnfNodePtr &node) {
return node->template user_data<std::vector<bool>>(kElementsUseFlagsKey);
}

inline void SetSequenceNodeElementsUseFlags(const AnfNodePtr &node, const std::shared_ptr<std::vector<bool>> &flags) {
node->set_user_data(kElementsUseFlagsKey, flags);
}

void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag);
void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag);
} // namespace mindspore

#endif // MINDSPORE_CORE_IR_ANF_H_

+ 2
- 2
mindspore/core/ir/func_graph_cloner.cc View File

@@ -140,7 +140,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) {
repl_node_[node] = std::move(new_const);
}

void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
void Cloner::CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target);
auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
@@ -232,7 +232,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
auto parent = cnode.first->first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(parent);
const auto &valuenode = parent->input(cnode.first->second);
CloneValueNode(valuenode, target_func_graph);
CloneFuncGraphValueNode(valuenode, target_func_graph);
}
}



+ 1
- 1
mindspore/core/ir/func_graph_cloner.h View File

@@ -88,7 +88,7 @@ class Cloner {
void SetDefaults();
void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneValueNode(const AnfNodePtr &node);
void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false);
void CloneValueNodes(const FuncGraphPtr &func_graph);


+ 4
- 1
mindspore/core/ops/addn.cc View File

@@ -98,7 +98,10 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args));
auto res = abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args));
// For AddN(MakeTuple()) or AddN(tuple), set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[0], true);
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true);
} // namespace ops


+ 14
- 8
tests/ut/cpp/operator/composite_test.cc View File

@@ -160,7 +160,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};

AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
AbstractTuplePtr ret =
dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
@@ -186,7 +187,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};

AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
AbstractTuplePtr ret =
dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
@@ -212,7 +214,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};

AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
AbstractTuplePtr ret =
dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
@@ -238,7 +241,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};

AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
AbstractTuplePtr ret =
dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
@@ -265,7 +269,8 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);

AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
AbstractTuplePtr ret =
dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
@@ -292,7 +297,8 @@ TEST_F(TestComposite, test_UnpackCall_5args) {
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);

AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
AbstractTuplePtr ret =
dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
@@ -314,7 +320,7 @@ TEST_F(TestComposite, test_ZipOperation) {
auto tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tuple};

AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract());
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
@@ -362,7 +368,7 @@ TEST_F(TestComposite, test_shard) {
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
AbstractBasePtrList args_spec_list = {tensor};

auto ret = engine_->Run(shard_func_graph, args_spec_list).inferred->abstract();
auto ret = engine_->Run(shard_func_graph, args_spec_list).eval_result->abstract();
ASSERT_NE(ret, nullptr);
ASSERT_TRUE(ret->isa<abstract::AbstractTensor>());
auto build_shape = ret->BuildShape();


+ 7
- 7
tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc View File

@@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) {
std::vector<int64_t> shape = {2, 2, 6, 6};
expected->set_shape(std::make_shared<Shape>(shape));

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected: " << expected->ToString();

@@ -143,7 +143,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) {
AbstractBasePtr abstract_x = FromValue(x, false);
args_spec_list.push_back(abstract_x);

AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
}
@@ -159,7 +159,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) {
AbstractBasePtr abstract_x = FromValue(x, false);
args_spec_list.push_back(abstract_x);

AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
}
@@ -178,7 +178,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);

AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
@@ -197,7 +197,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);

AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
@@ -216,7 +216,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);

AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
@@ -235,7 +235,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);

AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}


+ 48
- 48
tests/ut/cpp/pipeline/static_analysis/prim_test.cc View File

@@ -139,7 +139,7 @@ TEST_F(TestPrim, test_typeof) {

auto prim_typeof = std::make_shared<Primitive>("typeof");
FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
res->dump();
TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
res_value->dump();
@@ -164,7 +164,7 @@ TEST_F(TestPrim, test_list_map) {

auto prim_list_map = std::make_shared<Primitive>("list_map");
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
auto expected = std::make_shared<AbstractList>(
AbstractBasePtrList({FromValue(static_cast<int64_t>(3), false), FromValue(static_cast<int64_t>(3), false)}));
res->dump();
@@ -189,7 +189,7 @@ TEST_F(TestPrim, test_list_reduce) {

auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
res->dump();
TypePtr res_type = res->GetTypeTrack();
res_type->dump();
@@ -206,7 +206,7 @@ TEST_F(TestPrim, test_scalar_to_array) {

auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
res->dump();
TypePtr res_type = res->BuildType();
res_type->dump();
@@ -224,7 +224,7 @@ TEST_F(TestPrim, test_array_to_scalar) {

auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
res->dump();
TypePtr res_type = res->BuildType();
res_type->dump();
@@ -240,7 +240,7 @@ TEST_F(TestPrim, test_J_1) {

auto prim_J = std::make_shared<Primitive>("J");
FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
ASSERT_TRUE(res_J != nullptr);
ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
@@ -280,7 +280,7 @@ TEST_F(TestPrim, test_J_2) {
int64_t v1 = 1;
AbstractBasePtr abstract_v1 = FromValue(v1, false);
AbstractBasePtrList args_spec_list = {abstract_v1};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
res->dump();
AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
ASSERT_TRUE(res_J != nullptr);
@@ -301,7 +301,7 @@ TEST_F(TestPrim, test_switch1) {
AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false);
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *arg1);
}

@@ -314,7 +314,7 @@ TEST_F(TestPrim, test_switch2) {
AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false);
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "make result res: " << res->ToString();
MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
ASSERT_TRUE(*res == *arg2);
@@ -327,7 +327,7 @@ TEST_F(TestPrim, test_identity) {
AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtrList args_spec_list = {abstract_v1};

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *abstract_v1);
}

@@ -341,7 +341,7 @@ TEST_F(TestPrim, test_broadcast_shape) {

AbstractBasePtrList args_spec_list = {a, b};

AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).eval_result->abstract());

auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
@@ -361,7 +361,7 @@ TEST_F(TestPrim, test_partial) {
AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
auto expected = std::make_shared<PartialAbstractClosure>(
std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
@@ -377,7 +377,7 @@ TEST_F(TestPrim, test_environ_set) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract();

FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvironSet, 3);

@@ -385,7 +385,7 @@ TEST_F(TestPrim, test_environ_set) {
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
args_spec_list = {abstract_environ, embed_x, abstract_y};

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*res == *exp);
}
@@ -397,7 +397,7 @@ TEST_F(TestPrim, test_environ_get) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract();

FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3);

@@ -405,7 +405,7 @@ TEST_F(TestPrim, test_environ_get) {
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
args_spec_list = {abstract_environ, embed_x, abstract_y};

AbstractBasePtr res = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*res == *exp);

@@ -414,7 +414,7 @@ TEST_F(TestPrim, test_environ_get) {
AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
args_spec_list = {res, embed_x, abstract_z};

res = engine_->Run(graph_environ_get, args_spec_list).inferred->abstract();
res = engine_->Run(graph_environ_get, args_spec_list).eval_result->abstract();

ASSERT_TRUE(*res == *abstract_x);
}
@@ -426,7 +426,7 @@ TEST_F(TestPrim, test_environ_add) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).eval_result->abstract();

FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3);

@@ -434,19 +434,19 @@ TEST_F(TestPrim, test_environ_add) {
AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
args_spec_list = {abstract_environ, embed_x, abstract_y};

AbstractBasePtr abstract_e1 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
AbstractBasePtr abstract_e1 = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*abstract_e1 == *exp);

AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
args_spec_list = {abstract_environ, embed_x, abstract_z};

AbstractBasePtr abstract_e2 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract();
AbstractBasePtr abstract_e2 = engine_->Run(graph_environ_set, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*abstract_e2 == *exp);

FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvironAdd, 2);
args_spec_list = {abstract_e1, abstract_e2};
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).eval_result->abstract();

ASSERT_TRUE(*res == *exp);
}
@@ -459,7 +459,7 @@ TEST_F(TestPrim, test_relu) {
AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW
AbstractBasePtrList args_spec_list = {expected};

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *expected);
}

@@ -472,7 +472,7 @@ TEST_F(TestPrim, test_relu2) {
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});

AbstractBasePtrList args_spec_list = {arr};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
auto res = dyn_cast<AbstractTensor>(ret);
ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
}
@@ -505,7 +505,7 @@ TEST_F(TestPrim, test_conv2d1) {
std::vector<int64_t> shape = {2, 64, 14, 14};
expected->set_shape(std::make_shared<Shape>(shape));

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected: " << expected->ToString();

@@ -523,7 +523,7 @@ TEST_F(TestPrim, test_conv2d) {
auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});

AbstractBasePtrList args_spec_list = {input, weight};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
auto res = dyn_cast<AbstractTensor>(ret);
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
MS_LOG(INFO) << "result: " << res->ToString();
@@ -539,7 +539,7 @@ TEST_F(TestPrim, test_conv2d_native) {
auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});

AbstractBasePtrList args_spec_list = {input, weight};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
auto res = dyn_cast<AbstractTensor>(ret);
auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
MS_LOG(INFO) << "result: " << res->ToString();
@@ -555,7 +555,7 @@ TEST_F(TestPrim, test_biasAdd) {
auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});

AbstractBasePtrList args_spec_list = {value, bias};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
auto res = dyn_cast<AbstractTensor>(ret);
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
MS_LOG(INFO) << "result: " << res->ToString();
@@ -571,7 +571,7 @@ TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) {
auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});

AbstractBasePtrList args_spec_list = {logits, labels};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_NE(ret, nullptr);
auto res = dyn_cast<AbstractTuple>(ret);
auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
@@ -600,7 +600,7 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) {
auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});

AbstractBasePtrList args_spec_list = {logits, labels};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
auto res = dyn_cast<AbstractScalar>(ret);
AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
expected->set_type(UTPrimUtils::kF64);
@@ -627,7 +627,7 @@ TEST_F(TestPrim, test_pooling) {
inputs->set_shape(inputs_dims);
AbstractBasePtr abstract_input = FromValue(inputs, false);
AbstractBasePtrList args_spec_list = {abstract_input};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();

AbstractBasePtr expected = abstract_input->Clone()->Broaden();
std::vector<int64_t> expected_dims = {8, 64, 2, 2};
@@ -652,7 +652,7 @@ TEST_F(TestPrim, test_hastype) {
auto prim = std::make_shared<Primitive>("hastype");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *expected);
}

@@ -666,7 +666,7 @@ TEST_F(TestPrim, test_array_len) {
auto prim = std::make_shared<Primitive>("array_len");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *expected);
}

@@ -680,7 +680,7 @@ TEST_F(TestPrim, test_list_len) {
auto prim = std::make_shared<Primitive>("list_len");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *expected);
}

@@ -694,7 +694,7 @@ TEST_F(TestPrim, test_tuple_len) {
auto prim = std::make_shared<Primitive>("tuple_len");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *expected);
}

@@ -708,7 +708,7 @@ TEST_F(TestPrim, test_tuple_reversed) {
auto prim = std::make_shared<Primitive>("tuple_reversed");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "expect=" << expected->ToString();
ASSERT_TRUE(*res == *expected);
}
@@ -730,7 +730,7 @@ TEST_F(TestPrim, test_list_getitem) {
auto prim = std::make_shared<Primitive>("list_getitem");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *elem);
}

@@ -749,7 +749,7 @@ TEST_F(TestPrim, test_list_setitem) {
auto prim = std::make_shared<Primitive>("list_setitem");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "result: " << res->ToString();
AbstractBasePtrList elems_exp = {elem1, elem2};
auto expected = std::make_shared<AbstractList>(elems_exp);
@@ -771,7 +771,7 @@ TEST_F(TestPrim, test_list_append) {
auto prim = std::make_shared<Primitive>("list_append");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "result: " << res->ToString();
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
MS_LOG(INFO) << "expected: " << expected->ToString();
@@ -795,7 +795,7 @@ TEST_F(TestPrim, test_tuple_setitem) {
auto prim = std::make_shared<Primitive>("tuple_setitem");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "result: " << res->ToString();
AbstractBasePtrList elems_exp = {elem1, elem2};
auto expected = std::make_shared<AbstractTuple>(elems_exp);
@@ -821,7 +821,7 @@ TEST_F(TestPrim, test_make_list) {
auto prim = std::make_shared<Primitive>("make_list");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(*res == *expected);
}

@@ -844,7 +844,7 @@ TEST_F(TestPrim, test_make_range) {
AbstractBasePtrList elem_list({ele1, ele2, ele3});
AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "res=" << res->ToString();
MS_LOG(INFO) << "expected=" << expected->ToString();
ASSERT_TRUE(*res == *expected);
@@ -887,7 +887,7 @@ TEST_F(TestPrim, test_layernorm) {
AbstractBasePtr expected1 = abstract_mean_var->Clone();
AbstractBasePtr expected2 = abstract_mean_var->Clone();

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected0: " << expected0->ToString();
MS_LOG(INFO) << "expected1: " << expected1->ToString();
@@ -933,7 +933,7 @@ TEST_F(TestPrim, test_DropoutGenMask) {
AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
std::make_shared<Shape>(std::vector<int64_t>{79}));

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "res=" << res->ToString();
MS_LOG(INFO) << "expected=" << expected->ToString();
ASSERT_TRUE(*res == *expected);
@@ -963,7 +963,7 @@ TEST_F(TestPrim, test_dropout) {
std::vector<int64_t> shape = {2, 20, 32, 32};
expected->set_shape(std::make_shared<Shape>(shape));

AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected: " << expected->ToString();

@@ -984,7 +984,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) {
auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
AbstractBasePtrList args_spec_list = {x_input, y_input};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
auto res = dyn_cast<AbstractTuple>(ret);
AbstractBasePtrList x_idx_list;
auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
@@ -1008,7 +1008,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) {
auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
AbstractBasePtrList args_spec_list = {x_input, y_input};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
auto res = dyn_cast<AbstractTuple>(ret);
AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
@@ -1033,7 +1033,7 @@ TEST_F(TestPrim, test_DictGetItem) {
AbstractBasePtr key = abstract::FromValue("x");
AbstractBasePtrList args_spec_list = {array_dict, key};

AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));

@@ -1052,7 +1052,7 @@ TEST_F(TestPrim, test_DictGetItem2) {
AbstractBasePtr key = abstract::FromValue("x");
AbstractBasePtrList args_spec_list = {array_dict, key};

AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);



+ 12
- 12
tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc View File

@@ -164,7 +164,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {

auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).eval_result->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
}

@@ -262,7 +262,7 @@ TEST_F(TestInferGraph, test_inferred) {
MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
args_spec_list.push_back(abstract_v1);
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).eval_result->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());

// now this test case failed randomly, have to debug.
@@ -273,7 +273,7 @@ TEST_F(TestInferGraph, test_inferred) {
args_spec_list.clear();
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract();
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).eval_result->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
}

@@ -359,7 +359,7 @@ TEST_F(TestInferMetaGraph, test_inferred) {
AbstractBasePtr abstract_v2 = FromValue(v1, false);
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).eval_result->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
}

@@ -391,7 +391,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {

auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract();
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).eval_result->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt64);
}
@@ -446,7 +446,7 @@ void TestGraphEval::TearDown() {
TEST_F(TestGraphInfer, test_graph_infer_defaults) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract();
AbstractBasePtr expect = FromValue(MakeValue(50), false);
ASSERT_EQ(*res, *expect);
}
@@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract();
AbstractBasePtr expect = FromValue(MakeValue(1), false);
ASSERT_EQ(*res, *expect);
}
@@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
TEST_F(TestGraphInfer, test_graph_infer_vararg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract();
AbstractBasePtr expect = FromValue(MakeValue(9), false);
ASSERT_EQ(*res, *expect);
}
@@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract();
AbstractBasePtr expect = FromValue(MakeValue(48), false);
ASSERT_EQ(*res, *expect);
}
@@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract();
AbstractBasePtr expect = FromValue(MakeValue(7), false);
ASSERT_EQ(*res, *expect);
}
@@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract();
AbstractBasePtr expect = FromValue(MakeValue(46), false);
ASSERT_EQ(*res, *expect);
}
@@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr res = engine_->Run(graph, args_spec_list).eval_result->abstract();
AbstractBasePtr expect = FromValue(MakeValue(57), false);
ASSERT_EQ(*res, *expect);
}


Loading…
Cancel
Save