add cache for PythonPrimEvaluator. Be careful that the infer function of PythonPrimitive in python code should be idempotent.tags/v0.3.0-alpha
| @@ -17,7 +17,9 @@ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ | #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ | ||||
| #define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ | #define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ | ||||
| #include <algorithm> | |||||
| #include <functional> | #include <functional> | ||||
| #include <iterator> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -129,29 +131,38 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| return optimizer; | return optimizer; | ||||
| } | } | ||||
| FuncGraphPtr step(FuncGraphPtr func_graph, const abstract::AbstractBasePtrList &args_spec, bool use_profile = true) { | |||||
| FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | |||||
| // Optimizer step counter; | // Optimizer step counter; | ||||
| int counter = 1; | int counter = 1; | ||||
| bool changes = true; | bool changes = true; | ||||
| while (changes) { | while (changes) { | ||||
| changes = false; | changes = false; | ||||
| auto run_runc = [&counter, &func_graph, &args_spec, &changes, use_profile, this]() { | |||||
| auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() { | |||||
| for (size_t i = 0; i < passes_.size(); ++i) { | for (size_t i = 0; i < passes_.size(); ++i) { | ||||
| const OptPass &opt = passes_[i]; | const OptPass &opt = passes_[i]; | ||||
| auto opt_func = [&func_graph, &args_spec, &changes, &opt, this]() { | |||||
| auto opt_func = [&func_graph, &changes, &opt, this]() { | |||||
| if (opt.is_renormalize()) { | if (opt.is_renormalize()) { | ||||
| auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_); | auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_); | ||||
| if (resource_ptr != nullptr) { | if (resource_ptr != nullptr) { | ||||
| // StepParallel may replace the AbstractValue of the parameters of func_graph, | |||||
| // So generate the args_spec from parameters. | |||||
| abstract::AbstractBasePtrList maybe_new_args_spec; | |||||
| if (is_watch_renormalize_) { | if (is_watch_renormalize_) { | ||||
| if (untyped_nodes_.size() > 0) { | if (untyped_nodes_.size() > 0) { | ||||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); | |||||
| std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), | |||||
| std::back_inserter(maybe_new_args_spec), | |||||
| [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); | |||||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); | |||||
| clear_untyped_nodes(); | clear_untyped_nodes(); | ||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; | MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; | ||||
| } | } | ||||
| } else { | } else { | ||||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); | |||||
| std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), | |||||
| std::back_inserter(maybe_new_args_spec), | |||||
| [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); | |||||
| func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); | |||||
| } | } | ||||
| } | } | ||||
| } else if (opt(func_graph, shared_from_this())) { | } else if (opt(func_graph, shared_from_this())) { | ||||
| @@ -1230,7 +1230,11 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||||
| << MakeValue(slice_shape)->ToString(); | << MakeValue(slice_shape)->ToString(); | ||||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | ||||
| MS_EXCEPTION_IF_NULL(parallel_shape); | MS_EXCEPTION_IF_NULL(parallel_shape); | ||||
| abstract->set_shape(parallel_shape); | |||||
| // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis. | |||||
| auto cloned_abstract = abstract->Clone(); | |||||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||||
| cloned_abstract->set_shape(parallel_shape); | |||||
| parameter->set_abstract(cloned_abstract); | |||||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | ||||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(parameter_ptr); | MS_EXCEPTION_IF_NULL(parameter_ptr); | ||||
| @@ -1330,7 +1334,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||||
| cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); | cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); | ||||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | ||||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | ||||
| cloned_parameter_node->abstract()->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | |||||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | |||||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||||
| cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | |||||
| cloned_parameter_node->set_abstract(cloned_abstract); | |||||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() | MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() | ||||
| << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() | << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() | ||||
| << ", clone index is: " << cloned_index; | << ", clone index is: " << cloned_index; | ||||
| @@ -1743,7 +1750,10 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l | |||||
| auto slice_shape = loss_grad_layout.slice_shape().array(); | auto slice_shape = loss_grad_layout.slice_shape().array(); | ||||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | ||||
| MS_EXCEPTION_IF_NULL(parallel_shape); | MS_EXCEPTION_IF_NULL(parallel_shape); | ||||
| abstract->set_shape(parallel_shape); | |||||
| auto cloned_abstract = abstract->Clone(); | |||||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||||
| cloned_abstract->set_shape(parallel_shape); | |||||
| sens_tensor_node->set_abstract(cloned_abstract); | |||||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | ||||
| sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout)); | sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout)); | ||||
| return; | return; | ||||
| @@ -276,9 +276,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa | |||||
| (void)parse::python_adapter::set_python_scoped(); | (void)parse::python_adapter::set_python_scoped(); | ||||
| abstract::AbstractBasePtrList args_spec; | |||||
| MS_EXCEPTION_IF_NULL(opt_resolve); | MS_EXCEPTION_IF_NULL(opt_resolve); | ||||
| (void)opt_resolve->step(func_graph, args_spec, use_profile); | |||||
| (void)opt_resolve->step(func_graph, use_profile); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -205,14 +205,15 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| abstract::AbstractBasePtrList args = res->args_spec(); | |||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " | MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " | ||||
| << func_graph->get_return()->DebugString(true); | << func_graph->get_return()->DebugString(true); | ||||
| InitOpt(res); | InitOpt(res); | ||||
| if (g_pass_opts.find(name) != g_pass_opts.end()) { | if (g_pass_opts.find(name) != g_pass_opts.end()) { | ||||
| res->set_func_graph(g_pass_opts[name]->step(func_graph, args)); | |||||
| res->set_func_graph(g_pass_opts[name]->step(func_graph)); | |||||
| } | } | ||||
| // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to | |||||
| // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here. | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -255,10 +256,9 @@ bool ValidatePass(const ResourcePtr &res) { | |||||
| bool InferenceOptPreparePass(const ResourcePtr &res) { | bool InferenceOptPreparePass(const ResourcePtr &res) { | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| abstract::AbstractBasePtrList args_spec = res->args_spec(); | |||||
| auto prepare_map = GetInferenceOptPreparePhases(); | auto prepare_map = GetInferenceOptPreparePhases(); | ||||
| auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); | auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); | ||||
| (void)infer_opt_prepare->step(func_graph, args_spec, false); | |||||
| (void)infer_opt_prepare->step(func_graph, false); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -260,7 +260,6 @@ AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const Config | |||||
| return conf->GetEvaluatedValue(); | return conf->GetEvaluatedValue(); | ||||
| }); | }); | ||||
| AbstractBasePtr ret = EvalPrim(engine, args_spec_list); | AbstractBasePtr ret = EvalPrim(engine, args_spec_list); | ||||
| (*cache_)[args_spec_list] = ret; | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -405,6 +405,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||||
| AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | ||||
| MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | ||||
| const auto &iter = cache_->find(args); | |||||
| if (iter != cache_->end()) { | |||||
| return iter->second; | |||||
| } | |||||
| auto py_args = PreparePyInputs(prim_py_, args); | auto py_args = PreparePyInputs(prim_py_, args); | ||||
| auto pyobj = prim_py_->GetPyObj(); | auto pyobj = prim_py_->GetPyObj(); | ||||
| @@ -418,6 +422,7 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A | |||||
| auto res_spec = PyInferRes2Abstract(prim_py_, output); | auto res_spec = PyInferRes2Abstract(prim_py_, output); | ||||
| MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; | MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; | ||||
| (*cache_)[args] = res_spec; | |||||
| return res_spec; | return res_spec; | ||||
| } | } | ||||
| @@ -271,6 +271,18 @@ void AnalysisEngine::ClearEvaluatorCache() { | |||||
| MS_EXCEPTION_IF_NULL(evaluator->cache()); | MS_EXCEPTION_IF_NULL(evaluator->cache()); | ||||
| evaluator->cache()->clear(); | evaluator->cache()->clear(); | ||||
| } | } | ||||
| for (auto &element : prim_constructors_) { | |||||
| EvaluatorPtr evaluator = element.second; | |||||
| MS_EXCEPTION_IF_NULL(evaluator); | |||||
| MS_EXCEPTION_IF_NULL(evaluator->cache()); | |||||
| evaluator->cache()->clear(); | |||||
| } | |||||
| for (auto &element : prim_py_evaluators_) { | |||||
| EvaluatorPtr evaluator = element.second; | |||||
| MS_EXCEPTION_IF_NULL(evaluator); | |||||
| MS_EXCEPTION_IF_NULL(evaluator->cache()); | |||||
| evaluator->cache()->clear(); | |||||
| } | |||||
| } | } | ||||
| void AnalysisEngine::Clear() { | void AnalysisEngine::Clear() { | ||||
| @@ -296,7 +308,17 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||||
| if (prim->HasPyEvaluator()) { | if (prim->HasPyEvaluator()) { | ||||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | auto prim_py = dyn_cast<PrimitivePy>(prim); | ||||
| if (prim_py != nullptr) { | if (prim_py != nullptr) { | ||||
| return std::make_shared<PythonPrimEvaluator>(prim_py); | |||||
| if (engine == nullptr) { | |||||
| return std::make_shared<PythonPrimEvaluator>(prim_py); | |||||
| } | |||||
| const auto &iter = engine->prim_py_evaluators_.find(prim_py); | |||||
| if (iter != engine->prim_py_evaluators_.end()) { | |||||
| return iter->second; | |||||
| } | |||||
| evaluator = std::make_shared<PythonPrimEvaluator>(prim_py); | |||||
| engine->prim_py_evaluators_[prim_py] = evaluator; | |||||
| return evaluator; | |||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; | MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; | ||||
| } | } | ||||
| @@ -194,6 +194,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } | const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } | ||||
| AnalysisCache cache_; | AnalysisCache cache_; | ||||
| std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | |||||
| private: | private: | ||||
| const PrimEvaluatorMap &prim_constructors_; | const PrimEvaluatorMap &prim_constructors_; | ||||
| @@ -57,8 +57,7 @@ TEST_F(TestOptOptimizer, test_step_opt) { | |||||
| true); | true); | ||||
| EXPECT_TRUE(optimizer.get() != nullptr); | EXPECT_TRUE(optimizer.get() != nullptr); | ||||
| abstract::AbstractBasePtrList args; | |||||
| auto after = optimizer->step(before, args); | |||||
| auto after = optimizer->step(before); | |||||
| draw::Draw("optimizer_test_expendJ_before.dot", before); | draw::Draw("optimizer_test_expendJ_before.dot", before); | ||||
| draw::Draw("optimizer_test_expendJ_after.dot", after); | draw::Draw("optimizer_test_expendJ_after.dot", after); | ||||