|
|
|
@@ -20,11 +20,11 @@ |
|
|
|
#undef __STDC_FORMAT_MACROS |
|
|
|
#include <algorithm> |
|
|
|
#include <utility> |
|
|
|
#include "src/common/common.h" |
|
|
|
#include "include/ms_tensor.h" |
|
|
|
#include "include/context.h" |
|
|
|
#include "src/runtime/runtime_api.h" |
|
|
|
#include "include/ms_tensor.h" |
|
|
|
#include "include/version.h" |
|
|
|
#include "src/common/common.h" |
|
|
|
#include "src/runtime/runtime_api.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
@@ -378,26 +378,23 @@ int Benchmark::RunBenchmark() { |
|
|
|
std::cerr << "New context failed while running " << model_name.c_str() << std::endl; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto &device_ctx = context->device_list_[0]; |
|
|
|
if (flags_->device_ == "CPU") { |
|
|
|
device_ctx.device_type_ = lite::DT_CPU; |
|
|
|
} else if (flags_->device_ == "GPU") { |
|
|
|
device_ctx.device_type_ = lite::DT_GPU; |
|
|
|
} |
|
|
|
|
|
|
|
if (device_ctx.device_type_ == DT_CPU) { |
|
|
|
if (flags_->cpu_bind_mode_ == MID_CPU) { |
|
|
|
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU; |
|
|
|
} else if (flags_->cpu_bind_mode_ == HIGHER_CPU) { |
|
|
|
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; |
|
|
|
} else { |
|
|
|
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; |
|
|
|
} |
|
|
|
device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; |
|
|
|
auto &cpu_device_ctx = context->device_list_[0]; |
|
|
|
if (flags_->cpu_bind_mode_ == MID_CPU) { |
|
|
|
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU; |
|
|
|
} else if (flags_->cpu_bind_mode_ == HIGHER_CPU) { |
|
|
|
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; |
|
|
|
} else { |
|
|
|
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; |
|
|
|
} |
|
|
|
if (device_ctx.device_type_ == DT_GPU) { |
|
|
|
device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_; |
|
|
|
cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; |
|
|
|
|
|
|
|
if (flags_->device_ == "GPU") { |
|
|
|
DeviceContext gpu_device_ctx{DT_GPU, {false}}; |
|
|
|
gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_; |
|
|
|
context->device_list_.push_back(gpu_device_ctx); |
|
|
|
} |
|
|
|
|
|
|
|
context->thread_num_ = flags_->num_threads_; |
|
|
|
|
|
|
|
session_ = session::LiteSession::CreateSession(context.get()); |
|
|
|
|