Browse Source

Remove redundant phi nodes

tags/v1.1.0
yujianfeng 5 years ago
parent
commit
3176d377e6
4 changed files with 168 additions and 7 deletions
  1. +12
    -5
      mindspore/ccsrc/pipeline/jit/parse/function_block.cc
  2. +9
    -0
      mindspore/ccsrc/pipeline/jit/parse/function_block.h
  3. +145
    -0
      tests/st/control/test_cont_grad.py
  4. +2
    -2
      tests/ut/cpp/optimizer/opt_test.cc

+ 12
- 5
mindspore/ccsrc/pipeline/jit/parse/function_block.cc View File

@@ -49,11 +49,11 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
if (vars_.count(var)) { if (vars_.count(var)) {
AnfNodePtr node = vars_[var]; AnfNodePtr node = vars_[var];
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
return NewValueNode(GetValueNode(node));
} else {
return node;
auto iter = resolve_to_removable_phis_.find(node);
if (iter != resolve_to_removable_phis_.end()) {
return iter->second;
} }
return node;
} }
// get var from predecessor block ,if can't get the make a resolve node to it // get var from predecessor block ,if can't get the make a resolve node to it
if (matured_) { if (matured_) {
@@ -64,7 +64,13 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
return block->ReadVariable(var); return block->ReadVariable(var);
} else if (prev_blocks_.empty()) { } else if (prev_blocks_.empty()) {
// get namespace and make Reslove // get namespace and make Reslove
return MakeResolveSymbol(var);
auto it = var_to_resolve_.find(var);
if (it != var_to_resolve_.end()) {
return it->second;
}
auto tmp_node = MakeResolveSymbol(var);
var_to_resolve_[var] = tmp_node;
return tmp_node;
} }
} }
// If have more than one predecessor blocks then build a phi node. // If have more than one predecessor blocks then build a phi node.
@@ -217,6 +223,7 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
// replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
WriteVariable(var, arg_node); WriteVariable(var, arg_node);
removable_phis_[phi] = arg_node; removable_phis_[phi] = arg_node;
resolve_to_removable_phis_[arg_node] = phi;
// The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized
// recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node.
for (auto &prev : prev_blocks_) { for (auto &prev : prev_blocks_) {


+ 9
- 0
mindspore/ccsrc/pipeline/jit/parse/function_block.h View File

@@ -101,12 +101,21 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// keeps all removable phis which will be removed in one pass. // keeps all removable phis which will be removed in one pass.
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;


// Keeps the map for the resolve node to the removable phi node.
// For the case that ReadVariable returns a phi node although this phi node
// generated in the prev block is identified as removable. The other blocks
// should find this phi node.
std::unordered_map<AnfNodePtr, ParameterPtr> resolve_to_removable_phis_;

// hold declared global variables in function // hold declared global variables in function
std::set<std::string> global_vars_; std::set<std::string> global_vars_;


// other depend need to insert before function return nodes. // other depend need to insert before function return nodes.
// summary or some other node // summary or some other node
std::vector<AnfNodePtr> auto_depends_; std::vector<AnfNodePtr> auto_depends_;

// keeps the new made resolve symbol for the variable not found in vars_.
std::unordered_map<std::string, AnfNodePtr> var_to_resolve_;
}; };


} // namespace parse } // namespace parse


+ 145
- 0
tests/st/control/test_cont_grad.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
""" test control ops """ """ test control ops """
import numpy as np import numpy as np
import pytest


from mindspore import dtype as ms from mindspore import dtype as ms
from mindspore import Tensor from mindspore import Tensor
@@ -1150,3 +1151,147 @@ def test_if_by_if_forward_all_const_branch():
end = Tensor(np.array(3), dtype=ms.float32) end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x) net(idx, end, x)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_if_const_grad():
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()

def construct(self, *inputs):
out = self.add(*inputs)
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())

def construct(self, *inputs):
a = 1
b = 2
if a > 0:
b = 1
a += b
return grad_by_list(self.net, self.weights)(*inputs)

context.set_context(mode=context.GRAPH_MODE)
my_net = MyNet()
net = GradNet(my_net)
a = Tensor(np.array(0), dtype=ms.int32)
b = Tensor(np.array(1), dtype=ms.int32)
net(a, b)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_if_by_if_const_grad():
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()

def construct(self, *inputs):
out = self.add(*inputs)
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())

def construct(self, *inputs):
a = 1
b = 2
if a > 0:
b = 1
if a < 0:
b = 0
if a == 0:
b = 3
a += b
return grad_by_list(self.net, self.weights)(*inputs)

context.set_context(mode=context.GRAPH_MODE)
my_net = MyNet()
net = GradNet(my_net)
a = Tensor(np.array(0), dtype=ms.int32)
b = Tensor(np.array(1), dtype=ms.int32)
net(a, b)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_while_const_grad():
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()

def construct(self, *inputs):
out = self.add(*inputs)
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())

def construct(self, *inputs):
a = 1
while a > 1:
a = a - 1
return grad_by_list(self.net, self.weights)(*inputs)

context.set_context(mode=context.GRAPH_MODE)
my_net = MyNet()
net = GradNet(my_net)
a = Tensor(np.array(0), dtype=ms.int32)
b = Tensor(np.array(1), dtype=ms.int32)
net(a, b)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_if_by_while_const_grad():
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()

def construct(self, *inputs):
out = self.add(*inputs)
return out

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())

def construct(self, *inputs):
a = 1
b = 2
if a > 0:
b = 0
while a > 1:
a = a - 1
a += b
return grad_by_list(self.net, self.weights)(*inputs)

context.set_context(mode=context.GRAPH_MODE)
my_net = MyNet()
net = GradNet(my_net)
a = Tensor(np.array(0), dtype=ms.int32)
b = Tensor(np.array(1), dtype=ms.int32)
net(a, b)

+ 2
- 2
tests/ut/cpp/optimizer/opt_test.cc View File

@@ -187,7 +187,7 @@ TEST_F(TestOptOpt, CSE) {
FuncGraphManagerPtr manager1 = Manage(test_graph1); FuncGraphManagerPtr manager1 = Manage(test_graph1);
draw::Draw("opt_cse_before_1.dot", test_graph1); draw::Draw("opt_cse_before_1.dot", test_graph1);


ASSERT_EQ(manager1->all_nodes().size(), 10);
ASSERT_EQ(manager1->all_nodes().size(), 9);


auto cse = std::make_shared<CSE>(); auto cse = std::make_shared<CSE>();
ASSERT_TRUE(cse != nullptr); ASSERT_TRUE(cse != nullptr);
@@ -205,7 +205,7 @@ TEST_F(TestOptOpt, CSE) {


FuncGraphManagerPtr manager2 = Manage(test_graph2); FuncGraphManagerPtr manager2 = Manage(test_graph2);
draw::Draw("opt_cse_before_2.dot", test_graph2); draw::Draw("opt_cse_before_2.dot", test_graph2);
ASSERT_EQ(manager2->all_nodes().size(), 22);
ASSERT_EQ(manager2->all_nodes().size(), 16);
is_changed = cse->Cse(test_graph2, manager2); is_changed = cse->Cse(test_graph2, manager2);
ASSERT_TRUE(is_changed); ASSERT_TRUE(is_changed);
ASSERT_EQ(manager2->all_nodes().size(), 12); ASSERT_EQ(manager2->all_nodes().size(), 12);


Loading…
Cancel
Save