Browse Source

keep tuple_getitem as much as possible to reduce the number of func graphs

shrink output of func_graph other than set unused to dead node
tags/v1.5.0-rc1
zhousiyi 4 years ago
parent
commit
94e5fe6242
3 changed files with 462 additions and 40 deletions
  1. +22
    -3
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +439
    -37
      mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
  3. +1
    -0
      mindspore/ccsrc/pipeline/jit/pass.cc

+ 22
- 3
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -466,6 +466,27 @@ bool NoPartialInPartialGraph(const AnfNodePtr &partial_node) {
return true;
}

// 1. Convert the node to make_tuple if the node is a ValueNode<ValueTuple> and it's the input of 'return' node.
// 2. Set the return of graph if node is "Return" node.
void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
constexpr auto kReturnInputIdx = 1;
auto return_node = node->cast<CNodePtr>();
graph->set_return(return_node);
auto graph_output = return_node->input(kReturnInputIdx);
MS_EXCEPTION_IF_NULL(graph_output);

// If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
// match this pattern because that pass begin with output node but return node. So we add transform value tuple
// to make_tuple here.
if (AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
}
}
}
} // namespace

GraphId SessionBasic::graph_sum_ = 0;
@@ -1483,9 +1504,7 @@ bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph
new_cnode->set_fullname_with_scope(fullname);
new_cnode->set_scope(cnode->scope());
graph->FrontBackendlMapAdd(node, new_cnode);
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
graph->set_return(new_cnode);
}
SetReturnNode(new_cnode, graph);
return true;
}



+ 439
- 37
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h View File

@@ -19,6 +19,7 @@

#include <algorithm>
#include <memory>
#include <set>
#include <unordered_map>
#include <vector>
#include <utility>
@@ -128,6 +129,178 @@ class GetItemTransformACrossGraph {
private:
std::unordered_map<FuncGraphPtr, std::unordered_map<int64_t, FuncGraphPtr>> cache_;
};

bool HasMoreJ(const OptimizerPtr &optimizer) {
bool more_j = false;
auto res = optimizer->resource();
auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(res);
if (resource_ptr != nullptr) {
const auto &manager = optimizer->manager();
MS_EXCEPTION_IF_NULL(manager);
more_j = manager->func_graph_j_total(resource_ptr->func_graph());
}
return more_j;
}

bool IsOutputShrinkable(const AnfNodePtr &output) {
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
return true;
}
if (GetValueNode<ValueTuplePtr>(output)) {
return true;
}
return false;
}

size_t GetOutputSize(const AnfNodePtr &output) {
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
const auto &output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
return output_cnode->size() - 1;
}
const auto &value_tuple = GetValueNode<ValueTuplePtr>(output);
if (value_tuple == nullptr) {
MS_LOG(EXCEPTION) << "fg output is not MakeTuple or ValueTuple, but: " << output->DebugString();
}
return value_tuple->size();
}

struct TpCNodeAndIndex {
// CNode {TupleGetItem, call, index}
CNodePtr tp_cnode;
int64_t index;
};

int64_t UpdateUserNodeIndex(const CNodePtr &fg_call_cnode, const int64_t current_index,
const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) {
const auto &manager = fg_call_cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
int64_t new_index = current_index;
auto txn = manager->Transact();
for (int64_t i = 0; i < SizeToLong(tp_cnodes_and_index.size()); ++i) {
const auto &cnode_and_index = tp_cnodes_and_index[i];
if (cnode_and_index.index != i) {
constexpr auto kInputIndex = 2;
txn.SetEdge(cnode_and_index.tp_cnode, kInputIndex, NewValueNode(i));
}
if (cnode_and_index.index == current_index) {
new_index = i;
}
}
txn.Commit();
return new_index;
}

