Browse Source

add func graph cache avoid repeat func graphs

tags/v1.1.0
yao_yf 5 years ago
parent
commit
1a2fd0e0b0
2 changed files with 9 additions and 0 deletions
  1. +4
    -0
      mindspore/core/ir/func_graph.h
  2. +5
    -0
      mindspore/core/ir/func_graph_extends.cc

+ 4
- 0
mindspore/core/ir/func_graph.h View File

@@ -34,6 +34,7 @@
#include "utils/ordered_map.h"
#include "base/base_ref.h"
#include "ir/func_graph_cloner.h"
#include "abstract/abstract_value.h"

namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
@@ -417,6 +418,9 @@ class FuncGraph : public FuncGraphBase {
// Design switch_layer_input as a ptr to
// share between derived backpropagator and cloned graphs
std::shared_ptr<bool> switch_layer_input_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
abstract::AbstractBasePtrListEqual>
func_graph_cache_;
};

inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {


+ 5
- 0
mindspore/core/ir/func_graph_extends.cc View File

@@ -245,6 +245,10 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
if (!NeedGenerate(kwarg_list)) {
return shared_from_base<FuncGraph>();
}
auto iter = func_graph_cache_.find(args_spec_list);
if (iter != func_graph_cache_.end()) {
return iter->second;
}
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
size_t kwarg_count = kwarg_list.size();
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_);
@@ -290,6 +294,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
specialized_graph->set_kwonlyargs_count(0);
specialized_graph->ClearDefaultValues();
specialized_graph->set_is_generate(true);
func_graph_cache_[args_spec_list] = specialized_graph;
return specialized_graph;
}



Loading…
Cancel
Save