From 6102202abd782adcbd0dacef29b853ec2071c835 Mon Sep 17 00:00:00 2001 From: tanghuikang Date: Tue, 9 Mar 2021 11:28:55 +0800 Subject: [PATCH] Not save InitDatasetQueue and GetNext op in PyNative Mode --- mindspore/ccsrc/backend/session/gpu_session.cc | 3 +++ mindspore/ccsrc/utils/utils.h | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 45f94f51f6..34fdd6bbdb 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -498,6 +498,9 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, UpdateOutputAbstract(kernel_graph, op_run_info); } RunOpClearMemory(kernel_graph.get()); + if (kOpCacheAllowList.find(op_run_info->op_name) != kOpCacheAllowList.end()) { + run_op_graphs_.erase(graph_info); + } } void GPUSession::Dump(const std::shared_ptr &kernel_graph) const { diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 37804d3848..078adee9f1 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -60,6 +60,7 @@ constexpr auto kFusedBatchNormGradExWithAddAndActivation = "FusedBatchNormGradEx constexpr auto kClearZeroOpName = "ClearZero"; constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; constexpr auto kGetNextOpName = "GetNext"; +constexpr auto kInitDatasetQueueOpName = "InitDataSetQueue"; constexpr auto kEndOfSequence = "EndOfSequence"; constexpr auto kAllReduceOpName = "AllReduce"; constexpr auto kAllGatherOpName = "AllGather"; @@ -505,7 +506,8 @@ const std::set kOptOperatorSet = {kMomentumOpName, const std::set kPosteriorOperatorSet = {kPullOpName}; -const std::set kOpCacheAllowList = {kUniformCandidateSamplerOpName}; +const std::set kOpCacheAllowList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName, + kGetNextOpName}; const std::set kHWSpecialFormatSet = { kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0,