Browse Source

fixed issue I3HZIK and removed some more SUPPORT_TRAIN ifdefs

pull/15402/head
Emir Haleva 4 years ago
parent
commit
8974c7f0bc
6 changed files with 28 additions and 30 deletions
  1. +5
    -5
      mindspore/lite/src/lite_session.cc
  2. +3
    -2
      mindspore/lite/src/lite_session.h
  3. +13
    -14
      mindspore/lite/src/scheduler.cc
  4. +6
    -4
      mindspore/lite/src/scheduler.h
  5. +0
    -5
      mindspore/lite/src/sub_graph_kernel.cc
  6. +1
    -0
      mindspore/lite/src/train/train_session.cc

+ 5
- 5
mindspore/lite/src/lite_session.cc View File

@@ -395,9 +395,9 @@ int LiteSession::CompileGraph(Model *model) {
}
// scheduler kernels
#if SUPPORT_NPU
Scheduler scheduler(context_, model, &tensors_, npu_manager_, npu_pass_manager_);
Scheduler scheduler(context_, model, &tensors_, is_train_session_, npu_manager_, npu_pass_manager_);
#else
Scheduler scheduler(context_, model, &tensors_);
Scheduler scheduler(context_, model, &tensors_, is_train_session_);
#endif
ret = scheduler.Schedule(&kernels_);
if (ret != RET_OK) {
@@ -599,7 +599,7 @@ LiteSession::~LiteSession() {
npu_manager_->Reset();
delete npu_manager_;
#endif
#if GPU_OPENCL && !SUPPORT_TRAIN
#if GPU_OPENCL
delete opencl_runtime_wrapper_;
#endif
delete (model_);
@@ -737,7 +737,7 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
}

int LiteSession::InitGPURuntime() {
#if GPU_OPENCL && !SUPPORT_TRAIN
#if GPU_OPENCL
if (this->context_->IsGpuEnabled()) {
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
if (opencl_runtime_wrapper_ == nullptr) {
@@ -754,7 +754,7 @@ int LiteSession::InitGPURuntime() {
MS_LOG(INFO) << "Init OpenCL runtime success.";
}
}
#elif GPU_VULKAN && !SUPPORT_TRAIN
#elif GPU_VULKAN
if (this->context_->IsGpuEnabled()) {
auto gpu_device_info = this->context_->GetGpuInfo();
vk_runtime_wrap_ = new (std::nothrow) gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime>;


+ 3
- 2
mindspore/lite/src/lite_session.h View File

@@ -134,13 +134,14 @@ class LiteSession : public session::LiteSession {
Executor *executor_ = nullptr;
Model *model_ = nullptr;
std::atomic<bool> is_running_ = false;
bool is_train_session_ = false;
#if SUPPORT_NPU
NPUManager *npu_manager_ = nullptr;
NPUPassManager *npu_pass_manager_ = nullptr;
#endif
#if GPU_OPENCL && !SUPPORT_TRAIN
#if GPU_OPENCL
opencl::OpenCLRuntimeWrapper *opencl_runtime_wrapper_{nullptr};
#elif GPU_VULKAN && !SUPPORT_TRAIN
#elif GPU_VULKAN
gpu::GpuRuntimeWrapper<vulkan::VulkanRuntime> *vk_runtime_wrap_{nullptr};
#endif
};


+ 13
- 14
mindspore/lite/src/scheduler.cc View File

@@ -215,7 +215,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter
}

namespace {
#ifndef SUPPORT_TRAIN
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
MS_ASSERT(tensor != nullptr);
@@ -319,7 +318,6 @@ int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) {
}
return RET_OK;
}
#endif

inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) {
MS_ASSERT(restored_origin_tensors != nullptr);
@@ -368,19 +366,20 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten
return nullptr;
}
std::map<Tensor *, Tensor *> restored_origin_tensors;
#ifndef SUPPORT_TRAIN
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
return nullptr;
}
// we don't need to restore tensor for copy data
ret = CopyConstTensorData(in_tensors, op_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
return nullptr;

if (!is_train_session_) {
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret;
return nullptr;
}
// we don't need to restore tensor for copy data
ret = CopyConstTensorData(in_tensors, op_type);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret;
return nullptr;
}
}
#endif
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type);


+ 6
- 4
mindspore/lite/src/scheduler.h View File

@@ -30,16 +30,17 @@
namespace mindspore::lite {
class Scheduler {
public:
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors)
: context_(ctx), src_model_(src_model), src_tensors_(src_tensors) {}
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session)
: context_(ctx), src_model_(src_model), src_tensors_(src_tensors), is_train_session_(is_train_session) {}
#if SUPPORT_NPU
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session,
NPUManager *npu_manager = nullptr, NPUPassManager *npu_pass_manager = nullptr)
: context_(ctx),
src_model_(src_model),
src_tensors_(src_tensors),
npu_manager_(npu_manager),
npu_pass_manager_(npu_pass_manager) {}
npu_pass_manager_(npu_pass_manager),
is_train_session_(is_train_session) {}
#endif
~Scheduler() = default;

@@ -113,6 +114,7 @@ class Scheduler {
#endif
std::vector<size_t> graph_output_node_indexes_;
std::map<int, OpParameter *> op_parameters_;
bool is_train_session_ = false;
};
} // namespace mindspore::lite



+ 0
- 5
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -166,11 +166,6 @@ int CpuSubGraph::Run(const KernelCallBack &before, const KernelCallBack &after)
}
}

#ifdef SUPPORT_TRAIN
for (auto out_tensor : out_tensors_) { // increase RefCount of output tensors, such that Run will not free them
out_tensor->set_ref_count(out_tensor->ref_count() + 1);
}
#endif
#ifdef SUPPORT_GPU
// In heterogeneous scenarios of CPU and GPU, call MutableData to MapBuffer(synchronize data).
if (static_cast<const lite::InnerContext *>(context_)->IsGpuEnabled()) {


+ 1
- 0
mindspore/lite/src/train/train_session.cc View File

@@ -57,6 +57,7 @@ static kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *>
return *it;
}
TrainSession::TrainSession() {
is_train_session_ = true;
#ifdef ENABLE_V0
if (VersionManager::GetInstance()->CheckV0Schema()) {
kernel::PopulateTrainV0Parameters();


Loading…
Cancel
Save