From 4e6e68f187b97eb66f93bb85609284f2bfbd4418 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Fri, 6 Nov 2020 13:29:18 +0800 Subject: [PATCH] Resolve specialize error during transforming after block in PyNative mode. --- .../optimizer/irpass/specialize_transform.h | 27 ++++++++++--------- .../pipeline/jit/static_analysis/evaluator.cc | 1 + 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h b/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h index 6cb9312028..ae18ea5e91 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h @@ -23,6 +23,7 @@ #include #include #include +#include #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" @@ -42,13 +43,13 @@ class SpecializeTransform { ~SpecializeTransform() = default; FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector graph_args, - std::vector prim_args, std::vector value_args) { + std::vector prim_args, std::vector tensor_value_args) { if (cache_.count(func_graph) == 0) { cache_[func_graph] = {}; } auto &cache = cache_[func_graph]; - auto key = std::make_pair(graph_args, prim_args); + auto key = std::make_tuple(graph_args, prim_args, tensor_value_args); if (cache.count(key) == 0) { auto mng = func_graph->manager(); MS_EXCEPTION_IF_NULL(mng); @@ -70,8 +71,8 @@ class SpecializeTransform { (void)mng->Replace(params[i], arg); continue; } - if (value_args[i] != nullptr) { - auto &const_tensor = *value_args[i]; + if (tensor_value_args[i] != nullptr) { + auto &const_tensor = *tensor_value_args[i]; auto const_tensor_ptr = std::make_shared(const_tensor); AnfNodePtr arg = NewValueNode(const_tensor_ptr); (void)mng->Replace(params[i], arg); @@ -87,8 +88,10 @@ class SpecializeTransform { } private: - std::unordered_map, std::vector>, FuncGraphPtr>> + std::unordered_map< + FuncGraphPtr, + std::map, std::vector, std::vector>, + FuncGraphPtr>> cache_; }; } // namespace internal @@ -116,7 +119,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { std::vector graph_args; std::vector prim_args; - std::vector value_node_args; + std::vector tensor_value_args; std::vector new_xs; bool hasVNode = false; for (size_t i = 1; i < inputs.size(); i++) { @@ -124,24 +127,24 @@ class SpecializeOnGraphArguments : public AnfVisitor { auto fg_vnode = GetValueNode(inputs[i]); graph_args.push_back(fg_vnode); prim_args.emplace_back(nullptr); - value_node_args.emplace_back(nullptr); + tensor_value_args.emplace_back(nullptr); hasVNode = true; } else if (IsValueNode(inputs[i])) { auto p_vnode = GetValueNode(inputs[i]); graph_args.emplace_back(nullptr); prim_args.push_back(p_vnode); - value_node_args.emplace_back(nullptr); + tensor_value_args.emplace_back(nullptr); hasVNode = true; } else if (IsValueNode(inputs[i])) { tensor::TensorPtr t_vnode = GetValueNode(inputs[i]); graph_args.emplace_back(nullptr); prim_args.emplace_back(nullptr); - value_node_args.emplace_back(t_vnode); + tensor_value_args.emplace_back(t_vnode); hasVNode = true; } else { graph_args.emplace_back(nullptr); prim_args.emplace_back(nullptr); - value_node_args.emplace_back(nullptr); + tensor_value_args.emplace_back(nullptr); new_xs.push_back(inputs[i]); } } @@ -150,7 +153,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { return nullptr; } - auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); + auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, tensor_value_args); (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); return node->func_graph()->NewCNode(new_xs); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index d9f3afb06f..6f3bbd4313 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -17,6 +17,7 @@ #include "pipeline/jit/static_analysis/evaluator.h" #include +#include #include #include "ir/func_graph_cloner.h"