|
|
|
@@ -19,6 +19,7 @@ |
|
|
|
|
|
|
|
#include <algorithm> |
|
|
|
#include <memory> |
|
|
|
#include <set> |
|
|
|
#include <unordered_map> |
|
|
|
#include <vector> |
|
|
|
#include <utility> |
|
|
|
@@ -128,6 +129,178 @@ class GetItemTransformACrossGraph { |
|
|
|
private: |
|
|
|
std::unordered_map<FuncGraphPtr, std::unordered_map<int64_t, FuncGraphPtr>> cache_; |
|
|
|
}; |
|
|
|
|
|
|
|
bool HasMoreJ(const OptimizerPtr &optimizer) { |
|
|
|
bool more_j = false; |
|
|
|
auto res = optimizer->resource(); |
|
|
|
auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(res); |
|
|
|
if (resource_ptr != nullptr) { |
|
|
|
const auto &manager = optimizer->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
more_j = manager->func_graph_j_total(resource_ptr->func_graph()); |
|
|
|
} |
|
|
|
return more_j; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsOutputShrinkable(const AnfNodePtr &output) { |
|
|
|
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (GetValueNode<ValueTuplePtr>(output)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
size_t GetOutputSize(const AnfNodePtr &output) { |
|
|
|
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { |
|
|
|
const auto &output_cnode = output->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode); |
|
|
|
return output_cnode->size() - 1; |
|
|
|
} |
|
|
|
const auto &value_tuple = GetValueNode<ValueTuplePtr>(output); |
|
|
|
if (value_tuple == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "fg output is not MakeTuple or ValueTuple, but: " << output->DebugString(); |
|
|
|
} |
|
|
|
return value_tuple->size(); |
|
|
|
} |
|
|
|
|
|
|
|
struct TpCNodeAndIndex { |
|
|
|
// CNode {TupleGetItem, call, index} |
|
|
|
CNodePtr tp_cnode; |
|
|
|
int64_t index; |
|
|
|
}; |
|
|
|
|
|
|
|
int64_t UpdateUserNodeIndex(const CNodePtr &fg_call_cnode, const int64_t current_index, |
|
|
|
const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) { |
|
|
|
const auto &manager = fg_call_cnode->func_graph()->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
int64_t new_index = current_index; |
|
|
|
auto txn = manager->Transact(); |
|
|
|
for (int64_t i = 0; i < SizeToLong(tp_cnodes_and_index.size()); ++i) { |
|
|
|
const auto &cnode_and_index = tp_cnodes_and_index[i]; |
|
|
|
if (cnode_and_index.index != i) { |
|
|
|
constexpr auto kInputIndex = 2; |
|
|
|
txn.SetEdge(cnode_and_index.tp_cnode, kInputIndex, NewValueNode(i)); |
|
|
|
} |
|
|
|
if (cnode_and_index.index == current_index) { |
|
|
|
new_index = i; |
|
|
|
} |
|
|
|
} |
|
|
|
txn.Commit(); |
|
|
|
return new_index; |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr ShrinkAbstract(const AbstractBasePtr &original_abstract, |
|
|
|
const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) { |
|
|
|
if (original_abstract != nullptr && original_abstract->isa<abstract::AbstractTuple>()) { |
|
|
|
const auto &abs_tuple = original_abstract->cast<abstract::AbstractTuplePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(abs_tuple); |
|
|
|
const auto &abs_tuple_elements = abs_tuple->elements(); |
|
|
|
const int64_t before_shrink_tuple_size = SizeToLong(abs_tuple_elements.size()); |
|
|
|
AbstractBasePtrList shrunk_abstract_elements; |
|
|
|
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(), |
|
|
|
std::back_inserter(shrunk_abstract_elements), |
|
|
|
[abs_tuple_elements, before_shrink_tuple_size](const auto &node_and_index) { |
|
|
|
if (node_and_index.index >= before_shrink_tuple_size) { |
|
|
|
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index |
|
|
|
<< ", abstract tuple size: " << before_shrink_tuple_size; |
|
|
|
} |
|
|
|
return abs_tuple_elements[node_and_index.index]; |
|
|
|
}); |
|
|
|
return std::make_shared<abstract::AbstractTuple>(shrunk_abstract_elements); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr ShrinkUnsedOutput(const FuncGraphPtr &fg, const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) { |
|
|
|
const auto &manager = fg->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
|
|
|
|
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("tp_use")); |
|
|
|
auto new_fg_output = new_fg->output(); |
|
|
|
AnfNodePtr shrunk_output = nullptr; |
|
|
|
int64_t before_shrink_inputs_size = 0; |
|
|
|
if (IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) { |
|
|
|
// Shrink output; |
|
|
|
auto new_fg_output_cnode = new_fg_output->cast<CNodePtr>(); |
|
|
|
const auto &new_fg_output_inputs = new_fg_output_cnode->inputs(); |
|
|
|
constexpr auto kMinimalSize = 2; |
|
|
|
if (new_fg_output_inputs.size() <= kMinimalSize) { |
|
|
|
MS_LOG(EXCEPTION) << "New fg output should at least 2 elements, but: " << new_fg_output->DebugString(); |
|
|
|
} |
|
|
|
before_shrink_inputs_size = SizeToLong(new_fg_output_inputs.size() - 1); |
|
|
|
AnfNodePtrList shrunk_inputs{NewValueNode({prim::kPrimMakeTuple})}; |
|
|
|
// Bypass maketuple primitive in new_fg_output_inputs; |
|
|
|
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(), std::back_inserter(shrunk_inputs), |
|
|
|
[new_fg_output, new_fg_output_inputs, before_shrink_inputs_size](const auto &node_and_index) { |
|
|
|
if (node_and_index.index >= before_shrink_inputs_size) { |
|
|
|
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index |
|
|
|
<< ", output: " << new_fg_output->DebugString(); |
|
|
|
} |
|
|
|
return new_fg_output_inputs[node_and_index.index + 1]; |
|
|
|
}); |
|
|
|
shrunk_output = new_fg->NewCNode(shrunk_inputs); |
|
|
|
} else { |
|
|
|
auto value_tuple = GetValueNode<ValueTuplePtr>(new_fg_output); |
|
|
|
if (value_tuple == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "New fg output is not MakeTuple or ValueTuple, but " << new_fg_output->DebugString(); |
|
|
|
} |
|
|
|
ValuePtrList shrunk_inputs; |
|
|
|
before_shrink_inputs_size = value_tuple->size(); |
|
|
|
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(), std::back_inserter(shrunk_inputs), |
|
|
|
[new_fg_output, value_tuple, before_shrink_inputs_size](const auto &node_and_index) { |
|
|
|
if (node_and_index.index >= before_shrink_inputs_size) { |
|
|
|
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index |
|
|
|
<< ", output: " << new_fg_output->DebugString(); |
|
|
|
} |
|
|
|
return (*value_tuple)[node_and_index.index]; |
|
|
|
}); |
|
|
|
shrunk_output = NewValueNode(std::make_shared<ValueTuple>(shrunk_inputs)); |
|
|
|
} |
|
|
|
auto shrunk_abstract = ShrinkAbstract(new_fg_output->abstract(), tp_cnodes_and_index); |
|
|
|
MS_EXCEPTION_IF_NULL(shrunk_abstract); |
|
|
|
shrunk_output->set_abstract(shrunk_abstract); |
|
|
|
new_fg->set_output(shrunk_output); |
|
|
|
MS_LOG(DEBUG) << "Partly item used; original size: " << before_shrink_inputs_size |
|
|
|
<< ", new size: " << tp_cnodes_and_index.size() << ", fg: " << fg->ToString() << ", new graph" |
|
|
|
<< new_fg->ToString(); |
|
|
|
return new_fg; |
|
|
|
} |
|
|
|
|
|
|
|
struct FuncGraphIntVectorPairHasher { |
|
|
|
std::size_t Int64VectorHash(const std::vector<int64_t> &int_vector) const { |
|
|
|
std::size_t hash_value = 0; |
|
|
|
constexpr auto kMaxElementsNum = 4; |
|
|
|
for (size_t i = 0; (i < int_vector.size()) && (i < kMaxElementsNum); ++i) { |
|
|
|
hash_value = hash_combine(hash_value, std::hash<int64_t>{}(int_vector[i])); |
|
|
|
} |
|
|
|
return hash_value; |
|
|
|
} |
|
|
|
|
|
|
|
std::size_t operator()(const std::pair<FuncGraphPtr, std::vector<int64_t>> &p) const { |
|
|
|
auto h1 = std::hash<FuncGraphPtr>{}(p.first); |
|
|
|
auto h2 = Int64VectorHash(p.second); |
|
|
|
return hash_combine(h1, h2); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
bool ShouldTransform(const AnfNodePtr &node, const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) { |
|
|
|
if (node->abstract() && node->abstract()->isa<abstract::AbstractTuple>()) { |
|
|
|
const auto &abs_tuple = *(node->abstract()->cast<abstract::AbstractTuplePtr>()); |
|
|
|
if (tp_cnodes_and_index[0].index == 0 && abs_tuple.size() > 0) { |
|
|
|
if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
// fprop_fg will return MakeTuple(xx, bprop_fg). |
|
|
|
if (tp_cnodes_and_index.size() > 1 && tp_cnodes_and_index[1].index == 1 && abs_tuple.size() > 1 && |
|
|
|
abs_tuple[1]->isa<abstract::AbstractFunction>()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace internal |
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {G, Xs}, C} |
|
|
|
@@ -136,23 +309,52 @@ class IncorporateGetitem : public AnfVisitor { |
|
|
|
IncorporateGetitem() : getitem_transform_() {} |
|
|
|
~IncorporateGetitem() override = default; |
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node); |
|
|
|
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || |
|
|
|
fg_->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// This node had been substituted. |
|
|
|
if (processed_nodes_.find(fg_call_cnode_) != processed_nodes_.end()) { |
|
|
|
MS_LOG(DEBUG) << "fg call with same cnode is already replaced, node: " << node->DebugString() |
|
|
|
<< ", fg_call: " << fg_call_cnode_->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
const auto &manager = fg_->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
bool output_is_shrinkable = internal::IsOutputShrinkable(fg_->output()); |
|
|
|
std::vector<internal::TpCNodeAndIndex> tp_cnodes_and_index; |
|
|
|
auto fg_call_cnode_users_counter = MultipleUse(fg_call_cnode_, fg_, &tp_cnodes_and_index); |
|
|
|
bool multiple_use = (tp_cnodes_and_index.size() > 1); |
|
|
|
if (output_is_shrinkable && multiple_use && (tp_cnodes_and_index.size() == fg_call_cnode_users_counter)) { |
|
|
|
if (!internal::ShouldTransform(fg_call_cnode_, tp_cnodes_and_index) && !internal::HasMoreJ(optimizer)) { |
|
|
|
MS_LOG(DEBUG) << "No more j and multiple use, will shrink, node: " << node->DebugString() |
|
|
|
<< ", fg_call: " << fg_call_cnode_->DebugString(); |
|
|
|
const auto output_size = internal::GetOutputSize(fg_->output()); |
|
|
|
if (fg_call_cnode_users_counter == output_size) { |
|
|
|
processed_nodes_.emplace(fg_call_cnode_); |
|
|
|
MS_LOG(DEBUG) << "All elements in output is used, no need to transform, node: " << node->DebugString() |
|
|
|
<< ", fg_call: " << fg_call_cnode_->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto new_node = ShrinkFuncGraphOutput(node, tp_cnodes_and_index); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Cannot shrink, transform_getitem, node: " << node->DebugString() |
|
|
|
<< ", fg_call: " << fg_call_cnode_->DebugString(); |
|
|
|
auto new_fg = getitem_transform_(node, fg_, idx_); |
|
|
|
MS_LOG(DEBUG) << "Original fg: " << fg_->ToString() << ", new fg: " << new_fg->ToString(); |
|
|
|
(void)args_.insert(args_.begin(), NewValueNode(new_fg)); |
|
|
|
auto new_node = node->func_graph()->NewCNode(args_); |
|
|
|
// Check if the another only usage of {G, Xs} is UpdateState{s, {G, Xs}}, if yes, replace |
|
|
|
// UpdateState{s, {G, Xs}} with UpdateState{s, new_node}; |
|
|
|
const auto &manager = fg_->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto &node_users_map = manager->node_users(); |
|
|
|
auto it = node_users_map.find(fg_cnode_); |
|
|
|
auto it = node_users_map.find(fg_call_cnode_); |
|
|
|
if (it != node_users_map.end()) { |
|
|
|
AnfNodePtr update_state_node = nullptr; |
|
|
|
auto &node_users = it->second; |
|
|
|
@@ -166,7 +368,7 @@ class IncorporateGetitem : public AnfVisitor { |
|
|
|
if (update_state_node != nullptr) { |
|
|
|
auto update_state_cnode = update_state_node->cast<CNodePtr>(); |
|
|
|
// double check; |
|
|
|
if (update_state_cnode->input(2) == fg_cnode_) { |
|
|
|
if (update_state_cnode->input(2) == fg_call_cnode_) { |
|
|
|
MS_LOG(DEBUG) << "Replace UpdateState node: " << update_state_cnode->DebugString(2) |
|
|
|
<< ", input 2 with: " << new_node->DebugString(); |
|
|
|
manager->SetEdge(update_state_cnode, 2, new_node); |
|
|
|
@@ -177,12 +379,98 @@ class IncorporateGetitem : public AnfVisitor { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
|
|
|
|
size_t MultipleUse(const CNodePtr &fg_call, const FuncGraphPtr &fg, |
|
|
|
std::vector<internal::TpCNodeAndIndex> *cnodes_and_index) const { |
|
|
|
const auto &manager = fg->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto &cnode_and_index_vector = *cnodes_and_index; |
|
|
|
std::set<int64_t> index_set; |
|
|
|
std::size_t total_usage = 0; |
|
|
|
const auto &node_users_map = manager->node_users(); |
|
|
|
const auto &it = node_users_map.find(fg_call); |
|
|
|
if (it == node_users_map.end()) { |
|
|
|
return 0; |
|
|
|
} |
|
|
|
const auto &node_users = it->second; |
|
|
|
for (const auto &user : node_users) { |
|
|
|
if (IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { |
|
|
|
const auto &cnode = user.first->cast<CNodePtr>(); |
|
|
|
if (cnode->input(2)->isa<ValueNode>()) { |
|
|
|
auto idx = GetValue<int64_t>(cnode->input(2)->cast<ValueNodePtr>()->value()); |
|
|
|
cnode_and_index_vector.push_back({cnode, idx}); |
|
|
|
index_set.insert(idx); |
|
|
|
total_usage++; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "tuple_getitem index is not valuenode, but: " << user.first->DebugString(); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "fg_call usre is not tuple_getitem, user: " << user.first->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
if (index_set.size() != total_usage) { |
|
|
|
MS_LOG(DEBUG) << "some index usage is duplicated, total_usage: " << total_usage; |
|
|
|
MS_LOG(DEBUG) << "index_set:"; |
|
|
|
for (auto idx : index_set) { |
|
|
|
MS_LOG(DEBUG) << " " << idx; |
|
|
|
} |
|
|
|
} |
|
|
|
// sort by index; |
|
|
|
std::sort(cnode_and_index_vector.begin(), cnode_and_index_vector.end(), |
|
|
|
[](const auto &tp1, const auto &tp2) { return tp1.index < tp2.index; }); |
|
|
|
return node_users.size(); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr ShrinkFuncGraphOutput(const AnfNodePtr &node, |
|
|
|
const std::vector<internal::TpCNodeAndIndex> &tp_cnodes_and_index) { |
|
|
|
const auto &manager = fg_->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
std::vector<int64_t> index_vector; |
|
|
|
(void)std::transform(tp_cnodes_and_index.begin(), tp_cnodes_and_index.end(), std::back_inserter(index_vector), |
|
|
|
[](const auto &cnode_and_index) { return cnode_and_index.index; }); |
|
|
|
auto iter = processed_fgs_.find(std::make_pair(fg_, index_vector)); |
|
|
|
if (iter != processed_fgs_.end()) { |
|
|
|
MS_LOG(DEBUG) << "fg is already processed, just update caller index, node: " << node->DebugString() |
|
|
|
<< ", fg_call: " << fg_call_cnode_->DebugString(); |
|
|
|
MS_LOG(DEBUG) << "original fg: " << fg_->ToString() << ", processed_fg: " << iter->second->ToString(); |
|
|
|
processed_nodes_.emplace(fg_call_cnode_); |
|
|
|
manager->SetEdge(fg_call_cnode_, 0, NewValueNode(iter->second)); |
|
|
|
auto shrunk_abstract = internal::ShrinkAbstract(fg_call_cnode_->abstract(), tp_cnodes_and_index); |
|
|
|
if (shrunk_abstract != nullptr) { |
|
|
|
fg_call_cnode_->set_abstract(shrunk_abstract); |
|
|
|
} |
|
|
|
auto new_idx = internal::UpdateUserNodeIndex(fg_call_cnode_, idx_, tp_cnodes_and_index); |
|
|
|
auto new_node = |
|
|
|
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fg_call_cnode_, NewValueNode(new_idx)}); |
|
|
|
new_node->set_abstract(node->abstract()); |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
const auto new_fg = internal::ShrinkUnsedOutput(fg_, tp_cnodes_and_index); |
|
|
|
if (new_fg != nullptr) { |
|
|
|
MS_LOG(DEBUG) << "fg output is shrunk, original fg: " << fg_->ToString() << ", new fg: " << new_fg->ToString(); |
|
|
|
processed_nodes_.emplace(fg_call_cnode_); |
|
|
|
processed_fgs_.emplace(std::make_pair(fg_, index_vector), new_fg); |
|
|
|
manager->SetEdge(fg_call_cnode_, 0, NewValueNode(new_fg)); |
|
|
|
auto shrunk_abstract = internal::ShrinkAbstract(fg_call_cnode_->abstract(), tp_cnodes_and_index); |
|
|
|
if (shrunk_abstract != nullptr) { |
|
|
|
fg_call_cnode_->set_abstract(shrunk_abstract); |
|
|
|
} |
|
|
|
auto new_idx = internal::UpdateUserNodeIndex(fg_call_cnode_, idx_, tp_cnodes_and_index); |
|
|
|
auto new_node = |
|
|
|
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fg_call_cnode_, NewValueNode(new_idx)}); |
|
|
|
new_node->set_abstract(node->abstract()); |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Shrink failed. node: " << node->DebugString() |
|
|
|
<< ", switch_call: " << fg_call_cnode_->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const CNodePtr &cnode) override { |
|
|
|
if (cnode->size() == 0 || !IsValueNode<FuncGraph>(cnode->input(0))) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
fg_cnode_ = cnode; |
|
|
|
fg_call_cnode_ = cnode; |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
fg_ = GetValueNode<FuncGraphPtr>(inputs[0]); |
|
|
|
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); |
|
|
|
@@ -193,15 +481,19 @@ class IncorporateGetitem : public AnfVisitor { |
|
|
|
void Reset() { |
|
|
|
idx_ = -1; |
|
|
|
fg_ = nullptr; |
|
|
|
fg_cnode_ = nullptr; |
|
|
|
fg_call_cnode_ = nullptr; |
|
|
|
args_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
int64_t idx_{-1}; |
|
|
|
FuncGraphPtr fg_{nullptr}; |
|
|
|
AnfNodePtr fg_cnode_{nullptr}; |
|
|
|
CNodePtr fg_call_cnode_{nullptr}; |
|
|
|
std::vector<AnfNodePtr> args_{}; |
|
|
|
std::set<AnfNodePtr> processed_nodes_; |
|
|
|
std::unordered_map<std::pair<FuncGraphPtr, std::vector<int64_t>>, FuncGraphPtr, |
|
|
|
internal::FuncGraphIntVectorPairHasher> |
|
|
|
processed_fgs_; |
|
|
|
internal::GetitemTransform getitem_transform_; |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -298,7 +590,7 @@ class IncorporateGetitemSwitch : public AnfVisitor { |
|
|
|
IncorporateGetitemSwitch() : getitem_transform_() {} |
|
|
|
~IncorporateGetitemSwitch() override = default; |
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
is_in_get_ = true; |
|
|
|
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node); |
|
|
|
@@ -316,33 +608,57 @@ class IncorporateGetitemSwitch : public AnfVisitor { |
|
|
|
if (g2_ == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (processed_nodes_.find(switch_) != processed_nodes_.end()) { |
|
|
|
MS_LOG(DEBUG) << "fg in switch node has been replaced. node: " << node->DebugString() |
|
|
|
<< ", switch: " << switch_->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
bool g1_output_is_shrinkable = internal::IsOutputShrinkable(g1_->output()); |
|
|
|
bool g2_output_is_shrinkable = internal::IsOutputShrinkable(g2_->output()); |
|
|
|
|
|
|
|
auto tuple_getitem = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
bool has_env_type = false; |
|
|
|
if (tuple_getitem->input(1)->abstract() && tuple_getitem->input(1)->abstract()->isa<abstract::AbstractTuple>()) { |
|
|
|
const auto &abs_tuple = *(tuple_getitem->input(1)->abstract()->cast<abstract::AbstractTuplePtr>()); |
|
|
|
// eliminate (envinstance, value1, value2, ...) built by bprop func_graph() |
|
|
|
if (abs_tuple.size() >= 1) { |
|
|
|
// Value maybe kAnyValue, so check the type track; |
|
|
|
if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) { |
|
|
|
has_env_type = true; |
|
|
|
const auto &switch_call = tuple_getitem->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(switch_call); |
|
|
|
const auto &switch_call_cnode = switch_call->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(switch_call_cnode); |
|
|
|
// If exist env_getitem/env_setitem in this funcgraph or |
|
|
|
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem; |
|
|
|
std::vector<internal::TpCNodeAndIndex> tp_cnodes_and_index; |
|
|
|
auto switch_call_users_counter = MultipleUseOfSwitch(switch_call, fg, &tp_cnodes_and_index); |
|
|
|
bool multiple_use = (tp_cnodes_and_index.size() > 1); |
|
|
|
if (g1_output_is_shrinkable && g2_output_is_shrinkable && multiple_use && |
|
|
|
(tp_cnodes_and_index.size() == switch_call_users_counter)) { |
|
|
|
if (!internal::HasMoreJ(optimizer) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) && |
|
|
|
!ExistEnvNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) { |
|
|
|
MS_LOG(DEBUG) << "No more j, will shrink. Node: " << node->DebugString() |
|
|
|
<< ", switch: " << switch_->DebugString(); |
|
|
|
const auto g1_output_size = internal::GetOutputSize(g1_->output()); |
|
|
|
const auto g2_output_size = internal::GetOutputSize(g2_->output()); |
|
|
|
if (g1_output_size != g2_output_size) { |
|
|
|
MS_LOG(EXCEPTION) << "output of g1 and g2 should have same tuple size, but g1 output: " |
|
|
|
<< g1_->output()->DebugString() << ", g2 output: " << g2_->output()->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
// eliminate (value, bprop_func) built by fprop func_graph |
|
|
|
if (abs_tuple.size() >= 2) { |
|
|
|
if (abs_tuple[1]->isa<abstract::AbstractFunction>()) { |
|
|
|
has_env_type = true; |
|
|
|
if (switch_call_users_counter == g1_output_size) { |
|
|
|
processed_nodes_.emplace(switch_call); |
|
|
|
MS_LOG(DEBUG) << "All elements in output is used, no need to transform, node: " << node->DebugString() |
|
|
|
<< ", switch: " << switch_->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto new_node = ShrinkFuncGraphOutput(node, switch_call_cnode, tp_cnodes_and_index); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
// If exist env_getitem/env_setitem in this funcgraph or |
|
|
|
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem; |
|
|
|
if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) && |
|
|
|
!ExistEnvNodeInTupleItem(g2_) && !has_env_type) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Cannot shrink output, transform_getitem_switch, node: " << node->DebugString() |
|
|
|
<< ", switch: " << switch_->DebugString(); |
|
|
|
auto new_g1 = getitem_transform_(node, g1_, idx_); |
|
|
|
auto new_g2 = getitem_transform_(node, g2_, idx_); |
|
|
|
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << new_g1->ToString(); |
|
|
|
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << new_g2->ToString(); |
|
|
|
auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); |
|
|
|
(void)args_.insert(args_.begin(), sw_node); |
|
|
|
|
|
|
|
@@ -350,7 +666,60 @@ class IncorporateGetitemSwitch : public AnfVisitor { |
|
|
|
new_node->set_abstract(node->abstract()); |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr ShrinkFuncGraphOutput(const AnfNodePtr &node, const CNodePtr &switch_call_cnode, |
|
|
|
const std::vector<internal::TpCNodeAndIndex> &tp_cnodes_and_index) { |
|
|
|
const auto &manager = node->func_graph()->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto switch_cnode = switch_->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(switch_cnode); |
|
|
|
std::vector<int64_t> index_vector; |
|
|
|
(void)std::transform(tp_cnodes_and_index.begin(), tp_cnodes_and_index.end(), std::back_inserter(index_vector), |
|
|
|
[](const auto &cnode_and_index) { return cnode_and_index.index; }); |
|
|
|
const auto &iter1 = processed_fgs_.find(std::make_pair(g1_, index_vector)); |
|
|
|
const auto &iter2 = processed_fgs_.find(std::make_pair(g2_, index_vector)); |
|
|
|
if (iter1 != processed_fgs_.end() && iter2 != processed_fgs_.end()) { |
|
|
|
MS_LOG(DEBUG) << "fg output had been processed, no need to transform, node: " << node->DebugString() |
|
|
|
<< ", switch: " << switch_->DebugString(); |
|
|
|
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << iter1->second->ToString(); |
|
|
|
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << iter2->second->ToString(); |
|
|
|
processed_nodes_.emplace(switch_); |
|
|
|
manager->SetEdge(switch_cnode, 2, NewValueNode(iter1->second)); |
|
|
|
manager->SetEdge(switch_cnode, 3, NewValueNode(iter2->second)); |
|
|
|
auto shrunk_abstract = internal::ShrinkAbstract(switch_call_cnode->abstract(), tp_cnodes_and_index); |
|
|
|
if (shrunk_abstract != nullptr) { |
|
|
|
switch_call_cnode->set_abstract(shrunk_abstract); |
|
|
|
} |
|
|
|
auto new_idx = internal::UpdateUserNodeIndex(switch_call_cnode, idx_, tp_cnodes_and_index); |
|
|
|
auto new_node = |
|
|
|
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), switch_call_cnode, NewValueNode(new_idx)}); |
|
|
|
new_node->set_abstract(node->abstract()); |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
const auto &new_g1 = internal::ShrinkUnsedOutput(g1_, tp_cnodes_and_index); |
|
|
|
const auto &new_g2 = internal::ShrinkUnsedOutput(g2_, tp_cnodes_and_index); |
|
|
|
if (new_g1 != nullptr && new_g2 != nullptr) { |
|
|
|
MS_LOG(DEBUG) << "Shrink output. node: " << node->DebugString() << ", switch: " << switch_->DebugString(); |
|
|
|
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << new_g1->ToString(); |
|
|
|
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << new_g2->ToString(); |
|
|
|
processed_nodes_.emplace(switch_); |
|
|
|
processed_fgs_.emplace(std::make_pair(g1_, index_vector), new_g1); |
|
|
|
processed_fgs_.emplace(std::make_pair(g2_, index_vector), new_g2); |
|
|
|
manager->SetEdge(switch_cnode, 2, NewValueNode(new_g1)); |
|
|
|
manager->SetEdge(switch_cnode, 3, NewValueNode(new_g2)); |
|
|
|
auto shrunk_abstract = internal::ShrinkAbstract(switch_call_cnode->abstract(), tp_cnodes_and_index); |
|
|
|
if (shrunk_abstract != nullptr) { |
|
|
|
switch_call_cnode->set_abstract(shrunk_abstract); |
|
|
|
} |
|
|
|
auto new_idx = internal::UpdateUserNodeIndex(switch_call_cnode, idx_, tp_cnodes_and_index); |
|
|
|
auto new_node = |
|
|
|
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), switch_call_cnode, NewValueNode(new_idx)}); |
|
|
|
new_node->set_abstract(node->abstract()); |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Shrink failed. node: " << node->DebugString() |
|
|
|
<< ", switch_call: " << switch_call_cnode->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
if (is_in_switch_ && x_ == nullptr) { |
|
|
|
x_ = node; |
|
|
|
@@ -393,22 +762,51 @@ class IncorporateGetitemSwitch : public AnfVisitor { |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
bool MultipleUseOfSwitch(const AnfNodePtr &switch_call, const FuncGraphPtr &fg) const { |
|
|
|
size_t MultipleUseOfSwitch(const AnfNodePtr &switch_call, const FuncGraphPtr &fg, |
|
|
|
std::vector<internal::TpCNodeAndIndex> *cnodes_and_index) const { |
|
|
|
auto switch_call_cnode = switch_call->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(switch_call_cnode); |
|
|
|
auto manager = fg->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto &cnode_and_index_vector = *cnodes_and_index; |
|
|
|
std::set<int64_t> index_set; |
|
|
|
std::size_t total_usage = 0; |
|
|
|
auto &node_users_map = manager->node_users(); |
|
|
|
auto it = node_users_map.find(switch_call); |
|
|
|
if (it == node_users_map.end()) { |
|
|
|
return false; |
|
|
|
return 0; |
|
|
|
} |
|
|
|
auto &node_users = it->second; |
|
|
|
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s |
|
|
|
auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair<AnfNodePtr, int> &user) { |
|
|
|
return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem); |
|
|
|
}); |
|
|
|
return tuple_getitem_num > 1; |
|
|
|
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute. |
|
|
|
for (auto user : node_users) { |
|
|
|
if (IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { |
|
|
|
auto cnode = user.first->cast<CNodePtr>(); |
|
|
|
constexpr auto kInputIndex = 2; |
|
|
|
if (cnode->input(kInputIndex)->isa<ValueNode>()) { |
|
|
|
const auto &idx_node = cnode->input(kInputIndex)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(idx_node); |
|
|
|
auto idx = GetValue<int64_t>(idx_node->value()); |
|
|
|
cnode_and_index_vector.push_back({cnode, idx}); |
|
|
|
index_set.insert(idx); |
|
|
|
total_usage++; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Tuple_getitem index is not valuenode, but: " << user.first->DebugString(2); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "switch_call user is not tuple_getitem, user: " << user.first->DebugString(2); |
|
|
|
} |
|
|
|
} |
|
|
|
if (index_set.size() != total_usage) { |
|
|
|
MS_LOG(DEBUG) << "some index is duplicated, total_usage: " << total_usage; |
|
|
|
MS_LOG(DEBUG) << "index_set: "; |
|
|
|
for (auto idx : index_set) { |
|
|
|
MS_LOG(DEBUG) << " " << idx; |
|
|
|
} |
|
|
|
} |
|
|
|
// sort by index; |
|
|
|
std::sort(cnode_and_index_vector.begin(), cnode_and_index_vector.end(), |
|
|
|
[](const auto &tp1, const auto &tp2) { return tp1.index < tp2.index; }); |
|
|
|
return node_users.size(); |
|
|
|
} |
|
|
|
|
|
|
|
static bool inline ExistEnvNode(const FuncGraphPtr &fg) { |
|
|
|
@@ -441,6 +839,10 @@ class IncorporateGetitemSwitch : public AnfVisitor { |
|
|
|
FuncGraphPtr g1_{nullptr}, g2_{nullptr}; |
|
|
|
bool is_in_get_{false}, is_in_switch_{false}; |
|
|
|
std::vector<AnfNodePtr> args_{}; |
|
|
|
std::set<AnfNodePtr> processed_nodes_; |
|
|
|
std::unordered_map<std::pair<FuncGraphPtr, std::vector<int64_t>>, FuncGraphPtr, |
|
|
|
internal::FuncGraphIntVectorPairHasher> |
|
|
|
processed_fgs_; |
|
|
|
internal::GetitemTransform getitem_transform_; |
|
|
|
}; |
|
|
|
|
|
|
|
|