diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ff7949cce4..0a5aa90872 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -130,6 +130,13 @@ static std::string GetId(const py::object &obj) { } return prefix + key; } + if (py::isinstance(to_process)) { + auto type_ptr = py::cast(to_process); + return prefix + type_ptr->ToString(); + } + if (py::isinstance(to_process)) { + return prefix + std::string(py::str(to_process)); + } if (py::isinstance(to_process)) { return prefix + std::string(py::str(to_process)); } @@ -1253,17 +1260,24 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje pipeline::ReclaimOptimizer(); } +template +void MapClear(T map, const std::string &flag) { + for (auto it = map.begin(); it != map.end();) { + if (it->first.find(flag) != std::string::npos) { + it->second = nullptr; + it = map.erase(it); + } else { + it++; + } + } +} + void PynativeExecutor::Clear(const std::string &flag) { if (!flag.empty()) { MS_LOG(DEBUG) << "Clear res"; - auto key_value = std::find_if(graph_map_.begin(), graph_map_.end(), - [&flag](const auto &item) { return item.first.find(flag) != std::string::npos; }); - if (key_value != graph_map_.end()) { - std::string key = key_value->first; - (void)graph_map_.erase(key); - (void)cell_graph_map_.erase(key); - (void)cell_resource_map_.erase(key); - } + MapClear>(graph_map_, flag); + MapClear>(cell_graph_map_, flag); + MapClear>(cell_resource_map_, flag); Clean(); // Maybe exit in the pynative runing op, so need reset pynative flag. auto ms_context = MsContext::GetInstance(); @@ -1281,7 +1295,6 @@ void PynativeExecutor::Clear(const std::string &flag) { curr_g_ = nullptr; graph_info_map_.clear(); op_id_map_.clear(); - // node_abs_map_.clear(); std::stack().swap(graph_p_); ConfigManager::GetInstance().ResetIterNum(); } @@ -1295,7 +1308,18 @@ void PynativeExecutor::Clean() { pipeline::ReclaimOptimizer(); } +template +void MapErase(T map) { + for (auto it = map.begin(); it != map.end();) { + it = map.erase(it++); + } +} + void PynativeExecutor::ClearRes() { + MapErase>(graph_map_); + MapErase>(cell_graph_map_); + MapErase>(cell_resource_map_); + MapErase>(node_abs_map_); Clean(); resource_.reset(); } diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 36973f038b..1d27f1be59 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -102,13 +102,17 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) py::object grad_dtype = grads[i].attr("dtype"); py::tuple arg_shape = py_args[i].attr("shape"); py::object arg_dtype = py_args[i].attr("dtype"); - if (!grad_shape.equal(arg_shape) || !grad_dtype.is(arg_dtype)) { - MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i - << "th arg should have the same shape and dtype as the " << i << "th arg, but the " - << i << "th arg shape: " << py::cast(arg_shape) - << " and dtype: " << py::cast(arg_dtype) - << ", the gradient shape: " << py::cast(grad_shape) - << " and dtype: " << py::cast(grad_dtype) << "."; + if (!grad_shape.equal(arg_shape)) { + MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i + << "th arg should have the same shape as the " << i << "th arg, but the " << i + << "th arg shape is: " << py::cast(arg_shape) + << ", the gradient shape is: " << py::cast(grad_shape) << "."; + } + if (!grad_dtype.is(arg_dtype)) { + MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i + << "th arg should have the same dtype as the " << i << "th arg, but the " << i + << "th arg dtype is: " << py::cast(arg_dtype) + << ", the gradient dtype is: " << py::cast(grad_dtype) << "."; } } } diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index d11d35fccd..47a814fc23 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -227,6 +227,7 @@ def dtype_to_pytype(type_): return { bool_: bool, + int_: int, int8: int, int16: int, int32: int, @@ -235,6 +236,7 @@ def dtype_to_pytype(type_): uint16: int, uint32: int, uint64: int, + float_: float, float16: float, float32: float, float64: float, diff --git a/tests/ut/python/pynative_mode/test_user_define_bprop_check.py b/tests/ut/python/pynative_mode/test_user_define_bprop_check.py index 6ebe94aceb..0485e6428b 100644 --- a/tests/ut/python/pynative_mode/test_user_define_bprop_check.py +++ b/tests/ut/python/pynative_mode/test_user_define_bprop_check.py @@ -116,7 +116,6 @@ def test_user_define_bprop_check_shape(): grad_net = GradNet(net) with pytest.raises(ValueError) as ex: ret = grad_net(x, sens) - assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value) def test_user_define_bprop_check_dtype(): @@ -145,9 +144,8 @@ def test_user_define_bprop_check_dtype(): context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True) net = Net() grad_net = GradNet(net) - with pytest.raises(ValueError) as ex: + with pytest.raises(TypeError) as ex: ret = grad_net(x, sens) - assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value) def test_user_define_bprop_check_parameter():