Browse Source

fix bug of ctx use

tags/v1.1.0
mengyuanli 5 years ago
parent
commit
16f43d8dc0
1 changed files with 17 additions and 20 deletions
  1. +17
    -20
      mindspore/lite/tools/benchmark/benchmark.cc

+ 17
- 20
mindspore/lite/tools/benchmark/benchmark.cc View File

@@ -20,11 +20,11 @@
#undef __STDC_FORMAT_MACROS #undef __STDC_FORMAT_MACROS
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "src/common/common.h"
#include "include/ms_tensor.h"
#include "include/context.h" #include "include/context.h"
#include "src/runtime/runtime_api.h"
#include "include/ms_tensor.h"
#include "include/version.h" #include "include/version.h"
#include "src/common/common.h"
#include "src/runtime/runtime_api.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@@ -378,26 +378,23 @@ int Benchmark::RunBenchmark() {
std::cerr << "New context failed while running " << model_name.c_str() << std::endl; std::cerr << "New context failed while running " << model_name.c_str() << std::endl;
return RET_ERROR; return RET_ERROR;
} }
auto &device_ctx = context->device_list_[0];
if (flags_->device_ == "CPU") {
device_ctx.device_type_ = lite::DT_CPU;
} else if (flags_->device_ == "GPU") {
device_ctx.device_type_ = lite::DT_GPU;
}


if (device_ctx.device_type_ == DT_CPU) {
if (flags_->cpu_bind_mode_ == MID_CPU) {
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
} else if (flags_->cpu_bind_mode_ == HIGHER_CPU) {
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
} else {
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
}
device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
auto &cpu_device_ctx = context->device_list_[0];
if (flags_->cpu_bind_mode_ == MID_CPU) {
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
} else if (flags_->cpu_bind_mode_ == HIGHER_CPU) {
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
} else {
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
} }
if (device_ctx.device_type_ == DT_GPU) {
device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_;
cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;

if (flags_->device_ == "GPU") {
DeviceContext gpu_device_ctx{DT_GPU, {false}};
gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_;
context->device_list_.push_back(gpu_device_ctx);
} }

context->thread_num_ = flags_->num_threads_; context->thread_num_ = flags_->num_threads_;


session_ = session::LiteSession::CreateSession(context.get()); session_ = session::LiteSession::CreateSession(context.get());


Loading…
Cancel
Save