Browse Source

!27372 fix permission bug when remove read_only_file on windows

Merge pull request !27372 from zhangbuxue/fix_permission_bug_when_remove_read_only_file_on_windows
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
d46e047fe1
3 changed files with 38 additions and 14 deletions
  1. +6
    -6
      mindspore/ccsrc/pybind_api/ir/func_graph_py.cc
  2. +4
    -0
      tests/st/export_and_load/test_get_and_init_graph_cell_parameters.py
  3. +28
    -8
      tests/ut/python/mindir/test_init_graph_cell_parameters_with_illegal_data.py

+ 6
- 6
mindspore/ccsrc/pybind_api/ir/func_graph_py.cc View File

@@ -39,12 +39,12 @@ py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::di
const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(new_value);
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
"The parameter '"
<< param_name.cast<std::string>() << "' has shape " << old_value->shape()
<< " and dtype " << TypeIdLabel(old_value->data_type())
<< ", but got the update Tensor with shape " << new_value->shape() << " and dtype "
<< TypeIdLabel(new_value->data_type()) << ".";
MS_EXCEPTION(ValueError)
<< "Only support update parameter by Tensor or Parameter with same shape and dtype as it. "
"The parameter '"
<< param_name.cast<std::string>() << "' has shape " << old_value->shape() << " and dtype "
<< TypeIdLabel(old_value->data_type()) << ", but got the update value with shape " << new_value->shape()
<< " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
}
new_param = fn(*new_value);
} else {


+ 4
- 0
tests/st/export_and_load/test_get_and_init_graph_cell_parameters.py View File

@@ -14,7 +14,9 @@
# ============================================================================

"""test get and init GraphCell parameters"""

import os
import stat

import numpy as np
import pytest
@@ -72,8 +74,10 @@ def get_and_init_graph_cell_parameters():
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 2.0)

if os.path.isfile(mindir_name):
os.chmod(mindir_name, stat.S_IWUSR)
os.remove(mindir_name)
if os.path.isfile(ckpt_name):
os.chmod(ckpt_name, stat.S_IWUSR)
os.remove(ckpt_name)




+ 28
- 8
tests/ut/python/mindir/test_init_graph_cell_parameters_with_illegal_data.py View File

@@ -57,7 +57,27 @@ def remove_generated_file(file_name):
def test_init_graph_cell_parameters_with_wrong_type():
"""
Description: load mind ir and update parameters with wrong type.
Expectation: raise a ValueError indicating the params type error.
Expectation: raise a ValueError indicating the params_init type error.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
mindir_name = "net_0.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')

new_params = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)
with pytest.raises(TypeError) as err:
graph = load(mindir_name)
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)

assert "The 'params_init' must be a dict, but got" in str(err.value)
remove_generated_file(mindir_name)


def test_init_graph_cell_parameters_with_wrong_value_type():
"""
Description: load mind ir and update parameters with wrong value type.
Expectation: raise a ValueError indicating the params value type error.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
@@ -74,10 +94,10 @@ def test_init_graph_cell_parameters_with_wrong_type():
remove_generated_file(mindir_name)


def test_init_graph_cell_parameters_with_wrong_shape():
def test_init_graph_cell_parameters_with_wrong_value_shape():
"""
Description: load mind ir and update parameters with wrong tensor shape.
Expectation: raise a ValueError indicating the tensor shape error.
Expectation: raise a ValueError indicating the update value shape error.
"""
context.set_context(mode=context.PYNATIVE_MODE)
net = Net()
@@ -90,25 +110,25 @@ def test_init_graph_cell_parameters_with_wrong_shape():
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)

assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
assert "Only support update parameter by Tensor or Parameter with same shape and dtype as it" in str(err.value)
remove_generated_file(mindir_name)


def test_init_graph_cell_parameters_with_wrong_dtype():
def test_init_graph_cell_parameters_with_wrong_value_dtype():
"""
Description: load mind ir and update parameters with wrong tensor dtype.
Expectation: raise a ValueError indicating the tensor dtype error.
Expectation: raise a ValueError indicating the update value dtype error.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
mindir_name = "net_3.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')

new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
new_params = {"weight": Tensor(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
with pytest.raises(ValueError) as err:
graph = load(mindir_name)
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)

assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
assert "Only support update parameter by Tensor or Parameter with same shape and dtype as it" in str(err.value)
remove_generated_file(mindir_name)

Loading…
Cancel
Save