diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 5a99fa4e7b..b457f65e81 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1719,6 +1719,11 @@ void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull &ro } void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr> &useless_tensors) { + auto ms_context = MsContext::GetInstance(); + std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); + if (device_target == "CPU") { + return; + } for (const auto &tensor : *useless_tensors) { MS_EXCEPTION_IF_NULL(tensor); const auto &shape = tensor->shape(); diff --git a/tests/st/pynative/test_pynative_lenet.py b/tests/st/pynative/test_pynative_lenet.py index 75b5d0cfe5..29a7ddbc40 100644 --- a/tests/st/pynative/test_pynative_lenet.py +++ b/tests/st/pynative/test_pynative_lenet.py @@ -131,9 +131,11 @@ class GradWrap(nn.Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_ascend_pynative_lenet(): - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + context.set_context(mode=context.PYNATIVE_MODE) epoch_size = 20 batch_size = 32