|
|
|
@@ -19,6 +19,19 @@ import pytest |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore import Tensor, nn |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.ops.composite import GradOperation |
|
|
|
|
|
|
|
|
|
|
|
class Grad(nn.Cell): |
|
|
|
def __init__(self, net): |
|
|
|
super().__init__() |
|
|
|
self.grad = GradOperation(get_all=False) |
|
|
|
self.net = net |
|
|
|
|
|
|
|
def construct(self, x, y): |
|
|
|
grad_net = self.grad(self.net) |
|
|
|
grad = grad_net(x, y) |
|
|
|
return grad |
|
|
|
|
|
|
|
|
|
|
|
class CaseNet(nn.Cell): |
|
|
|
@@ -53,3 +66,51 @@ def test_switch_layer(): |
|
|
|
true_value = relu(data) |
|
|
|
ret = np.allclose(value.asnumpy(), true_value.asnumpy()) |
|
|
|
assert ret |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.platform_x86_gpu_training |
|
|
|
@pytest.mark.platform_x86_cpu_training |
|
|
|
@pytest.mark.env_onecard |
|
|
|
def test_cell_in_list(): |
|
|
|
""" |
|
|
|
Feature: Switch layer in while. |
|
|
|
Description: test recursive switch layer. |
|
|
|
Expectation: success if grad and output are correct. |
|
|
|
""" |
|
|
|
|
|
|
|
class TestCell(nn.Cell): |
|
|
|
def __init__(self, i): |
|
|
|
super().__init__() |
|
|
|
self.i = i |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
return self.i * x |
|
|
|
|
|
|
|
class CellInList(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.cell_list = nn.CellList() |
|
|
|
self.cell_list.append(TestCell(4)) |
|
|
|
self.cell_list.append(TestCell(5)) |
|
|
|
self.cell_list.append(TestCell(6)) |
|
|
|
|
|
|
|
def construct(self, t, x): |
|
|
|
out = t |
|
|
|
while x < 3: |
|
|
|
add = self.cell_list[x](t) |
|
|
|
out = out + add |
|
|
|
x += 1 |
|
|
|
return out |
|
|
|
|
|
|
|
net = CellInList() |
|
|
|
t = Tensor(10, mstype.int32) |
|
|
|
x = Tensor(0, mstype.int32) |
|
|
|
out = net(t, x) |
|
|
|
grad_net = Grad(net) |
|
|
|
grad_out = grad_net(t, x) |
|
|
|
|
|
|
|
assert out == Tensor(160, mstype.int32) |
|
|
|
assert grad_out == Tensor(16, mstype.int32) |