From a0680113c4ca5c00d0dee89c03c6be3835638346 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Fri, 11 Dec 2020 14:52:55 +0800 Subject: [PATCH] fix cpu pynative coredump --- mindspore/ccsrc/backend/session/session_basic.cc | 5 +++++ tests/st/pynative/test_pynative_lenet.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 4fb20f4d61..896958992e 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1718,6 +1718,11 @@ void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull &ro } void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr> &useless_tensors) { + auto ms_context = MsContext::GetInstance(); + std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); + if (device_target == "CPU") { + return; + } for (const auto &tensor : *useless_tensors) { tensor->set_device_address(nullptr); } diff --git a/tests/st/pynative/test_pynative_lenet.py b/tests/st/pynative/test_pynative_lenet.py index 75b5d0cfe5..29a7ddbc40 100644 --- a/tests/st/pynative/test_pynative_lenet.py +++ b/tests/st/pynative/test_pynative_lenet.py @@ -131,9 +131,11 @@ class GradWrap(nn.Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_ascend_pynative_lenet(): - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + context.set_context(mode=context.PYNATIVE_MODE) epoch_size = 20 batch_size = 32