|
|
|
@@ -49,9 +49,8 @@ def test_switch_mode(): |
|
|
|
def test_set_device_id(): |
|
|
|
""" test_set_device_id """ |
|
|
|
with pytest.raises(TypeError): |
|
|
|
context.set_context(device_id=1) |
|
|
|
context.set_context(device_id="cpu") |
|
|
|
assert context.get_context("device_id") == 0 |
|
|
|
context.set_context(device_id=1) |
|
|
|
assert context.get_context("device_id") == 1 |
|
|
|
|
|
|
|
|
|
|
|
@@ -115,14 +114,17 @@ def test_variable_memory_max_size(): |
|
|
|
context.set_context(variable_memory_max_size=True) |
|
|
|
with pytest.raises(TypeError): |
|
|
|
context.set_context(variable_memory_max_size=1) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
context.set_context(variable_memory_max_size="") |
|
|
|
with pytest.raises(ValueError): |
|
|
|
context.set_context(variable_memory_max_size="1G") |
|
|
|
with pytest.raises(ValueError): |
|
|
|
context.set_context(variable_memory_max_size="32GB") |
|
|
|
context.set_context(variable_memory_max_size="3GB") |
|
|
|
context.set_context.__wrapped__(variable_memory_max_size="3GB") |
|
|
|
|
|
|
|
def test_max_device_memory_size(): |
|
|
|
"""test_max_device_memory_size""" |
|
|
|
with pytest.raises(TypeError): |
|
|
|
context.set_context(max_device_memory=True) |
|
|
|
with pytest.raises(TypeError): |
|
|
|
context.set_context(max_device_memory=1) |
|
|
|
context.set_context(max_device_memory="3.5G") |
|
|
|
context.set_context.__wrapped__(max_device_memory="3GB") |
|
|
|
|
|
|
|
def test_print_file_path(): |
|
|
|
"""test_print_file_path""" |
|
|
|
@@ -132,8 +134,9 @@ def test_print_file_path(): |
|
|
|
|
|
|
|
def test_set_context(): |
|
|
|
""" test_set_context """ |
|
|
|
context.set_context.__wrapped__(device_id=0) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", |
|
|
|
device_id=0, save_graphs=True, save_graphs_path="mindspore_ir_path") |
|
|
|
save_graphs=True, save_graphs_path="mindspore_ir_path") |
|
|
|
assert context.get_context("device_id") == 0 |
|
|
|
assert context.get_context("device_target") == "Ascend" |
|
|
|
assert context.get_context("save_graphs") |
|
|
|
|