Browse Source

Resolve specialize error during transforming after block in PyNative mode.

tags/v1.1.0
Zhang Qinghua 5 years ago
parent
commit
4e6e68f187
2 changed files with 16 additions and 12 deletions
  1. +15
    -12
      mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h
  2. +1
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc

+ 15
- 12
mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h View File

@@ -23,6 +23,7 @@
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include <tuple>

#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
@@ -42,13 +43,13 @@ class SpecializeTransform {
~SpecializeTransform() = default;

FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) {
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> tensor_value_args) {
if (cache_.count(func_graph) == 0) {
cache_[func_graph] = {};
}

auto &cache = cache_[func_graph];
auto key = std::make_pair(graph_args, prim_args);
auto key = std::make_tuple(graph_args, prim_args, tensor_value_args);
if (cache.count(key) == 0) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
@@ -70,8 +71,8 @@ class SpecializeTransform {
(void)mng->Replace(params[i], arg);
continue;
}
if (value_args[i] != nullptr) {
auto &const_tensor = *value_args[i];
if (tensor_value_args[i] != nullptr) {
auto &const_tensor = *tensor_value_args[i];
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
(void)mng->Replace(params[i], arg);
@@ -87,8 +88,10 @@ class SpecializeTransform {
}

private:
std::unordered_map<FuncGraphPtr,
std::map<std::pair<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>>, FuncGraphPtr>>
std::unordered_map<
FuncGraphPtr,
std::map<std::tuple<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>, std::vector<tensor::TensorPtr>>,
FuncGraphPtr>>
cache_;
};
} // namespace internal
@@ -116,7 +119,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {

std::vector<FuncGraphPtr> graph_args;
std::vector<PrimitivePtr> prim_args;
std::vector<tensor::TensorPtr> value_node_args;
std::vector<tensor::TensorPtr> tensor_value_args;
std::vector<AnfNodePtr> new_xs;
bool hasVNode = false;
for (size_t i = 1; i < inputs.size(); i++) {
@@ -124,24 +127,24 @@ class SpecializeOnGraphArguments : public AnfVisitor {
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
graph_args.push_back(fg_vnode);
prim_args.emplace_back(nullptr);
value_node_args.emplace_back(nullptr);
tensor_value_args.emplace_back(nullptr);
hasVNode = true;
} else if (IsValueNode<Primitive>(inputs[i])) {
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
graph_args.emplace_back(nullptr);
prim_args.push_back(p_vnode);
value_node_args.emplace_back(nullptr);
tensor_value_args.emplace_back(nullptr);
hasVNode = true;
} else if (IsValueNode<tensor::Tensor>(inputs[i])) {
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr);
value_node_args.emplace_back(t_vnode);
tensor_value_args.emplace_back(t_vnode);
hasVNode = true;
} else {
graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr);
value_node_args.emplace_back(nullptr);
tensor_value_args.emplace_back(nullptr);
new_xs.push_back(inputs[i]);
}
}
@@ -150,7 +153,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
return nullptr;
}

auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args);
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, tensor_value_args);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));

return node->func_graph()->NewCNode(new_xs);


+ 1
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc View File

@@ -17,6 +17,7 @@
#include "pipeline/jit/static_analysis/evaluator.h"

#include <algorithm>
#include <utility>
#include <unordered_set>

#include "ir/func_graph_cloner.h"


Loading…
Cancel
Save