|
|
|
@@ -57,7 +57,27 @@ def remove_generated_file(file_name): |
|
|
|
def test_init_graph_cell_parameters_with_wrong_type(): |
|
|
|
""" |
|
|
|
Description: load mind ir and update parameters with wrong type. |
|
|
|
Expectation: raise a ValueError indicating the params type error. |
|
|
|
Expectation: raise a ValueError indicating the params_init type error. |
|
|
|
""" |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
net = Net() |
|
|
|
mindir_name = "net_0.mindir" |
|
|
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') |
|
|
|
|
|
|
|
new_params = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32) |
|
|
|
with pytest.raises(TypeError) as err: |
|
|
|
graph = load(mindir_name) |
|
|
|
load_net = nn.GraphCell(graph, params_init=new_params) |
|
|
|
load_net(input_a, input_b) |
|
|
|
|
|
|
|
assert "The 'params_init' must be a dict, but got" in str(err.value) |
|
|
|
remove_generated_file(mindir_name) |
|
|
|
|
|
|
|
|
|
|
|
def test_init_graph_cell_parameters_with_wrong_value_type(): |
|
|
|
""" |
|
|
|
Description: load mind ir and update parameters with wrong value type. |
|
|
|
Expectation: raise a ValueError indicating the params value type error. |
|
|
|
""" |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
net = Net() |
|
|
|
@@ -74,10 +94,10 @@ def test_init_graph_cell_parameters_with_wrong_type(): |
|
|
|
remove_generated_file(mindir_name) |
|
|
|
|
|
|
|
|
|
|
|
def test_init_graph_cell_parameters_with_wrong_shape(): |
|
|
|
def test_init_graph_cell_parameters_with_wrong_value_shape(): |
|
|
|
""" |
|
|
|
Description: load mind ir and update parameters with wrong tensor shape. |
|
|
|
Expectation: raise a ValueError indicating the tensor shape error. |
|
|
|
Expectation: raise a ValueError indicating the update value shape error. |
|
|
|
""" |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE) |
|
|
|
net = Net() |
|
|
|
@@ -90,25 +110,25 @@ def test_init_graph_cell_parameters_with_wrong_shape(): |
|
|
|
load_net = nn.GraphCell(graph, params_init=new_params) |
|
|
|
load_net(input_a, input_b) |
|
|
|
|
|
|
|
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value) |
|
|
|
assert "Only support update parameter by Tensor or Parameter with same shape and dtype as it" in str(err.value) |
|
|
|
remove_generated_file(mindir_name) |
|
|
|
|
|
|
|
|
|
|
|
def test_init_graph_cell_parameters_with_wrong_dtype(): |
|
|
|
def test_init_graph_cell_parameters_with_wrong_value_dtype(): |
|
|
|
""" |
|
|
|
Description: load mind ir and update parameters with wrong tensor dtype. |
|
|
|
Expectation: raise a ValueError indicating the tensor dtype error. |
|
|
|
Expectation: raise a ValueError indicating the update value dtype error. |
|
|
|
""" |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
net = Net() |
|
|
|
mindir_name = "net_3.mindir" |
|
|
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') |
|
|
|
|
|
|
|
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))} |
|
|
|
new_params = {"weight": Tensor(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))} |
|
|
|
with pytest.raises(ValueError) as err: |
|
|
|
graph = load(mindir_name) |
|
|
|
load_net = nn.GraphCell(graph, params_init=new_params) |
|
|
|
load_net(input_a, input_b) |
|
|
|
|
|
|
|
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value) |
|
|
|
assert "Only support update parameter by Tensor or Parameter with same shape and dtype as it" in str(err.value) |
|
|
|
remove_generated_file(mindir_name) |