| @@ -395,9 +395,9 @@ int LiteSession::CompileGraph(Model *model) { | |||||
| } | } | ||||
| // scheduler kernels | // scheduler kernels | ||||
| #if SUPPORT_NPU | #if SUPPORT_NPU | ||||
| Scheduler scheduler(context_, model, &tensors_, npu_manager_, npu_pass_manager_); | |||||
| Scheduler scheduler(context_, model, &tensors_, is_train_session_, npu_manager_, npu_pass_manager_); | |||||
| #else | #else | ||||
| Scheduler scheduler(context_, model, &tensors_); | |||||
| Scheduler scheduler(context_, model, &tensors_, is_train_session_); | |||||
| #endif | #endif | ||||
| ret = scheduler.Schedule(&kernels_); | ret = scheduler.Schedule(&kernels_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -599,7 +599,7 @@ LiteSession::~LiteSession() { | |||||
| npu_manager_->Reset(); | npu_manager_->Reset(); | ||||
| delete npu_manager_; | delete npu_manager_; | ||||
| #endif | #endif | ||||
| #if GPU_OPENCL && !SUPPORT_TRAIN | |||||
| #if GPU_OPENCL | |||||
| delete opencl_runtime_wrapper_; | delete opencl_runtime_wrapper_; | ||||
| #endif | #endif | ||||
| delete (model_); | delete (model_); | ||||
| @@ -737,7 +737,7 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs | |||||
| } | } | ||||
| int LiteSession::InitGPURuntime() { | int LiteSession::InitGPURuntime() { | ||||
| #if GPU_OPENCL && !SUPPORT_TRAIN | |||||
| #if GPU_OPENCL | |||||
| if (this->context_->IsGpuEnabled()) { | if (this->context_->IsGpuEnabled()) { | ||||
| opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper(); | opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper(); | ||||
| if (opencl_runtime_wrapper_ == nullptr) { | if (opencl_runtime_wrapper_ == nullptr) { | ||||
| @@ -754,7 +754,7 @@ int LiteSession::InitGPURuntime() { | |||||
| MS_LOG(INFO) << "Init OpenCL runtime success."; | MS_LOG(INFO) << "Init OpenCL runtime success."; | ||||
| } | } | ||||
| } | } | ||||
| #elif GPU_VULKAN && !SUPPORT_TRAIN | |||||
| #elif GPU_VULKAN | |||||
| if (this->context_->IsGpuEnabled()) { | if (this->context_->IsGpuEnabled()) { | ||||
| auto gpu_device_info = this->context_->GetGpuInfo(); | auto gpu_device_info = this->context_->GetGpuInfo(); | ||||
| vk_runtime_wrap_ = new (std::nothrow) gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime>; | vk_runtime_wrap_ = new (std::nothrow) gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime>; | ||||
| @@ -134,13 +134,14 @@ class LiteSession : public session::LiteSession { | |||||
| Executor *executor_ = nullptr; | Executor *executor_ = nullptr; | ||||
| Model *model_ = nullptr; | Model *model_ = nullptr; | ||||
| std::atomic<bool> is_running_ = false; | std::atomic<bool> is_running_ = false; | ||||
| bool is_train_session_ = false; | |||||
| #if SUPPORT_NPU | #if SUPPORT_NPU | ||||
| NPUManager *npu_manager_ = nullptr; | NPUManager *npu_manager_ = nullptr; | ||||
| NPUPassManager *npu_pass_manager_ = nullptr; | NPUPassManager *npu_pass_manager_ = nullptr; | ||||
| #endif | #endif | ||||
| #if GPU_OPENCL && !SUPPORT_TRAIN | |||||
| #if GPU_OPENCL | |||||
| opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr}; | opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr}; | ||||
| #elif GPU_VULKAN && !SUPPORT_TRAIN | |||||
| #elif GPU_VULKAN | |||||
| gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime> *vk_runtime_wrap_{nullptr}; | gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime> *vk_runtime_wrap_{nullptr}; | ||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -215,7 +215,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter | |||||
| } | } | ||||
| namespace { | namespace { | ||||
| #ifndef SUPPORT_TRAIN | |||||
| int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) { | int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) { | ||||
| #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | ||||
| MS_ASSERT(tensor != nullptr); | MS_ASSERT(tensor != nullptr); | ||||
| @@ -319,7 +318,6 @@ int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) { | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | |||||
| inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) { | inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) { | ||||
| MS_ASSERT(restored_origin_tensors != nullptr); | MS_ASSERT(restored_origin_tensors != nullptr); | ||||
| @@ -368,19 +366,20 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::map<Tensor *, Tensor *> restored_origin_tensors; | std::map<Tensor *, Tensor *> restored_origin_tensors; | ||||
| #ifndef SUPPORT_TRAIN | |||||
| ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; | |||||
| return nullptr; | |||||
| } | |||||
| // we don't need to restore tensor for copy data | |||||
| ret = CopyConstTensorData(in_tensors, op_type); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; | |||||
| return nullptr; | |||||
| if (!is_train_session_) { | |||||
| ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; | |||||
| return nullptr; | |||||
| } | |||||
| // we don't need to restore tensor for copy data | |||||
| ret = CopyConstTensorData(in_tensors, op_type); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; | |||||
| return nullptr; | |||||
| } | |||||
| } | } | ||||
| #endif | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); | auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); | ||||
| if (kernel != nullptr) { | if (kernel != nullptr) { | ||||
| MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); | MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); | ||||
| @@ -30,16 +30,17 @@ | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class Scheduler { | class Scheduler { | ||||
| public: | public: | ||||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors) | |||||
| : context_(ctx), src_model_(src_model), src_tensors_(src_tensors) {} | |||||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session) | |||||
| : context_(ctx), src_model_(src_model), src_tensors_(src_tensors), is_train_session_(is_train_session) {} | |||||
| #if SUPPORT_NPU | #if SUPPORT_NPU | ||||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, | |||||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session, | |||||
| NPUManager *npu_manager = nullptr, NPUPassManager *npu_pass_manager = nullptr) | NPUManager *npu_manager = nullptr, NPUPassManager *npu_pass_manager = nullptr) | ||||
| : context_(ctx), | : context_(ctx), | ||||
| src_model_(src_model), | src_model_(src_model), | ||||
| src_tensors_(src_tensors), | src_tensors_(src_tensors), | ||||
| npu_manager_(npu_manager), | npu_manager_(npu_manager), | ||||
| npu_pass_manager_(npu_pass_manager) {} | |||||
| npu_pass_manager_(npu_pass_manager), | |||||
| is_train_session_(is_train_session) {} | |||||
| #endif | #endif | ||||
| ~Scheduler() = default; | ~Scheduler() = default; | ||||
| @@ -113,6 +114,7 @@ class Scheduler { | |||||
| #endif | #endif | ||||
| std::vector<size_t> graph_output_node_indexes_; | std::vector<size_t> graph_output_node_indexes_; | ||||
| std::map<int, OpParameter *> op_parameters_; | std::map<int, OpParameter *> op_parameters_; | ||||
| bool is_train_session_ = false; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -166,11 +166,6 @@ int CpuSubGraph::Run(const KernelCallBack &before, const KernelCallBack &after) | |||||
| } | } | ||||
| } | } | ||||
| #ifdef SUPPORT_TRAIN | |||||
| for (auto out_tensor : out_tensors_) { // increase RefCount of output tensors, such that Run will not free them | |||||
| out_tensor->set_ref_count(out_tensor->ref_count() + 1); | |||||
| } | |||||
| #endif | |||||
| #ifdef SUPPORT_GPU | #ifdef SUPPORT_GPU | ||||
| // In heterogeneous scenarios of CPU and GPU, call MutableData to MapBuffer(synchronize data). | // In heterogeneous scenarios of CPU and GPU, call MutableData to MapBuffer(synchronize data). | ||||
| if (static_cast<const lite::InnerContext *>(context_)->IsGpuEnabled()) { | if (static_cast<const lite::InnerContext *>(context_)->IsGpuEnabled()) { | ||||
| @@ -57,6 +57,7 @@ static kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> | |||||
| return *it; | return *it; | ||||
| } | } | ||||
| TrainSession::TrainSession() { | TrainSession::TrainSession() { | ||||
| is_train_session_ = true; | |||||
| #ifdef ENABLE_V0 | #ifdef ENABLE_V0 | ||||
| if (VersionManager::GetInstance()->CheckV0Schema()) { | if (VersionManager::GetInstance()->CheckV0Schema()) { | ||||
| kernel::PopulateTrainV0Parameters(); | kernel::PopulateTrainV0Parameters(); | ||||