Browse Source

!7446 [MS][LITE][Develop]fix bug of benchmark context

Merge pull request !7446 from mengyuanli/fix_bug_of_benchmark_context
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e9c71d517e
1 changed files with 17 additions and 20 deletions
  1. +17
    -20
      mindspore/lite/tools/benchmark/benchmark.cc

+ 17
- 20
mindspore/lite/tools/benchmark/benchmark.cc View File

@@ -20,11 +20,11 @@
#undef __STDC_FORMAT_MACROS
#include <algorithm>
#include <utility>
#include "src/common/common.h"
#include "include/ms_tensor.h"
#include "include/context.h"
#include "src/runtime/runtime_api.h"
#include "include/ms_tensor.h"
#include "include/version.h"
#include "src/common/common.h"
#include "src/runtime/runtime_api.h"

namespace mindspore {
namespace lite {
@@ -378,26 +378,23 @@ int Benchmark::RunBenchmark() {
std::cerr << "New context failed while running " << model_name.c_str() << std::endl;
return RET_ERROR;
}
auto &device_ctx = context->device_list_[0];
if (flags_->device_ == "CPU") {
device_ctx.device_type_ = lite::DT_CPU;
} else if (flags_->device_ == "GPU") {
device_ctx.device_type_ = lite::DT_GPU;
}

if (device_ctx.device_type_ == DT_CPU) {
if (flags_->cpu_bind_mode_ == MID_CPU) {
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
} else if (flags_->cpu_bind_mode_ == HIGHER_CPU) {
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
} else {
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
}
device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
auto &cpu_device_ctx = context->device_list_[0];
if (flags_->cpu_bind_mode_ == MID_CPU) {
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
} else if (flags_->cpu_bind_mode_ == HIGHER_CPU) {
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
} else {
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
}
if (device_ctx.device_type_ == DT_GPU) {
device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_;
cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;

if (flags_->device_ == "GPU") {
DeviceContext gpu_device_ctx{DT_GPU, {false}};
gpu_device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_;
context->device_list_.push_back(gpu_device_ctx);
}

context->thread_num_ = flags_->num_threads_;

session_ = session::LiteSession::CreateSession(context.get());


Loading…
Cancel
Save