Browse Source

1. if Switch/SwitchLayer, do not replace Load or remove UpdateState; 2. add control flow testcases; 3. fix codedex problem

pull/15784/head
huangbingjian 4 years ago
parent
commit
2a85af5d83
7 changed files with 90 additions and 3 deletions
  1. +1
    -1
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
  2. +5
    -0
      mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc
  3. +8
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc
  4. +1
    -1
      mindspore/ccsrc/frontend/optimizer/opt.cc
  5. +1
    -1
      mindspore/core/ir/tensor.cc
  6. +49
    -0
      tests/st/control/inner/test_030_for_in_if.py
  7. +25
    -0
      tests/st/control/inner/test_100_if_after_if.py

+ 1
- 1
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc View File

@@ -880,7 +880,7 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
// To find J user. // To find J user.
auto j_user = GetJUser(node_user_map, cnode, index); auto j_user = GetJUser(node_user_map, cnode, index);
primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
(void)primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
} }
} }




+ 5
- 0
mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc View File

@@ -85,6 +85,11 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
if (IsPrimitiveCNode(node, prim::kPrimLoad)) { if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
return false; return false;
} }
// if Call/Switch/SwitchLayer, do not replace load.
if (IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) ||
IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) {
return true;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs(); auto &inputs = cnode->inputs();
return std::any_of(inputs.begin(), inputs.end(), return std::any_of(inputs.begin(), inputs.end(),


+ 8
- 0
mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc View File

@@ -28,6 +28,7 @@ namespace {
// data = Load(input, attach) // data = Load(input, attach)
// data = Depend(input, attach) // data = Depend(input, attach)
// monad = UpdateState(input, attach) // monad = UpdateState(input, attach)
constexpr size_t kFirstInputIndex = 0;
constexpr size_t kInputIndex = 1; constexpr size_t kInputIndex = 1;
constexpr size_t kAttachIndex = 2; constexpr size_t kAttachIndex = 2;
constexpr size_t kMakeTupleSize = 3; constexpr size_t kMakeTupleSize = 3;
@@ -120,6 +121,13 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
return nullptr; return nullptr;
} }
} }
// Skip Call/Switch/SwitchLayer.
auto first_input_node = cnode->input(kFirstInputIndex);
if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
return nullptr;
}

// Remove UpdateState by replace it with its input monad. // Remove UpdateState by replace it with its input monad.
return update_state->input(kInputIndex); return update_state->input(kInputIndex);
} }


+ 1
- 1
mindspore/ccsrc/frontend/optimizer/opt.cc View File

@@ -246,7 +246,7 @@ void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std:
<< std::endl; << std::endl;
for (size_t i = 0; i < list_.size(); i++) { for (size_t i = 0; i < list_.size(); i++) {
auto name = list_[i]->name_; auto name = list_[i]->name_;
ss << std::left << std::setw(space + 4) << name << "\t";
ss << std::left << std::setw(SizeToInt(space) + 4) << name << "\t";
for (auto change : status.at(name + std::to_string(i))) { for (auto change : status.at(name + std::to_string(i))) {
ss << change << " "; ss << change << " ";
} }


+ 1
- 1
mindspore/core/ir/tensor.cc View File

@@ -393,7 +393,7 @@ class TensorDataImpl : public TensorData {
pos++; pos++;
} }
size_t len = pos - index; size_t len = pos - index;
std::string space(max_width - len, ' ');
std::string space(max_width - SizeToInt(len), ' ');
str = str.replace(index, len, space); str = str.replace(index, len, space);
index = str.find('#', index); index = str.find('#', index);
} }


+ 49
- 0
tests/st/control/inner/test_030_for_in_if.py View File

@@ -170,3 +170,52 @@ def test_for_in_if_03():


assert graph_forward_res == pynative_forward_res assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res assert graph_backward_res == pynative_backward_res


def test_for_in_if_04():
class ForInIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x):
out = self.param_a
x = self.func(x)
out *= x
return out

def func(self, x):
if self.param_a > self.param_b:
for _ in range(0, 4):
self.param_a += 1
self.param_b -= 3
self.param_b += 10
return x

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net

def construct(self, *inputs):
return grad_all(self.net)(*inputs)

x = Tensor(5, mstype.int32)

# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
graph_forward_res = for_in_if_net(x)
graph_backward_res = net(x)

# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
pynative_forward_res = for_in_if_net(x)
pynative_backward_res = net(x)

assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

+ 25
- 0
tests/st/control/inner/test_100_if_after_if.py View File

@@ -74,6 +74,25 @@ class IfAfterIfNet2(nn.Cell):
return y return y




class IfAfterIfNet3(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')

def construct(self, x, y):
out = x * y + self.func(self.param_b)
if self.param_a > self.param_b:
out += 5
return out

def func(self, x):
if self.param_a > self.param_b:
x += 5
self.param_b += 4
return x


class GradNet(nn.Cell): class GradNet(nn.Cell):
def __init__(self, net): def __init__(self, net):
super(GradNet, self).__init__() super(GradNet, self).__init__()
@@ -118,3 +137,9 @@ def test_if_after_if_02():
x = Tensor(2, mstype.int32) x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32) y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet2, x, y) control_flow_if_after_if(IfAfterIfNet2, x, y)


def test_if_after_if_03():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet3, x, y)

Loading…
Cancel
Save