|
|
|
@@ -189,12 +189,8 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa |
|
|
|
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); |
|
|
|
func_graph_->joined_shapes_.clear(); |
|
|
|
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), |
|
|
|
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { |
|
|
|
if (arg_spec->isa<AbstractRef>()) { |
|
|
|
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack(); |
|
|
|
} |
|
|
|
return arg_spec->GetShapeTrack(); |
|
|
|
}); |
|
|
|
std::back_inserter(func_graph_->joined_shapes_), |
|
|
|
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); |
|
|
|
joined_args_spec_list = NormalizeArgs(joined_args_spec_list); |
|
|
|
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; |
|
|
|
} |
|
|
|
@@ -212,12 +208,8 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa |
|
|
|
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); |
|
|
|
func_graph_->joined_shapes_.clear(); |
|
|
|
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), |
|
|
|
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { |
|
|
|
if (arg_spec->isa<AbstractRef>()) { |
|
|
|
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack(); |
|
|
|
} |
|
|
|
return arg_spec->GetShapeTrack(); |
|
|
|
}); |
|
|
|
std::back_inserter(func_graph_->joined_shapes_), |
|
|
|
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); |
|
|
|
joined_args_spec_list = NormalizeArgs(joined_args_spec_list); |
|
|
|
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; |
|
|
|
} |
|
|
|
@@ -317,10 +309,17 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args |
|
|
|
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, |
|
|
|
AnfNodeConfigPtr) { |
|
|
|
AbstractBasePtrList args_spec_list; |
|
|
|
auto is_py_eval = (identifier_ == "PythonPrimEvaluator"); |
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), |
|
|
|
[](const ConfigPtr &conf) -> AbstractBasePtr { |
|
|
|
[is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr { |
|
|
|
MS_EXCEPTION_IF_NULL(conf); |
|
|
|
return conf->GetEvaluatedValue()->abstract(); |
|
|
|
auto abstract = conf->GetEvaluatedValue()->abstract(); |
|
|
|
// broaden the ref_key, while infer python prim for cache |
|
|
|
if (is_py_eval && abstract->isa<AbstractRef>()) { |
|
|
|
auto abs_ref = abstract->cast<AbstractRefPtr>(); |
|
|
|
abstract = std::make_shared<AbstractRef>(abs_ref->ref_key()->Broaden(), abs_ref); |
|
|
|
} |
|
|
|
return abstract; |
|
|
|
}); |
|
|
|
EvalResultPtr ret = EvalPrim(engine, args_spec_list); |
|
|
|
return ret; |
|
|
|
|