|
|
|
@@ -72,6 +72,52 @@ class GetitemTransform { |
|
|
|
private: |
|
|
|
std::unordered_map<FuncGraphPtr, std::unordered_map<int, FuncGraphPtr>> cache_; |
|
|
|
}; |
|
|
|
|
|
|
|
class GetItemTransformACrossGraph { |
|
|
|
public: |
|
|
|
GetItemTransformACrossGraph() : cache_() {} |
|
|
|
~GetItemTransformACrossGraph() = default; |
|
|
|
|
|
|
|
FuncGraphPtr operator()(const FuncGraphPtr &fg, int idx) { |
|
|
|
if (cache_.find(fg) == cache_.end()) { |
|
|
|
cache_[fg] = {}; |
|
|
|
} |
|
|
|
|
|
|
|
auto &cache = cache_[fg]; |
|
|
|
if (cache.find(idx) == cache.end()) { |
|
|
|
std::ostringstream ss("tp", std::ostringstream::app); |
|
|
|
ss << idx; |
|
|
|
|
|
|
|
auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str())); |
|
|
|
auto output_outer = new_fg_outer->output(); |
|
|
|
if (!IsValueNode<FuncGraph>(output_outer)) { |
|
|
|
MS_LOG(WARNING) << "Output of outer graph should be a func_graph"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto fg_inner = GetValueNode<FuncGraphPtr>(output_outer); |
|
|
|
auto new_fg = TransformableClone(fg_inner, std::make_shared<TraceTransform>(ss.str())); |
|
|
|
new_fg_outer->set_output(NewValueNode(new_fg)); |
|
|
|
auto output = new_fg->output(); |
|
|
|
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { |
|
|
|
auto cnode = output->cast<CNodePtr>(); |
|
|
|
auto ids = IntToSize(idx + 1); |
|
|
|
// Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1. |
|
|
|
if (ids >= cnode->size()) { |
|
|
|
MS_LOG(EXCEPTION) << "index " << ids << " is out of inputs length " << cnode->size(); |
|
|
|
} |
|
|
|
new_fg->set_output(cnode->input(ids)); |
|
|
|
} else { |
|
|
|
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(idx)})); |
|
|
|
} |
|
|
|
|
|
|
|
cache[idx] = new_fg_outer; |
|
|
|
} |
|
|
|
return cache[idx]; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
std::unordered_map<FuncGraphPtr, std::unordered_map<int, FuncGraphPtr>> cache_; |
|
|
|
}; |
|
|
|
} // namespace internal |
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {G, Xs}, C} |
|
|
|
@@ -385,13 +431,199 @@ class IncorporateGetitemSwitch : public AnfVisitor { |
|
|
|
internal::GetitemTransform getitem_transform_; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, C} |
|
|
|
class IncorporateGetitemSwitchLayerA : public AnfVisitor { |
|
|
|
public: |
|
|
|
IncorporateGetitemSwitchLayerA() : getitem_transform_() {} |
|
|
|
~IncorporateGetitemSwitchLayerA() override = default; |
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
is_in_get_ = true; |
|
|
|
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); |
|
|
|
is_in_get_ = false; |
|
|
|
|
|
|
|
auto fg = node->func_graph(); |
|
|
|
if (idx_ == -1 || switch_layer_ == nullptr || fg == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
is_in_switch_ = true; |
|
|
|
AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(switch_layer_); |
|
|
|
is_in_switch_ = false; |
|
|
|
|
|
|
|
if (graphs_.empty()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> layers; |
|
|
|
for (auto &graph : graphs_) { |
|
|
|
auto fg_transform = getitem_transform_(graph, idx_); |
|
|
|
if (fg_transform == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
layers.push_back(NewValueNode(fg_transform)); |
|
|
|
} |
|
|
|
auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers); |
|
|
|
std::vector<AnfNodePtr> sw_args{NewValueNode(prim::kPrimSwitchLayer), x_, layers_node}; |
|
|
|
auto sw_node = fg->NewCNode(sw_args); |
|
|
|
(void)args_.insert(args_.begin(), sw_node); |
|
|
|
|
|
|
|
return fg->NewCNode(args_); |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
if (is_in_switch_ && x_ == nullptr) { |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
AnfVisitor::Visit(node); |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const CNodePtr &cnode) override { |
|
|
|
if (is_in_get_ && cnode->size() != 0) { |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
switch_layer_ = inputs[0]; |
|
|
|
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); |
|
|
|
} |
|
|
|
if (is_in_switch_ && cnode->size() > 2) { |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) { |
|
|
|
(void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), |
|
|
|
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); }); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { |
|
|
|
if (is_in_get_) { |
|
|
|
idx_ = GetValue<int>(vnode->value()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Reset() { |
|
|
|
x_ = nullptr; |
|
|
|
graphs_.clear(); |
|
|
|
switch_layer_ = nullptr; |
|
|
|
args_.clear(); |
|
|
|
is_in_get_ = false; |
|
|
|
is_in_switch_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
int idx_{-1}; |
|
|
|
AnfNodePtr switch_layer_{nullptr}, x_{nullptr}; |
|
|
|
std::vector<FuncGraphPtr> graphs_{}; |
|
|
|
bool is_in_get_{false}, is_in_switch_{false}; |
|
|
|
std::vector<AnfNodePtr> args_{}; |
|
|
|
internal::GetitemTransform getitem_transform_; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C} |
|
|
|
class IncorporateGetitemSwitchLayerB : public AnfVisitor { |
|
|
|
public: |
|
|
|
IncorporateGetitemSwitchLayerB() : getitem_transform_() {} |
|
|
|
~IncorporateGetitemSwitchLayerB() override = default; |
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
is_in_get_ = true; |
|
|
|
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); |
|
|
|
is_in_get_ = false; |
|
|
|
|
|
|
|
auto fg = node->func_graph(); |
|
|
|
if (idx_ == -1 || switch_layer_call_ == nullptr || !switch_layer_call_->isa<CNode>() || fg == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto &switch_layer_call_inputs = switch_layer_call_->cast<CNodePtr>()->inputs(); |
|
|
|
(void)std::copy(switch_layer_call_inputs.begin() + 1, switch_layer_call_inputs.end(), std::back_inserter(args_)); |
|
|
|
|
|
|
|
is_in_switch_ = true; |
|
|
|
AnfVisitor::Match(prim::kPrimSwitchLayer, {IsNode, IsCNode})(switch_layer_call_inputs[0]); |
|
|
|
is_in_switch_ = false; |
|
|
|
|
|
|
|
if (graphs_.empty()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> layers; |
|
|
|
for (auto &graph : graphs_) { |
|
|
|
auto fg_transform = getitem_transform_(graph, idx_); |
|
|
|
if (fg_transform == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
layers.push_back(NewValueNode(fg_transform)); |
|
|
|
} |
|
|
|
auto layers_node = fg->NewCNode(prim::kPrimMakeTuple, layers); |
|
|
|
std::vector<AnfNodePtr> sw_args{NewValueNode(prim::kPrimSwitchLayer), x_, layers_node}; |
|
|
|
auto sw_node = fg->NewCNode(sw_args); |
|
|
|
(void)args_.insert(args_.begin(), sw_node); |
|
|
|
auto call_switch_layer = fg->NewCNode(args_); |
|
|
|
(void)outer_call_args_.insert(outer_call_args_.begin(), call_switch_layer); |
|
|
|
return fg->NewCNode(outer_call_args_); |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
if (is_in_switch_ && x_ == nullptr) { |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
AnfVisitor::Visit(node); |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const CNodePtr &cnode) override { |
|
|
|
if (is_in_get_ && cnode->size() != 0) { |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
switch_layer_call_ = inputs[0]; |
|
|
|
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outer_call_args_)); |
|
|
|
} |
|
|
|
if (is_in_switch_ && cnode->size() > 2) { |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) { |
|
|
|
(void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), |
|
|
|
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); }); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { |
|
|
|
if (is_in_get_) { |
|
|
|
idx_ = GetValue<int>(vnode->value()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Reset() { |
|
|
|
x_ = nullptr; |
|
|
|
graphs_.clear(); |
|
|
|
switch_layer_call_ = nullptr; |
|
|
|
args_.clear(); |
|
|
|
outer_call_args_.clear(); |
|
|
|
is_in_get_ = false; |
|
|
|
is_in_switch_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
int idx_{-1}; |
|
|
|
AnfNodePtr switch_layer_call_{nullptr}, x_{nullptr}; |
|
|
|
std::vector<FuncGraphPtr> graphs_{}; |
|
|
|
bool is_in_get_{false}, is_in_switch_{false}; |
|
|
|
std::vector<AnfNodePtr> args_{}; |
|
|
|
std::vector<AnfNodePtr> outer_call_args_{}; |
|
|
|
internal::GetItemTransformACrossGraph getitem_transform_; |
|
|
|
}; |
|
|
|
|
|
|
|
class IncorporateGetitemSet : public OptimizerCaller { |
|
|
|
public: |
|
|
|
IncorporateGetitemSet() |
|
|
|
: incorporate_getitem_(std::make_shared<IncorporateGetitem>()), |
|
|
|
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) { |
|
|
|
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()), |
|
|
|
incorporate_getitem_switch_layer_a_(std::make_shared<IncorporateGetitemSwitchLayerA>()), |
|
|
|
incorporate_getitem_switch_layer_b_(std::make_shared<IncorporateGetitemSwitchLayerB>()) { |
|
|
|
eliminaters_.emplace_back(incorporate_getitem_); |
|
|
|
eliminaters_.emplace_back(incorporate_getitem_switch_); |
|
|
|
eliminaters_.emplace_back(incorporate_getitem_switch_layer_a_); |
|
|
|
eliminaters_.emplace_back(incorporate_getitem_switch_layer_b_); |
|
|
|
} |
|
|
|
~IncorporateGetitemSet() = default; |
|
|
|
|
|
|
|
@@ -407,7 +639,8 @@ class IncorporateGetitemSet : public OptimizerCaller { |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; |
|
|
|
OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_, incorporate_getitem_switch_layer_a_, |
|
|
|
incorporate_getitem_switch_layer_b_; |
|
|
|
std::vector<OptimizerCallerPtr> eliminaters_{}; |
|
|
|
}; |
|
|
|
} // namespace irpass |
|
|
|
|