| @@ -395,9 +395,9 @@ int LiteSession::CompileGraph(Model *model) { | |||
| } | |||
| // scheduler kernels | |||
| #if SUPPORT_NPU | |||
| Scheduler scheduler(context_, model, &tensors_, npu_manager_, npu_pass_manager_); | |||
| Scheduler scheduler(context_, model, &tensors_, is_train_session_, npu_manager_, npu_pass_manager_); | |||
| #else | |||
| Scheduler scheduler(context_, model, &tensors_); | |||
| Scheduler scheduler(context_, model, &tensors_, is_train_session_); | |||
| #endif | |||
| ret = scheduler.Schedule(&kernels_); | |||
| if (ret != RET_OK) { | |||
| @@ -599,7 +599,7 @@ LiteSession::~LiteSession() { | |||
| npu_manager_->Reset(); | |||
| delete npu_manager_; | |||
| #endif | |||
| #if GPU_OPENCL && !SUPPORT_TRAIN | |||
| #if GPU_OPENCL | |||
| delete opencl_runtime_wrapper_; | |||
| #endif | |||
| delete (model_); | |||
| @@ -737,7 +737,7 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs | |||
| } | |||
| int LiteSession::InitGPURuntime() { | |||
| #if GPU_OPENCL && !SUPPORT_TRAIN | |||
| #if GPU_OPENCL | |||
| if (this->context_->IsGpuEnabled()) { | |||
| opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper(); | |||
| if (opencl_runtime_wrapper_ == nullptr) { | |||
| @@ -754,7 +754,7 @@ int LiteSession::InitGPURuntime() { | |||
| MS_LOG(INFO) << "Init OpenCL runtime success."; | |||
| } | |||
| } | |||
| #elif GPU_VULKAN && !SUPPORT_TRAIN | |||
| #elif GPU_VULKAN | |||
| if (this->context_->IsGpuEnabled()) { | |||
| auto gpu_device_info = this->context_->GetGpuInfo(); | |||
| vk_runtime_wrap_ = new (std::nothrow) gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime>; | |||
| @@ -134,13 +134,14 @@ class LiteSession : public session::LiteSession { | |||
| Executor *executor_ = nullptr; | |||
| Model *model_ = nullptr; | |||
| std::atomic<bool> is_running_ = false; | |||
| bool is_train_session_ = false; | |||
| #if SUPPORT_NPU | |||
| NPUManager *npu_manager_ = nullptr; | |||
| NPUPassManager *npu_pass_manager_ = nullptr; | |||
| #endif | |||
| #if GPU_OPENCL && !SUPPORT_TRAIN | |||
| #if GPU_OPENCL | |||
| opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr}; | |||
| #elif GPU_VULKAN && !SUPPORT_TRAIN | |||
| #elif GPU_VULKAN | |||
| gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime> *vk_runtime_wrap_{nullptr}; | |||
| #endif | |||
| }; | |||
| @@ -215,7 +215,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter | |||
| } | |||
| namespace { | |||
| #ifndef SUPPORT_TRAIN | |||
| int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) { | |||
| #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | |||
| MS_ASSERT(tensor != nullptr); | |||
| @@ -319,7 +318,6 @@ int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) { | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) { | |||
| MS_ASSERT(restored_origin_tensors != nullptr); | |||
| @@ -368,19 +366,20 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten | |||
| return nullptr; | |||
| } | |||
| std::map<Tensor *, Tensor *> restored_origin_tensors; | |||
| #ifndef SUPPORT_TRAIN | |||
| ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; | |||
| return nullptr; | |||
| } | |||
| // we don't need to restore tensor for copy data | |||
| ret = CopyConstTensorData(in_tensors, op_type); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; | |||
| return nullptr; | |||
| if (!is_train_session_) { | |||
| ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; | |||
| return nullptr; | |||
| } | |||
| // we don't need to restore tensor for copy data | |||
| ret = CopyConstTensorData(in_tensors, op_type); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; | |||
| return nullptr; | |||
| } | |||
| } | |||
| #endif | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); | |||
| @@ -30,16 +30,17 @@ | |||
| namespace mindspore::lite { | |||
| class Scheduler { | |||
| public: | |||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors) | |||
| : context_(ctx), src_model_(src_model), src_tensors_(src_tensors) {} | |||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session) | |||
| : context_(ctx), src_model_(src_model), src_tensors_(src_tensors), is_train_session_(is_train_session) {} | |||
| #if SUPPORT_NPU | |||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, | |||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session, | |||
| NPUManager *npu_manager = nullptr, NPUPassManager *npu_pass_manager = nullptr) | |||
| : context_(ctx), | |||
| src_model_(src_model), | |||
| src_tensors_(src_tensors), | |||
| npu_manager_(npu_manager), | |||
| npu_pass_manager_(npu_pass_manager) {} | |||
| npu_pass_manager_(npu_pass_manager), | |||
| is_train_session_(is_train_session) {} | |||
| #endif | |||
| ~Scheduler() = default; | |||
| @@ -113,6 +114,7 @@ class Scheduler { | |||
| #endif | |||
| std::vector<size_t> graph_output_node_indexes_; | |||
| std::map<int, OpParameter *> op_parameters_; | |||
| bool is_train_session_ = false; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -166,11 +166,6 @@ int CpuSubGraph::Run(const KernelCallBack &before, const KernelCallBack &after) | |||
| } | |||
| } | |||
| #ifdef SUPPORT_TRAIN | |||
| for (auto out_tensor : out_tensors_) { // increase RefCount of output tensors, such that Run will not free them | |||
| out_tensor->set_ref_count(out_tensor->ref_count() + 1); | |||
| } | |||
| #endif | |||
| #ifdef SUPPORT_GPU | |||
| // In heterogeneous scenarios of CPU and GPU, call MutableData to MapBuffer(synchronize data). | |||
| if (static_cast<const lite::InnerContext *>(context_)->IsGpuEnabled()) { | |||
| @@ -57,6 +57,7 @@ static kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> | |||
| return *it; | |||
| } | |||
| TrainSession::TrainSession() { | |||
| is_train_session_ = true; | |||
| #ifdef ENABLE_V0 | |||
| if (VersionManager::GetInstance()->CheckV0Schema()) { | |||
| kernel::PopulateTrainV0Parameters(); | |||