|
|
|
@@ -23,6 +23,7 @@ |
|
|
|
#include <utility> |
|
|
|
#include <unordered_map> |
|
|
|
#include <unordered_set> |
|
|
|
#include <tuple> |
|
|
|
|
|
|
|
#include "frontend/optimizer/irpass.h" |
|
|
|
#include "frontend/optimizer/optimizer.h" |
|
|
|
@@ -42,13 +43,13 @@ class SpecializeTransform { |
|
|
|
~SpecializeTransform() = default; |
|
|
|
|
|
|
|
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args, |
|
|
|
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) { |
|
|
|
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> tensor_value_args) { |
|
|
|
if (cache_.count(func_graph) == 0) { |
|
|
|
cache_[func_graph] = {}; |
|
|
|
} |
|
|
|
|
|
|
|
auto &cache = cache_[func_graph]; |
|
|
|
auto key = std::make_pair(graph_args, prim_args); |
|
|
|
auto key = std::make_tuple(graph_args, prim_args, tensor_value_args); |
|
|
|
if (cache.count(key) == 0) { |
|
|
|
auto mng = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(mng); |
|
|
|
@@ -70,8 +71,8 @@ class SpecializeTransform { |
|
|
|
(void)mng->Replace(params[i], arg); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (value_args[i] != nullptr) { |
|
|
|
auto &const_tensor = *value_args[i]; |
|
|
|
if (tensor_value_args[i] != nullptr) { |
|
|
|
auto &const_tensor = *tensor_value_args[i]; |
|
|
|
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor); |
|
|
|
AnfNodePtr arg = NewValueNode(const_tensor_ptr); |
|
|
|
(void)mng->Replace(params[i], arg); |
|
|
|
@@ -87,8 +88,10 @@ class SpecializeTransform { |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
std::unordered_map<FuncGraphPtr, |
|
|
|
std::map<std::pair<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>>, FuncGraphPtr>> |
|
|
|
std::unordered_map< |
|
|
|
FuncGraphPtr, |
|
|
|
std::map<std::tuple<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>, std::vector<tensor::TensorPtr>>, |
|
|
|
FuncGraphPtr>> |
|
|
|
cache_; |
|
|
|
}; |
|
|
|
} // namespace internal |
|
|
|
@@ -116,7 +119,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { |
|
|
|
|
|
|
|
std::vector<FuncGraphPtr> graph_args; |
|
|
|
std::vector<PrimitivePtr> prim_args; |
|
|
|
std::vector<tensor::TensorPtr> value_node_args; |
|
|
|
std::vector<tensor::TensorPtr> tensor_value_args; |
|
|
|
std::vector<AnfNodePtr> new_xs; |
|
|
|
bool hasVNode = false; |
|
|
|
for (size_t i = 1; i < inputs.size(); i++) { |
|
|
|
@@ -124,24 +127,24 @@ class SpecializeOnGraphArguments : public AnfVisitor { |
|
|
|
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]); |
|
|
|
graph_args.push_back(fg_vnode); |
|
|
|
prim_args.emplace_back(nullptr); |
|
|
|
value_node_args.emplace_back(nullptr); |
|
|
|
tensor_value_args.emplace_back(nullptr); |
|
|
|
hasVNode = true; |
|
|
|
} else if (IsValueNode<Primitive>(inputs[i])) { |
|
|
|
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]); |
|
|
|
graph_args.emplace_back(nullptr); |
|
|
|
prim_args.push_back(p_vnode); |
|
|
|
value_node_args.emplace_back(nullptr); |
|
|
|
tensor_value_args.emplace_back(nullptr); |
|
|
|
hasVNode = true; |
|
|
|
} else if (IsValueNode<tensor::Tensor>(inputs[i])) { |
|
|
|
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]); |
|
|
|
graph_args.emplace_back(nullptr); |
|
|
|
prim_args.emplace_back(nullptr); |
|
|
|
value_node_args.emplace_back(t_vnode); |
|
|
|
tensor_value_args.emplace_back(t_vnode); |
|
|
|
hasVNode = true; |
|
|
|
} else { |
|
|
|
graph_args.emplace_back(nullptr); |
|
|
|
prim_args.emplace_back(nullptr); |
|
|
|
value_node_args.emplace_back(nullptr); |
|
|
|
tensor_value_args.emplace_back(nullptr); |
|
|
|
new_xs.push_back(inputs[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -150,7 +153,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); |
|
|
|
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, tensor_value_args); |
|
|
|
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); |
|
|
|
|
|
|
|
return node->func_graph()->NewCNode(new_xs); |
|
|
|
|