diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index fd6100077c..c4bcf0f806 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -20,11 +20,11 @@ #undef __STDC_FORMAT_MACROS #include #include -#include "src/common/common.h" -#include "include/ms_tensor.h" #include "include/context.h" -#include "src/runtime/runtime_api.h" +#include "include/ms_tensor.h" #include "include/version.h" +#include "src/common/common.h" +#include "src/runtime/runtime_api.h" namespace mindspore { namespace lite { @@ -378,26 +378,23 @@ int Benchmark::RunBenchmark() { std::cerr << "New context failed while running " << model_name.c_str() << std::endl; return RET_ERROR; } - auto &device_ctx = context->device_list_[0]; - if (flags_->device_ == "CPU") { - device_ctx.device_type_ = lite::DT_CPU; - } else if (flags_->device_ == "GPU") { - device_ctx.device_type_ = lite::DT_GPU; - } - if (device_ctx.device_type_ == DT_CPU) { - if (flags_->cpu_bind_mode_ == MID_CPU) { - device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU; - } else if (flags_->cpu_bind_mode_ == HIGHER_CPU) { - device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; - } else { - device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; - } - device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; + auto &cpu_device_ctx = context->device_list_[0]; + if (flags_->cpu_bind_mode_ == MID_CPU) { + cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU; + } else if (flags_->cpu_bind_mode_ == HIGHER_CPU) { + cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; + } else { + cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; } - if (device_ctx.device_type_ == DT_GPU) { - device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_; + cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; + + if (flags_->device_ == "GPU") { + DeviceContext gpu_device_ctx{DT_GPU, {false}}; + gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_; + context->device_list_.push_back(gpu_device_ctx); } + context->thread_num_ = flags_->num_threads_; session_ = session::LiteSession::CreateSession(context.get());