AbstractBasePtr ShrinkAbstract(const AbstractBasePtr &original_abstract,
const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) {
if (original_abstract != nullptr && original_abstract->isa<abstract::AbstractTuple>()) {
const auto &abs_tuple = original_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abs_tuple);
const auto &abs_tuple_elements = abs_tuple->elements();
const int64_t before_shrink_tuple_size = SizeToLong(abs_tuple_elements.size());
AbstractBasePtrList shrunk_abstract_elements;
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(),
std::back_inserter(shrunk_abstract_elements),
[abs_tuple_elements, before_shrink_tuple_size](const auto &node_and_index) {
if (node_and_index.index >= before_shrink_tuple_size) {
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index
<< ", abstract tuple size: " << before_shrink_tuple_size;
}
return abs_tuple_elements[node_and_index.index];
});
return std::make_shared<abstract::AbstractTuple>(shrunk_abstract_elements);
}
return nullptr;
}

FuncGraphPtr ShrinkUnsedOutput(const FuncGraphPtr &fg, const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) {
const auto &manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);

auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("tp_use"));
auto new_fg_output = new_fg->output();
AnfNodePtr shrunk_output = nullptr;
int64_t before_shrink_inputs_size = 0;
if (IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
// Shrink output;
auto new_fg_output_cnode = new_fg_output->cast<CNodePtr>();
const auto &new_fg_output_inputs = new_fg_output_cnode->inputs();
constexpr auto kMinimalSize = 2;
if (new_fg_output_inputs.size() <= kMinimalSize) {
MS_LOG(EXCEPTION) << "New fg output should at least 2 elements, but: " << new_fg_output->DebugString();
}
before_shrink_inputs_size = SizeToLong(new_fg_output_inputs.size() - 1);
AnfNodePtrList shrunk_inputs{NewValueNode({prim::kPrimMakeTuple})};
// Bypass maketuple primitive in new_fg_output_inputs;
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(), std::back_inserter(shrunk_inputs),
[new_fg_output, new_fg_output_inputs, before_shrink_inputs_size](const auto &node_and_index) {
if (node_and_index.index >= before_shrink_inputs_size) {
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index
<< ", output: " << new_fg_output->DebugString();
}
return new_fg_output_inputs[node_and_index.index + 1];
});
shrunk_output = new_fg->NewCNode(shrunk_inputs);
} else {
auto value_tuple = GetValueNode<ValueTuplePtr>(new_fg_output);
if (value_tuple == nullptr) {
MS_LOG(EXCEPTION) << "New fg output is not MakeTuple or ValueTuple, but " << new_fg_output->DebugString();
}
ValuePtrList shrunk_inputs;
before_shrink_inputs_size = value_tuple->size();
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(), std::back_inserter(shrunk_inputs),
[new_fg_output, value_tuple, before_shrink_inputs_size](const auto &node_and_index) {
if (node_and_index.index >= before_shrink_inputs_size) {
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index
<< ", output: " << new_fg_output->DebugString();
}
return (*value_tuple)[node_and_index.index];
});
shrunk_output = NewValueNode(std::make_shared<ValueTuple>(shrunk_inputs));
}
auto shrunk_abstract = ShrinkAbstract(new_fg_output->abstract(), tp_cnodes_and_index);
MS_EXCEPTION_IF_NULL(shrunk_abstract);
shrunk_output->set_abstract(shrunk_abstract);
new_fg->set_output(shrunk_output);
MS_LOG(DEBUG) << "Partly item used; original size: " << before_shrink_inputs_size
<< ", new size: " << tp_cnodes_and_index.size() << ", fg: " << fg->ToString() << ", new graph"
<< new_fg->ToString();
return new_fg;
}

struct FuncGraphIntVectorPairHasher {
std::size_t Int64VectorHash(const std::vector<int64_t> &int_vector) const {
std::size_t hash_value = 0;
constexpr auto kMaxElementsNum = 4;
for (size_t i = 0; (i < int_vector.size()) && (i < kMaxElementsNum); ++i) {
hash_value = hash_combine(hash_value, std::hash<int64_t>{}(int_vector[i]));
}
return hash_value;
}

std::size_t operator()(const std::pair<FuncGraphPtr, std::vector<int64_t>> &p) const {
auto h1 = std::hash<FuncGraphPtr>{}(p.first);
auto h2 = Int64VectorHash(p.second);
return hash_combine(h1, h2);
}
};

