|
|
|
@@ -13,8 +13,11 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""Test dump.""" |
|
|
|
import warnings |
|
|
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
import mindspore.context as context |
|
|
|
import mindspore.nn as nn |
|
|
|
import mindspore.ops as ops |
|
|
|
from mindspore import set_dump |
|
|
|
@@ -26,6 +29,7 @@ def test_set_dump_on_cell(): |
|
|
|
Description: Use set_dump API on Cell instance. |
|
|
|
Expectation: Success. |
|
|
|
""" |
|
|
|
|
|
|
|
class MyNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(MyNet, self).__init__() |
|
|
|
@@ -38,7 +42,9 @@ def test_set_dump_on_cell(): |
|
|
|
return x |
|
|
|
|
|
|
|
net = MyNet() |
|
|
|
set_dump(net.conv1) |
|
|
|
set_dump(net.relu1) |
|
|
|
|
|
|
|
assert net.relu1.relu.attrs["dump"] == "true" |
|
|
|
|
|
|
|
|
|
|
|
def test_set_dump_on_primitive(): |
|
|
|
@@ -49,6 +55,7 @@ def test_set_dump_on_primitive(): |
|
|
|
""" |
|
|
|
op = ops.Add() |
|
|
|
set_dump(op) |
|
|
|
assert op.attrs["dump"] == "true" |
|
|
|
|
|
|
|
|
|
|
|
def test_input_type_check(): |
|
|
|
@@ -59,3 +66,21 @@ def test_input_type_check(): |
|
|
|
""" |
|
|
|
with pytest.raises(ValueError): |
|
|
|
set_dump(1) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason="Warning can only be triggered once, please execute " |
|
|
|
"this test case manually.") |
|
|
|
def test_set_dump_warning(): |
|
|
|
""" |
|
|
|
Feature: Python API set_dump. |
|
|
|
Description: Test the warning about device target and mode. |
|
|
|
Expectation: Trigger warning message. |
|
|
|
""" |
|
|
|
context.set_context(device_target="CPU") |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE) |
|
|
|
op = ops.Add() |
|
|
|
with warnings.catch_warnings(record=True) as w: |
|
|
|
warnings.simplefilter("always") |
|
|
|
set_dump(op) |
|
|
|
assert "Only Ascend device target is supported" in str(w[-2].message) |
|
|
|
assert "Only GRAPH_MODE is supported" in str(w[-1].message) |