| @@ -20,11 +20,11 @@ | |||||
| #undef __STDC_FORMAT_MACROS | #undef __STDC_FORMAT_MACROS | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <utility> | #include <utility> | ||||
| #include "src/common/common.h" | |||||
| #include "include/ms_tensor.h" | |||||
| #include "include/context.h" | #include "include/context.h" | ||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "include/ms_tensor.h" | |||||
| #include "include/version.h" | #include "include/version.h" | ||||
| #include "src/common/common.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -378,26 +378,23 @@ int Benchmark::RunBenchmark() { | |||||
| std::cerr << "New context failed while running " << model_name.c_str() << std::endl; | std::cerr << "New context failed while running " << model_name.c_str() << std::endl; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto &device_ctx = context->device_list_[0]; | |||||
| if (flags_->device_ == "CPU") { | |||||
| device_ctx.device_type_ = lite::DT_CPU; | |||||
| } else if (flags_->device_ == "GPU") { | |||||
| device_ctx.device_type_ = lite::DT_GPU; | |||||
| } | |||||
| if (device_ctx.device_type_ == DT_CPU) { | |||||
| if (flags_->cpu_bind_mode_ == MID_CPU) { | |||||
| device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU; | |||||
| } else if (flags_->cpu_bind_mode_ == HIGHER_CPU) { | |||||
| device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; | |||||
| } else { | |||||
| device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; | |||||
| } | |||||
| device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; | |||||
| auto &cpu_device_ctx = context->device_list_[0]; | |||||
| if (flags_->cpu_bind_mode_ == MID_CPU) { | |||||
| cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU; | |||||
| } else if (flags_->cpu_bind_mode_ == HIGHER_CPU) { | |||||
| cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; | |||||
| } else { | |||||
| cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; | |||||
| } | } | ||||
| if (device_ctx.device_type_ == DT_GPU) { | |||||
| device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_; | |||||
| cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; | |||||
| if (flags_->device_ == "GPU") { | |||||
| DeviceContext gpu_device_ctx{DT_GPU, {false}}; | |||||
| gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_; | |||||
| context->device_list_.push_back(gpu_device_ctx); | |||||
| } | } | ||||
| context->thread_num_ = flags_->num_threads_; | context->thread_num_ = flags_->num_threads_; | ||||
| session_ = session::LiteSession::CreateSession(context.get()); | session_ = session::LiteSession::CreateSession(context.get()); | ||||