bool ShouldTransform(const AnfNodePtr &node, const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) {
if (node->abstract() && node->abstract()->isa<abstract::AbstractTuple>()) {
const auto &abs_tuple = *(node->abstract()->cast<abstract::AbstractTuplePtr>());
if (tp_cnodes_and_index[0].index == 0 && abs_tuple.size() > 0) {
if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) {
return true;
}
}
// fprop_fg will return MakeTuple(xx, bprop_fg).
if (tp_cnodes_and_index.size() > 1 && tp_cnodes_and_index[1].index == 1 && abs_tuple.size() > 1 &&
abs_tuple[1]->isa<abstract::AbstractFunction>()) {
return true;
}
}
return false;
}
} // namespace internal

// {prim::kPrimTupleGetItem, {G, Xs}, C}
@@ -136,23 +309,52 @@ class IncorporateGetitem : public AnfVisitor {
IncorporateGetitem() : getitem_transform_() {}
~IncorporateGetitem() override = default;

AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node);
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
fg_->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr;
}

// This node had been substituted.
if (processed_nodes_.find(fg_call_cnode_) != processed_nodes_.end()) {
MS_LOG(DEBUG) << "fg call with same cnode is already replaced, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
return nullptr;
}
const auto &manager = fg_->manager();
MS_EXCEPTION_IF_NULL(manager);
bool output_is_shrinkable = internal::IsOutputShrinkable(fg_->output());
std::vector<internal::TpCNodeAndIndex> tp_cnodes_and_index;
auto fg_call_cnode_users_counter = MultipleUse(fg_call_cnode_, fg_, &tp_cnodes_and_index);
bool multiple_use = (tp_cnodes_and_index.size() > 1);
if (output_is_shrinkable && multiple_use && (tp_cnodes_and_index.size() == fg_call_cnode_users_counter)) {
if (!internal::ShouldTransform(fg_call_cnode_, tp_cnodes_and_index) && !internal::HasMoreJ(optimizer)) {
MS_LOG(DEBUG) << "No more j and multiple use, will shrink, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
const auto output_size = internal::GetOutputSize(fg_->output());
if (fg_call_cnode_users_counter == output_size) {
processed_nodes_.emplace(fg_call_cnode_);
MS_LOG(DEBUG) << "All elements in output is used, no need to transform, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
return nullptr;
}
auto new_node = ShrinkFuncGraphOutput(node, tp_cnodes_and_index);
if (new_node != nullptr) {
return new_node;
}
}
}
MS_LOG(DEBUG) << "Cannot shrink, transform_getitem, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
auto new_fg = getitem_transform_(node, fg_, idx_);
MS_LOG(DEBUG) << "Original fg: " << fg_->ToString() << ", new fg: " << new_fg->ToString();
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
auto new_node = node->func_graph()->NewCNode(args_);
// Check if the another only usage of {G, Xs} is UpdateState{s, {G, Xs}}, if yes, replace
// UpdateState{s, {G, Xs}} with UpdateState{s, new_node};
const auto &manager = fg_->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users_map = manager->node_users();
auto it = node_users_map.find(fg_cnode_);
auto it = node_users_map.find(fg_call_cnode_);
if (it != node_users_map.end()) {
AnfNodePtr update_state_node = nullptr;
auto &node_users = it->second;
@@ -166,7 +368,7 @@ class IncorporateGetitem : public AnfVisitor {
if (update_state_node != nullptr) {
auto update_state_cnode = update_state_node->cast<CNodePtr>();
// double check;
if (update_state_cnode->input(2) == fg_cnode_) {
if (update_state_cnode->input(2) == fg_call_cnode_) {
MS_LOG(DEBUG) << "Replace UpdateState node: " << update_state_cnode->DebugString(2)
<< ", input 2 with: " << new_node->DebugString();
manager->SetEdge(update_state_cnode, 2, new_node);
@@ -177,12 +379,98 @@ class IncorporateGetitem : public AnfVisitor {
return new_node;
}

size_t MultipleUse(const CNodePtr &fg_call, const FuncGraphPtr &fg,
std::vector<internal::TpCNodeAndIndex> *cnodes_and_index) const {
const auto &manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &cnode_and_index_vector = *cnodes_and_index;
std::set<int64_t> index_set;
std::size_t total_usage = 0;
const auto &node_users_map = manager->node_users();
const auto &it = node_users_map.find(fg_call);
if (it == node_users_map.end()) {
return 0;
}
const auto &node_users = it->second;
for (const auto &user : node_users) {
if (IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
const auto &cnode = user.first->cast<CNodePtr>();
if (cnode->input(2)->isa<ValueNode>()) {
auto idx = GetValue<int64_t>(cnode->input(2)->cast<ValueNodePtr>()->value());
cnode_and_index_vector.push_back({cnode, idx});
index_set.insert(idx);
total_usage++;
} else {
MS_LOG(EXCEPTION) << "tuple_getitem index is not valuenode, but: " << user.first->DebugString();
}
} else {
MS_LOG(DEBUG) << "fg_call usre is not tuple_getitem, user: " << user.first->DebugString();
}
}
if (index_set.size() != total_usage) {
MS_LOG(DEBUG) << "some index usage is duplicated, total_usage: " << total_usage;
MS_LOG(DEBUG) << "index_set:";
for (auto idx : index_set) {
MS_LOG(DEBUG) << " " << idx;
}
}
// sort by index;
std::sort(cnode_and_index_vector.begin(), cnode_and_index_vector.end(),
[](const auto &tp1, const auto &tp2) { return tp1.index < tp2.index; });
return node_users.size();
}

AnfNodePtr ShrinkFuncGraphOutput(const AnfNodePtr &node,
const std::vector<internal::TpCNodeAndIndex> &tp_cnodes_and_index) {
const auto &manager = fg_->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int64_t> index_vector;
(void)std::transform(tp_cnodes_and_index.begin(), tp_cnodes_and_index.end(), std::back_inserter(index_vector),
[](const auto &cnode_and_index) { return cnode_and_index.index; });
auto iter = processed_fgs_.find(std::make_pair(fg_, index_vector));
if (iter != processed_fgs_.end()) {
MS_LOG(DEBUG) << "fg is already processed, just update caller index, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
MS_LOG(DEBUG) << "original fg: " << fg_->ToString() << ", processed_fg: " << iter->second->ToString();
processed_nodes_.emplace(fg_call_cnode_);
manager->SetEdge(fg_call_cnode_, 0, NewValueNode(iter->second));
auto shrunk_abstract = internal::ShrinkAbstract(fg_call_cnode_->abstract(), tp_cnodes_and_index);
if (shrunk_abstract != nullptr) {
fg_call_cnode_->set_abstract(shrunk_abstract);
}
auto new_idx = internal::UpdateUserNodeIndex(fg_call_cnode_, idx_, tp_cnodes_and_index);
auto new_node =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fg_call_cnode_, NewValueNode(new_idx)});
new_node->set_abstract(node->abstract());
return new_node;
}
const auto new_fg = internal::ShrinkUnsedOutput(fg_, tp_cnodes_and_index);
if (new_fg != nullptr) {
MS_LOG(DEBUG) << "fg output is shrunk, original fg: " << fg_->ToString() << ", new fg: " << new_fg->ToString();
processed_nodes_.emplace(fg_call_cnode_);
processed_fgs_.emplace(std::make_pair(fg_, index_vector), new_fg);
manager->SetEdge(fg_call_cnode_, 0, NewValueNode(new_fg));
auto shrunk_abstract = internal::ShrinkAbstract(fg_call_cnode_->abstract(), tp_cnodes_and_index);
if (shrunk_abstract != nullptr) {
fg_call_cnode_->set_abstract(shrunk_abstract);
}
auto new_idx = internal::UpdateUserNodeIndex(fg_call_cnode_, idx_, tp_cnodes_and_index);
auto new_node =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fg_call_cnode_, NewValueNode(new_idx)});
new_node->set_abstract(node->abstract());
return new_node;
}
MS_LOG(DEBUG) << "Shrink failed. node: " << node->DebugString()
<< ", switch_call: " << fg_call_cnode_->DebugString();
return nullptr;
}

void Visit(const CNodePtr &cnode) override {
if (cnode->size() == 0 || !IsValueNode<FuncGraph>(cnode->input(0))) {
return;
}

fg_cnode_ = cnode;
fg_call_cnode_ = cnode;
auto &inputs = cnode->inputs();
fg_ = GetValueNode<FuncGraphPtr>(inputs[0]);
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_));
@@ -193,15 +481,19 @@ class IncorporateGetitem : public AnfVisitor {
void Reset() {
idx_ = -1;
fg_ = nullptr;
fg_cnode_ = nullptr;
fg_call_cnode_ = nullptr;
args_.clear();
}

private:
int64_t idx_{-1};
FuncGraphPtr fg_{nullptr};
AnfNodePtr fg_cnode_{nullptr};
CNodePtr fg_call_cnode_{nullptr};
std::vector<AnfNodePtr> args_{};
std::set<AnfNodePtr> processed_nodes_;
std::unordered_map<std::pair<FuncGraphPtr, std::vector<int64_t>>, FuncGraphPtr,
internal::FuncGraphIntVectorPairHasher>
processed_fgs_;
internal::GetitemTransform getitem_transform_;
};

