Browse Source

Fix bugs of the ConvertSwitchReplacement::OptimizerCaller function

tags/v0.7.0-beta
wuyongkang 5 years ago
parent
commit
bb141b753a
1 changed files with 14 additions and 7 deletions
  1. +14
    -7
      mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h

+ 14
- 7
mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h View File

@@ -98,12 +98,19 @@ class ConvertSwitchReplacement : public OptimizerCaller {
return nullptr;
}

auto cnode_ = node->cast<CNodePtr>();
if (cnode_->size() < 1) {
return nullptr;
}

auto node_ = cnode_->input(0);

PatternNode<AnfNodePtr> cond, true_br, false_br;

auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node));
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node));
auto x_ = cond.GetNode(node);
auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr {
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node_));
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node_));
auto x_ = cond.GetNode(node_);

// for switch replace method, only graphs without graph inside can be replaced
for (auto &item : g1_->value_nodes()) {
@@ -126,7 +133,7 @@ class ConvertSwitchReplacement : public OptimizerCaller {
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);

std::vector<AnfNodePtr> params;
auto fg = node->func_graph();
auto fg = node_->func_graph();
auto cloned_g1 = InlineClone(trans_g1, fg, params);
auto cloned_g2 = InlineClone(trans_g2, fg, params);
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
@@ -135,8 +142,8 @@ class ConvertSwitchReplacement : public OptimizerCaller {
};

MATCH_REPLACE_LAMBDA_IF(
node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda,
true_br.CheckFunc(IsValueNode<FuncGraph>, node) && false_br.CheckFunc(IsValueNode<FuncGraph>, node));
node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
true_br.CheckFunc(IsValueNode<FuncGraph>, node_) && false_br.CheckFunc(IsValueNode<FuncGraph>, node_));

return nullptr;
}


Loading…
Cancel
Save