add cache for PythonPrimEvaluator. Be careful that the infer function of PythonPrimitive in python code should be idempotent.tags/v0.3.0-alpha
| @@ -17,7 +17,9 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <iterator> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -129,29 +131,38 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| return optimizer; | |||
| } | |||
| FuncGraphPtr step(FuncGraphPtr func_graph, const abstract::AbstractBasePtrList &args_spec, bool use_profile = true) { | |||
| FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | |||
| // Optimizer step counter; | |||
| int counter = 1; | |||
| bool changes = true; | |||
| while (changes) { | |||
| changes = false; | |||
| auto run_runc = [&counter, &func_graph, &args_spec, &changes, use_profile, this]() { | |||
| auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() { | |||
| for (size_t i = 0; i < passes_.size(); ++i) { | |||
| const OptPass &opt = passes_[i]; | |||
| auto opt_func = [&func_graph, &args_spec, &changes, &opt, this]() { | |||
| auto opt_func = [&func_graph, &changes, &opt, this]() { | |||
| if (opt.is_renormalize()) { | |||
| auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_); | |||
| if (resource_ptr != nullptr) { | |||
| // StepParallel may replace the AbstractValue of the parameters of func_graph, | |||
| // So generate the args_spec from parameters. | |||
| abstract::AbstractBasePtrList maybe_new_args_spec; | |||
| if (is_watch_renormalize_) { | |||
| if (untyped_nodes_.size() > 0) { | |||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); | |||
| std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), | |||
| std::back_inserter(maybe_new_args_spec), | |||
| [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); | |||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); | |||
| clear_untyped_nodes(); | |||
| } else { | |||
| MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; | |||
| } | |||
| } else { | |||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); | |||
| std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), | |||
| std::back_inserter(maybe_new_args_spec), | |||
| [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); | |||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); | |||
| } | |||
| } | |||
| } else if (opt(func_graph, shared_from_this())) { | |||
| @@ -1230,7 +1230,11 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||
| << MakeValue(slice_shape)->ToString(); | |||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | |||
| MS_EXCEPTION_IF_NULL(parallel_shape); | |||
| abstract->set_shape(parallel_shape); | |||
| // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis. | |||
| auto cloned_abstract = abstract->Clone(); | |||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||
| cloned_abstract->set_shape(parallel_shape); | |||
| parameter->set_abstract(cloned_abstract); | |||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | |||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter_ptr); | |||
| @@ -1330,7 +1334,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); | |||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | |||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | |||
| cloned_parameter_node->abstract()->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | |||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | |||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||
| cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | |||
| cloned_parameter_node->set_abstract(cloned_abstract); | |||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() | |||
| << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() | |||
| << ", clone index is: " << cloned_index; | |||
| @@ -1743,7 +1750,10 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l | |||
| auto slice_shape = loss_grad_layout.slice_shape().array(); | |||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | |||
| MS_EXCEPTION_IF_NULL(parallel_shape); | |||
| abstract->set_shape(parallel_shape); | |||
| auto cloned_abstract = abstract->Clone(); | |||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||
| cloned_abstract->set_shape(parallel_shape); | |||
| sens_tensor_node->set_abstract(cloned_abstract); | |||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | |||
| sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| return; | |||
| @@ -276,9 +276,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa | |||
| (void)parse::python_adapter::set_python_scoped(); | |||
| abstract::AbstractBasePtrList args_spec; | |||
| MS_EXCEPTION_IF_NULL(opt_resolve); | |||
| (void)opt_resolve->step(func_graph, args_spec, use_profile); | |||
| (void)opt_resolve->step(func_graph, use_profile); | |||
| return true; | |||
| } | |||
| @@ -205,14 +205,15 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { | |||
| return false; | |||
| } | |||
| abstract::AbstractBasePtrList args = res->args_spec(); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " | |||
| << func_graph->get_return()->DebugString(true); | |||
| InitOpt(res); | |||
| if (g_pass_opts.find(name) != g_pass_opts.end()) { | |||
| res->set_func_graph(g_pass_opts[name]->step(func_graph, args)); | |||
| res->set_func_graph(g_pass_opts[name]->step(func_graph)); | |||
| } | |||
| // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to | |||
| // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here. | |||
| return true; | |||
| } | |||
| @@ -255,10 +256,9 @@ bool ValidatePass(const ResourcePtr &res) { | |||
| bool InferenceOptPreparePass(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| abstract::AbstractBasePtrList args_spec = res->args_spec(); | |||
| auto prepare_map = GetInferenceOptPreparePhases(); | |||
| auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); | |||
| (void)infer_opt_prepare->step(func_graph, args_spec, false); | |||
| (void)infer_opt_prepare->step(func_graph, false); | |||
| return true; | |||
| } | |||
| @@ -260,7 +260,6 @@ AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const Config | |||
| return conf->GetEvaluatedValue(); | |||
| }); | |||
| AbstractBasePtr ret = EvalPrim(engine, args_spec_list); | |||
| (*cache_)[args_spec_list] = ret; | |||
| return ret; | |||
| } | |||
| @@ -405,6 +405,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||
| AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||
| MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | |||
| const auto &iter = cache_->find(args); | |||
| if (iter != cache_->end()) { | |||
| return iter->second; | |||
| } | |||
| auto py_args = PreparePyInputs(prim_py_, args); | |||
| auto pyobj = prim_py_->GetPyObj(); | |||
| @@ -418,6 +422,7 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A | |||
| auto res_spec = PyInferRes2Abstract(prim_py_, output); | |||
| MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; | |||
| (*cache_)[args] = res_spec; | |||
| return res_spec; | |||
| } | |||
| @@ -271,6 +271,18 @@ void AnalysisEngine::ClearEvaluatorCache() { | |||
| MS_EXCEPTION_IF_NULL(evaluator->cache()); | |||
| evaluator->cache()->clear(); | |||
| } | |||
| for (auto &element : prim_constructors_) { | |||
| EvaluatorPtr evaluator = element.second; | |||
| MS_EXCEPTION_IF_NULL(evaluator); | |||
| MS_EXCEPTION_IF_NULL(evaluator->cache()); | |||
| evaluator->cache()->clear(); | |||
| } | |||
| for (auto &element : prim_py_evaluators_) { | |||
| EvaluatorPtr evaluator = element.second; | |||
| MS_EXCEPTION_IF_NULL(evaluator); | |||
| MS_EXCEPTION_IF_NULL(evaluator->cache()); | |||
| evaluator->cache()->clear(); | |||
| } | |||
| } | |||
| void AnalysisEngine::Clear() { | |||
| @@ -296,7 +308,17 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| if (prim->HasPyEvaluator()) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | |||
| if (prim_py != nullptr) { | |||
| return std::make_shared<PythonPrimEvaluator>(prim_py); | |||
| if (engine == nullptr) { | |||
| return std::make_shared<PythonPrimEvaluator>(prim_py); | |||
| } | |||
| const auto &iter = engine->prim_py_evaluators_.find(prim_py); | |||
| if (iter != engine->prim_py_evaluators_.end()) { | |||
| return iter->second; | |||
| } | |||
| evaluator = std::make_shared<PythonPrimEvaluator>(prim_py); | |||
| engine->prim_py_evaluators_[prim_py] = evaluator; | |||
| return evaluator; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; | |||
| } | |||
| @@ -194,6 +194,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } | |||
| AnalysisCache cache_; | |||
| std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | |||
| private: | |||
| const PrimEvaluatorMap &prim_constructors_; | |||
| @@ -57,8 +57,7 @@ TEST_F(TestOptOptimizer, test_step_opt) { | |||
| true); | |||
| EXPECT_TRUE(optimizer.get() != nullptr); | |||
| abstract::AbstractBasePtrList args; | |||
| auto after = optimizer->step(before, args); | |||
| auto after = optimizer->step(before); | |||
| draw::Draw("optimizer_test_expendJ_before.dot", before); | |||
| draw::Draw("optimizer_test_expendJ_after.dot", after); | |||