@@ -298,7 +590,7 @@ class IncorporateGetitemSwitch : public AnfVisitor {
IncorporateGetitemSwitch() : getitem_transform_() {}
~IncorporateGetitemSwitch() override = default;

AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset();
is_in_get_ = true;
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node);
@@ -316,33 +608,57 @@ class IncorporateGetitemSwitch : public AnfVisitor {
if (g2_ == nullptr) {
return nullptr;
}
if (processed_nodes_.find(switch_) != processed_nodes_.end()) {
MS_LOG(DEBUG) << "fg in switch node has been replaced. node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
return nullptr;
}

bool g1_output_is_shrinkable = internal::IsOutputShrinkable(g1_->output());
bool g2_output_is_shrinkable = internal::IsOutputShrinkable(g2_->output());

auto tuple_getitem = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
bool has_env_type = false;
if (tuple_getitem->input(1)->abstract() && tuple_getitem->input(1)->abstract()->isa<abstract::AbstractTuple>()) {
const auto &abs_tuple = *(tuple_getitem->input(1)->abstract()->cast<abstract::AbstractTuplePtr>());
// eliminate (envinstance, value1, value2, ...) built by bprop func_graph()
if (abs_tuple.size() >= 1) {
// Value maybe kAnyValue, so check the type track;
if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) {
has_env_type = true;
const auto &switch_call = tuple_getitem->input(1);
MS_EXCEPTION_IF_NULL(switch_call);
const auto &switch_call_cnode = switch_call->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_call_cnode);
// If exist env_getitem/env_setitem in this funcgraph or
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
std::vector<internal::TpCNodeAndIndex> tp_cnodes_and_index;
auto switch_call_users_counter = MultipleUseOfSwitch(switch_call, fg, &tp_cnodes_and_index);
bool multiple_use = (tp_cnodes_and_index.size() > 1);
if (g1_output_is_shrinkable && g2_output_is_shrinkable && multiple_use &&
(tp_cnodes_and_index.size() == switch_call_users_counter)) {
if (!internal::HasMoreJ(optimizer) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
!ExistEnvNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) {
MS_LOG(DEBUG) << "No more j, will shrink. Node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
const auto g1_output_size = internal::GetOutputSize(g1_->output());
const auto g2_output_size = internal::GetOutputSize(g2_->output());
if (g1_output_size != g2_output_size) {
MS_LOG(EXCEPTION) << "output of g1 and g2 should have same tuple size, but g1 output: "
<< g1_->output()->DebugString() << ", g2 output: " << g2_->output()->DebugString();
}
}
// eliminate (value, bprop_func) built by fprop func_graph
if (abs_tuple.size() >= 2) {
if (abs_tuple[1]->isa<abstract::AbstractFunction>()) {
has_env_type = true;
if (switch_call_users_counter == g1_output_size) {
processed_nodes_.emplace(switch_call);
MS_LOG(DEBUG) << "All elements in output is used, no need to transform, node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
return nullptr;
}

auto new_node = ShrinkFuncGraphOutput(node, switch_call_cnode, tp_cnodes_and_index);
if (new_node != nullptr) {
return new_node;
}
}
}
// If exist env_getitem/env_setitem in this funcgraph or
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
!ExistEnvNodeInTupleItem(g2_) && !has_env_type) {
return nullptr;
}
MS_LOG(DEBUG) << "Cannot shrink output, transform_getitem_switch, node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
auto new_g1 = getitem_transform_(node, g1_, idx_);
auto new_g2 = getitem_transform_(node, g2_, idx_);
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << new_g1->ToString();
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << new_g2->ToString();
auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)});
(void)args_.insert(args_.begin(), sw_node);

