Browse Source

add func graph cache avoid repeat func graphs

tags/v1.1.0
yao_yf 5 years ago
parent
commit
1a2fd0e0b0
2 changed files with 9 additions and 0 deletions
  1. +4
    -0
      mindspore/core/ir/func_graph.h
  2. +5
    -0
      mindspore/core/ir/func_graph_extends.cc

+ 4
- 0
mindspore/core/ir/func_graph.h View File

@@ -34,6 +34,7 @@
#include "utils/ordered_map.h" #include "utils/ordered_map.h"
#include "base/base_ref.h" #include "base/base_ref.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "abstract/abstract_value.h"


namespace mindspore { namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
@@ -417,6 +418,9 @@ class FuncGraph : public FuncGraphBase {
// Design switch_layer_input as a ptr to // Design switch_layer_input as a ptr to
// share between derived backpropagator and cloned graphs // share between derived backpropagator and cloned graphs
std::shared_ptr<bool> switch_layer_input_; std::shared_ptr<bool> switch_layer_input_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
abstract::AbstractBasePtrListEqual>
func_graph_cache_;
}; };


inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) { inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {


+ 5
- 0
mindspore/core/ir/func_graph_extends.cc View File

@@ -245,6 +245,10 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
if (!NeedGenerate(kwarg_list)) { if (!NeedGenerate(kwarg_list)) {
return shared_from_base<FuncGraph>(); return shared_from_base<FuncGraph>();
} }
auto iter = func_graph_cache_.find(args_spec_list);
if (iter != func_graph_cache_.end()) {
return iter->second;
}
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>()); FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
size_t kwarg_count = kwarg_list.size(); size_t kwarg_count = kwarg_list.size();
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_); int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_);
@@ -290,6 +294,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
specialized_graph->set_kwonlyargs_count(0); specialized_graph->set_kwonlyargs_count(0);
specialized_graph->ClearDefaultValues(); specialized_graph->ClearDefaultValues();
specialized_graph->set_is_generate(true); specialized_graph->set_is_generate(true);
func_graph_cache_[args_spec_list] = specialized_graph;
return specialized_graph; return specialized_graph;
} }




Loading…
Cancel
Save