|
|
|
@@ -122,8 +122,9 @@ class GradWrap(nn.Cell): |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.env_single |
|
|
|
@pytest.mark.env_onecard |
|
|
|
def test_ascend_pynative_lenet(): |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") |
|
|
|
|
|
|
|
@@ -152,6 +153,5 @@ def test_ascend_pynative_lenet(): |
|
|
|
total_time = total_time + cost_time |
|
|
|
|
|
|
|
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) |
|
|
|
assert(total_time < 20.0) |
|
|
|
assert(loss_output.asnumpy() < 0.01) |
|
|
|
assert(loss_output.asnumpy() < 0.1) |
|
|
|
|