diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 9c6e2f853e..ab061e1fa1 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -395,9 +395,9 @@ int LiteSession::CompileGraph(Model *model) { } // scheduler kernels #if SUPPORT_NPU - Scheduler scheduler(context_, model, &tensors_, npu_manager_, npu_pass_manager_); + Scheduler scheduler(context_, model, &tensors_, is_train_session_, npu_manager_, npu_pass_manager_); #else - Scheduler scheduler(context_, model, &tensors_); + Scheduler scheduler(context_, model, &tensors_, is_train_session_); #endif ret = scheduler.Schedule(&kernels_); if (ret != RET_OK) { @@ -599,7 +599,7 @@ LiteSession::~LiteSession() { npu_manager_->Reset(); delete npu_manager_; #endif -#if GPU_OPENCL && !SUPPORT_TRAIN +#if GPU_OPENCL delete opencl_runtime_wrapper_; #endif delete (model_); @@ -737,7 +737,7 @@ int LiteSession::Resize(const std::vector &inputs } int LiteSession::InitGPURuntime() { -#if GPU_OPENCL && !SUPPORT_TRAIN +#if GPU_OPENCL if (this->context_->IsGpuEnabled()) { opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper(); if (opencl_runtime_wrapper_ == nullptr) { @@ -754,7 +754,7 @@ int LiteSession::InitGPURuntime() { MS_LOG(INFO) << "Init OpenCL runtime success."; } } -#elif GPU_VULKAN && !SUPPORT_TRAIN +#elif GPU_VULKAN if (this->context_->IsGpuEnabled()) { auto gpu_device_info = this->context_->GetGpuInfo(); vk_runtime_wrap_ = new (std::nothrow) gpu::GpuRuntimeWrapper; diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 89672c1dfc..43d8bf8cb1 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -134,13 +134,14 @@ class LiteSession : public session::LiteSession { Executor *executor_ = nullptr; Model *model_ = nullptr; std::atomic is_running_ = false; + bool is_train_session_ = false; #if SUPPORT_NPU NPUManager *npu_manager_ = nullptr; NPUPassManager *npu_pass_manager_ = nullptr; #endif -#if GPU_OPENCL && !SUPPORT_TRAIN +#if GPU_OPENCL opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr}; -#elif GPU_VULKAN && !SUPPORT_TRAIN +#elif GPU_VULKAN gpu::GpuRuntimeWrapper *vk_runtime_wrap_{nullptr}; #endif }; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index b198e9679a..59a5a2b662 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -215,7 +215,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter } namespace { -#ifndef SUPPORT_TRAIN int CastConstTensorData(Tensor *tensor, std::map *restored_origin_tensors, TypeId dst_data_type) { #if defined(ENABLE_ARM) && defined(ENABLE_FP16) MS_ASSERT(tensor != nullptr); @@ -319,7 +318,6 @@ int CopyConstTensorData(const std::vector &tensors, int op_type) { } return RET_OK; } -#endif inline void FreeRestoreTensors(std::map *restored_origin_tensors) { MS_ASSERT(restored_origin_tensors != nullptr); @@ -368,19 +366,20 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector &in_ten return nullptr; } std::map restored_origin_tensors; -#ifndef SUPPORT_TRAIN - ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); - if (ret != RET_OK) { - MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; - return nullptr; - } - // we don't need to restore tensor for copy data - ret = CopyConstTensorData(in_tensors, op_type); - if (ret != RET_OK) { - MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; - return nullptr; + + if (!is_train_session_) { + ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; + return nullptr; + } + // we don't need to restore tensor for copy data + ret = CopyConstTensorData(in_tensors, op_type); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; + return nullptr; + } } -#endif auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); if (kernel != nullptr) { MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index e6efef57e6..1d69d4519e 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -30,16 +30,17 @@ namespace mindspore::lite { class Scheduler { public: - Scheduler(const InnerContext *ctx, Model *src_model, std::vector *src_tensors) - : context_(ctx), src_model_(src_model), src_tensors_(src_tensors) {} + Scheduler(const InnerContext *ctx, Model *src_model, std::vector *src_tensors, bool is_train_session) + : context_(ctx), src_model_(src_model), src_tensors_(src_tensors), is_train_session_(is_train_session) {} #if SUPPORT_NPU - Scheduler(const InnerContext *ctx, Model *src_model, std::vector *src_tensors, + Scheduler(const InnerContext *ctx, Model *src_model, std::vector *src_tensors, bool is_train_session, NPUManager *npu_manager = nullptr, NPUPassManager *npu_pass_manager = nullptr) : context_(ctx), src_model_(src_model), src_tensors_(src_tensors), npu_manager_(npu_manager), - npu_pass_manager_(npu_pass_manager) {} + npu_pass_manager_(npu_pass_manager), + is_train_session_(is_train_session) {} #endif ~Scheduler() = default; @@ -113,6 +114,7 @@ class Scheduler { #endif std::vector graph_output_node_indexes_; std::map op_parameters_; + bool is_train_session_ = false; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index 51cf64040f..1392652772 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -166,11 +166,6 @@ int CpuSubGraph::Run(const KernelCallBack &before, const KernelCallBack &after) } } -#ifdef SUPPORT_TRAIN - for (auto out_tensor : out_tensors_) { // increase RefCount of output tensors, such that Run will not free them - out_tensor->set_ref_count(out_tensor->ref_count() + 1); - } -#endif #ifdef SUPPORT_GPU // In heterogeneous scenarios of CPU and GPU, call MutableData to MapBuffer(synchronize data). if (static_cast(context_)->IsGpuEnabled()) { diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 7d15a40aef..143000d392 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -57,6 +57,7 @@ static kernel::LiteKernel *TSFindKernel(const std::vector return *it; } TrainSession::TrainSession() { + is_train_session_ = true; #ifdef ENABLE_V0 if (VersionManager::GetInstance()->CheckV0Schema()) { kernel::PopulateTrainV0Parameters();