@@ -350,7 +666,60 @@ class IncorporateGetitemSwitch : public AnfVisitor {
new_node->set_abstract(node->abstract());
return new_node;
}

AnfNodePtr ShrinkFuncGraphOutput(const AnfNodePtr &node, const CNodePtr &switch_call_cnode,
const std::vector<internal::TpCNodeAndIndex> &tp_cnodes_and_index) {
const auto &manager = node->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
auto switch_cnode = switch_->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
std::vector<int64_t> index_vector;
(void)std::transform(tp_cnodes_and_index.begin(), tp_cnodes_and_index.end(), std::back_inserter(index_vector),
[](const auto &cnode_and_index) { return cnode_and_index.index; });
const auto &iter1 = processed_fgs_.find(std::make_pair(g1_, index_vector));
const auto &iter2 = processed_fgs_.find(std::make_pair(g2_, index_vector));
if (iter1 != processed_fgs_.end() && iter2 != processed_fgs_.end()) {
MS_LOG(DEBUG) << "fg output had been processed, no need to transform, node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << iter1->second->ToString();
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << iter2->second->ToString();
processed_nodes_.emplace(switch_);
manager->SetEdge(switch_cnode, 2, NewValueNode(iter1->second));
manager->SetEdge(switch_cnode, 3, NewValueNode(iter2->second));
auto shrunk_abstract = internal::ShrinkAbstract(switch_call_cnode->abstract(), tp_cnodes_and_index);
if (shrunk_abstract != nullptr) {
switch_call_cnode->set_abstract(shrunk_abstract);
}
auto new_idx = internal::UpdateUserNodeIndex(switch_call_cnode, idx_, tp_cnodes_and_index);
auto new_node =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), switch_call_cnode, NewValueNode(new_idx)});
new_node->set_abstract(node->abstract());
return new_node;
}
const auto &new_g1 = internal::ShrinkUnsedOutput(g1_, tp_cnodes_and_index);
const auto &new_g2 = internal::ShrinkUnsedOutput(g2_, tp_cnodes_and_index);
if (new_g1 != nullptr && new_g2 != nullptr) {
MS_LOG(DEBUG) << "Shrink output. node: " << node->DebugString() << ", switch: " << switch_->DebugString();
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << new_g1->ToString();
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << new_g2->ToString();
processed_nodes_.emplace(switch_);
processed_fgs_.emplace(std::make_pair(g1_, index_vector), new_g1);
processed_fgs_.emplace(std::make_pair(g2_, index_vector), new_g2);
manager->SetEdge(switch_cnode, 2, NewValueNode(new_g1));
manager->SetEdge(switch_cnode, 3, NewValueNode(new_g2));
auto shrunk_abstract = internal::ShrinkAbstract(switch_call_cnode->abstract(), tp_cnodes_and_index);
if (shrunk_abstract != nullptr) {
switch_call_cnode->set_abstract(shrunk_abstract);
}
auto new_idx = internal::UpdateUserNodeIndex(switch_call_cnode, idx_, tp_cnodes_and_index);
auto new_node =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), switch_call_cnode, NewValueNode(new_idx)});
new_node->set_abstract(node->abstract());
return new_node;
}
MS_LOG(DEBUG) << "Shrink failed. node: " << node->DebugString()
<< ", switch_call: " << switch_call_cnode->DebugString();
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_in_switch_ && x_ == nullptr) {
x_ = node;
@@ -393,22 +762,51 @@ class IncorporateGetitemSwitch : public AnfVisitor {
}

private:
bool MultipleUseOfSwitch(const AnfNodePtr &switch_call, const FuncGraphPtr &fg) const {
size_t MultipleUseOfSwitch(const AnfNodePtr &switch_call, const FuncGraphPtr &fg,
std::vector<internal::TpCNodeAndIndex> *cnodes_and_index) const {
auto switch_call_cnode = switch_call->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_call_cnode);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &cnode_and_index_vector = *cnodes_and_index;
std::set<int64_t> index_set;
std::size_t total_usage = 0;
auto &node_users_map = manager->node_users();
auto it = node_users_map.find(switch_call);
if (it == node_users_map.end()) {
return false;
return 0;
}
auto &node_users = it->second;
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s
auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair<AnfNodePtr, int> &user) {
return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem);
});
return tuple_getitem_num > 1;
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.
for (auto user : node_users) {
if (IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
auto cnode = user.first->cast<CNodePtr>();
constexpr auto kInputIndex = 2;
if (cnode->input(kInputIndex)->isa<ValueNode>()) {
const auto &idx_node = cnode->input(kInputIndex)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(idx_node);
auto idx = GetValue<int64_t>(idx_node->value());
cnode_and_index_vector.push_back({cnode, idx});
index_set.insert(idx);
total_usage++;
} else {
MS_LOG(EXCEPTION) << "Tuple_getitem index is not valuenode, but: " << user.first->DebugString(2);
}
} else {
MS_LOG(DEBUG) << "switch_call user is not tuple_getitem, user: " << user.first->DebugString(2);
}
}
if (index_set.size() != total_usage) {
MS_LOG(DEBUG) << "some index is duplicated, total_usage: " << total_usage;
MS_LOG(DEBUG) << "index_set: ";
for (auto idx : index_set) {
MS_LOG(DEBUG) << " " << idx;
}
}
// sort by index;
std::sort(cnode_and_index_vector.begin(), cnode_and_index_vector.end(),
[](const auto &tp1, const auto &tp2) { return tp1.index < tp2.index; });
return node_users.size();
}

static bool inline ExistEnvNode(const FuncGraphPtr &fg) {
@@ -441,6 +839,10 @@ class IncorporateGetitemSwitch : public AnfVisitor {
FuncGraphPtr g1_{nullptr}, g2_{nullptr};
bool is_in_get_{false}, is_in_switch_{false};
std::vector<AnfNodePtr> args_{};
std::set<AnfNodePtr> processed_nodes_;
std::unordered_map<std::pair<FuncGraphPtr, std::vector<int64_t>>, FuncGraphPtr,
internal::FuncGraphIntVectorPairHasher>
processed_fgs_;
internal::GetitemTransform getitem_transform_;
};



+ 1
- 0
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -300,6 +300,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.merge_addn_,
irpass.float_tuple_getitem_switch_,
irpass.float_env_getitem_switch_,
irpass.inline_,
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,


Loading…
Cancel
Save