Browse Source

!25761 Fix switch layer recursive bug

Merge pull request !25761 from chenfei_mindspore/switch_layer_reursive_fix
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
0a4cc28c9d
2 changed files with 75 additions and 5 deletions
  1. +14
    -5
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +61
    -0
      tests/st/control/test_switch_layer.py

+ 14
- 5
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1957,6 +1957,19 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
}
return nullptr;
}

bool IsUnusedInternlOutput(const AnfNodePtr &user) {
if (!CNodeFirstInputIsPrimitive(user)) {
return true;
}
if (IsPrimitiveCNode(user, prim::kPrimSwitch) || IsPrimitiveCNode(user, prim::kPrimSwitchLayer)) {
return true;
}
if (!AnfAlgo::IsRealKernel(user)) {
return true;
}
return false;
}
} // namespace

constexpr auto kMixTarget = "MixTarget";
@@ -2040,11 +2053,7 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
continue;
}
if (!CNodeFirstInputIsPrimitive(user)) {
internal_output = false;
break;
}
if (!AnfAlgo::IsRealKernel(user)) {
if (IsUnusedInternlOutput(user)) {
internal_output = false;
break;
}


+ 61
- 0
tests/st/control/test_switch_layer.py View File

@@ -19,6 +19,19 @@ import pytest
import mindspore.context as context
from mindspore import Tensor, nn
from mindspore.common import dtype as mstype
from mindspore.ops.composite import GradOperation


class Grad(nn.Cell):
def __init__(self, net):
super().__init__()
self.grad = GradOperation(get_all=False)
self.net = net

def construct(self, x, y):
grad_net = self.grad(self.net)
grad = grad_net(x, y)
return grad


class CaseNet(nn.Cell):
@@ -53,3 +66,51 @@ def test_switch_layer():
true_value = relu(data)
ret = np.allclose(value.asnumpy(), true_value.asnumpy())
assert ret


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_cell_in_list():
"""
Feature: Switch layer in while.
Description: test recursive switch layer.
Expectation: success if grad and output are correct.
"""

class TestCell(nn.Cell):
def __init__(self, i):
super().__init__()
self.i = i

def construct(self, x):
return self.i * x

class CellInList(nn.Cell):
def __init__(self):
super().__init__()
self.cell_list = nn.CellList()
self.cell_list.append(TestCell(4))
self.cell_list.append(TestCell(5))
self.cell_list.append(TestCell(6))

def construct(self, t, x):
out = t
while x < 3:
add = self.cell_list[x](t)
out = out + add
x += 1
return out

net = CellInList()
t = Tensor(10, mstype.int32)
x = Tensor(0, mstype.int32)
out = net(t, x)
grad_net = Grad(net)
grad_out = grad_net(t, x)

assert out == Tensor(160, mstype.int32)
assert grad_out == Tensor(16, mstype.int32)

Loading…
Cancel
Save