Merge pull request !7342 from mengyuanli/add_device_contexttags/v1.1.0
| @@ -37,13 +37,35 @@ typedef enum { | |||
| DT_NPU /**< NPU device type, not supported yet */ | |||
| } DeviceType; | |||
| /// \brief Context defined for holding environment variables during runtime. | |||
| struct Context { | |||
| /// \brief CpuDeviceInfo defined for CPU's configuration information. | |||
| typedef struct { | |||
| bool enable_float16_ = false; /**< prior enable float16 inference */ | |||
| CpuBindMode cpu_bind_mode_ = MID_CPU; | |||
| } CpuDeviceInfo; | |||
| /// \brief GpuDeviceInfo defined for GPU's configuration information. | |||
| typedef struct { | |||
| bool enable_float16_ = false; /**< prior enable float16 inference */ | |||
| } GpuDeviceInfo; | |||
| /// \brief DeviceInfo defined for backend's configuration information. | |||
| union DeviceInfo { | |||
| CpuDeviceInfo cpu_device_info_; | |||
| GpuDeviceInfo gpu_device_info_; | |||
| }; | |||
| /// \brief DeviceContext defined for holding backend's configuration information. | |||
| struct DeviceContext { | |||
| DeviceType device_type_ = DT_CPU; | |||
| DeviceInfo device_info_; | |||
| }; | |||
| /// \brief Context defined for holding environment variables during runtime. | |||
| struct Context { | |||
| std::string vendor_name_; | |||
| int thread_num_ = 2; /**< thread number config for thread pool */ | |||
| AllocatorPtr allocator = nullptr; | |||
| CpuBindMode cpu_bind_mode_ = MID_CPU; | |||
| DeviceContextVector device_list_ = {{DT_CPU, {false, MID_CPU}}}; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_ | |||
| @@ -27,7 +27,11 @@ namespace mindspore::lite { | |||
| /// \note List public class and interface for reference. | |||
| class Allocator; | |||
| /// \brief DeviceContext defined a device context. | |||
| struct DeviceContext; | |||
| using TensorPtrVector = std::vector<mindspore::schema::Tensor *>; | |||
| using DeviceContextVector = std::vector<DeviceContext>; | |||
| using Uint32Vector = std::vector<uint32_t>; | |||
| using String = std::string; | |||
| using NodeType = schema::NodeType; | |||
| @@ -20,8 +20,13 @@ | |||
| namespace mindspore::lite { | |||
| int InnerContext::Init() { | |||
| if (this->thread_pool_ == nullptr) { | |||
| this->thread_pool_ = CreateLiteThreadPool(this->thread_num_, this->cpu_bind_mode_); | |||
| if (this->device_list_.empty()) { | |||
| MS_LOG(ERROR) << "Device list is empty."; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| if (this->thread_pool_ == nullptr && this->device_list_[0].device_type_ == DT_CPU) { | |||
| this->thread_pool_ = | |||
| CreateLiteThreadPool(this->thread_num_, this->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_); | |||
| if (this->thread_pool_ == nullptr) { | |||
| MS_LOG(ERROR) << "Create ThreadPool failed"; | |||
| return RET_NULL_PTR; | |||
| @@ -315,14 +315,27 @@ int LiteSession::Init(Context *context) { | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(nullptr != context); | |||
| if (context->device_type_ == DT_NPU) { | |||
| if (context == nullptr) { | |||
| MS_LOG(ERROR) << "context is nullptr"; | |||
| is_running_.store(false); | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (context->device_list_.empty()) { | |||
| MS_LOG(ERROR) << "Device list is empty."; | |||
| is_running_.store(false); | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| auto &device_type = context->device_list_[0].device_type_; | |||
| if (device_type == DT_NPU) { | |||
| MS_LOG(ERROR) << "NPU is not supported."; | |||
| is_running_.store(false); | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| #ifndef SUPPORT_GPU | |||
| if (context->device_type_ == DT_GPU) { | |||
| if (device_type == DT_GPU) { | |||
| MS_LOG(ERROR) << "GPU is not supported."; | |||
| is_running_.store(false); | |||
| return RET_NOT_SUPPORT; | |||
| @@ -337,9 +350,10 @@ int LiteSession::Init(Context *context) { | |||
| } | |||
| this->context_->allocator = context->allocator; | |||
| this->context_->thread_num_ = context->thread_num_; | |||
| this->context_->cpu_bind_mode_ = context->cpu_bind_mode_; | |||
| this->context_->device_type_ = context->device_type_; | |||
| this->context_->enable_float16_ = context->enable_float16_; | |||
| this->context_->device_list_.clear(); | |||
| for (auto &device_ctx : context->device_list_) { | |||
| this->context_->device_list_.push_back(device_ctx); | |||
| } | |||
| auto ret = this->context_->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init Context failed"; | |||
| @@ -353,11 +367,12 @@ int LiteSession::Init(Context *context) { | |||
| return ret; | |||
| } | |||
| #if SUPPORT_GPU | |||
| if (context_->device_type_ == DT_GPU) { | |||
| if (device_type == DT_GPU) { | |||
| auto gpu_device_info = this->context_->device_list_[0].device_info_.gpu_device_info_; | |||
| auto opencl_runtime = ocl_runtime_wrap_.GetInstance(); | |||
| opencl_runtime->SetFp16Enable(context_->enable_float16_); | |||
| opencl_runtime->SetFp16Enable(gpu_device_info.enable_float16_); | |||
| if (opencl_runtime->Init() != RET_OK) { | |||
| context_->device_type_ = DT_CPU; | |||
| device_type = DT_CPU; | |||
| MS_LOG(WARNING) << "Init OpenCL runtime failed, change to CPU mode."; | |||
| } else { | |||
| MS_LOG(INFO) << "Init OpenCL runtime success."; | |||
| @@ -375,9 +390,18 @@ int LiteSession::Init(Context *context) { | |||
| } | |||
| void LiteSession::BindThread(bool if_bind) { | |||
| if (this->context_->cpu_bind_mode_ != NO_BIND) { | |||
| if (this->context_->device_list_.empty()) { | |||
| MS_LOG(ERROR) << "Device list is empty."; | |||
| return; | |||
| } | |||
| auto &device_ctx = this->context_->device_list_[0]; | |||
| if (device_ctx.device_type_ != DT_CPU) { | |||
| MS_LOG(ERROR) << "Device is not CPU."; | |||
| return; | |||
| } | |||
| if (device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ != NO_BIND) { | |||
| MS_ASSERT(this->context_->thread_pool_ != NULL); | |||
| BindThreads(this->context_->thread_pool_, if_bind, this->context_->cpu_bind_mode_); | |||
| BindThreads(this->context_->thread_pool_, if_bind, device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_); | |||
| } | |||
| } | |||
| @@ -192,10 +192,12 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) { | |||
| std::vector<kernel::LiteKernel *> subgraph_kernels; | |||
| size_t sub_cnt{0}; | |||
| auto &device_ctx = context_->device_list_[0]; | |||
| for (auto temp_kernels : sub_kernels_list) { | |||
| std::vector<Tensor *> output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels); | |||
| for (auto tensor : output_tensor) { | |||
| if (context_->enable_float16_ && tensor->data_type() == kNumberTypeFloat16) { | |||
| if (device_ctx.device_type_ == DT_CPU && device_ctx.device_info_.cpu_device_info_.enable_float16_ && | |||
| tensor->data_type() == kNumberTypeFloat16) { | |||
| tensor->set_data_type(kNumberTypeFloat32); | |||
| } | |||
| } | |||
| @@ -246,8 +248,9 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tens | |||
| MS_ASSERT(primitive != nullptr); | |||
| TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); | |||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; | |||
| auto &device_ctx = context_->device_list_[0]; | |||
| #if SUPPORT_GPU | |||
| if (context_->device_type_ == DT_GPU) { | |||
| if (device_ctx.device_type_ == DT_GPU) { | |||
| desc.arch = kernel::KERNEL_ARCH::kGPU; | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| if (kernel != nullptr) { | |||
| @@ -262,7 +265,8 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tens | |||
| #endif | |||
| desc.arch = kernel::KERNEL_ARCH::kCPU; | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if ((context_->enable_float16_ && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) { | |||
| if ((device_ctx.device_info_.cpu_device_info_.enable_float16_ && data_type == kNumberTypeFloat32) || | |||
| data_type == kNumberTypeFloat16) { | |||
| // check if support fp16 | |||
| kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; | |||
| kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key); | |||
| @@ -106,8 +106,9 @@ TEST_F(InferTest, TestConvNode) { | |||
| meta_graph.reset(); | |||
| content = nullptr; | |||
| auto context = new lite::InnerContext; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->device_type_ = lite::DT_CPU; | |||
| auto &device_list = context->device_list_; | |||
| lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}}; | |||
| device_list.push_back(device_ctx); | |||
| context->thread_num_ = 4; | |||
| ASSERT_EQ(lite::RET_OK, context->Init()); | |||
| auto session = session::LiteSession::CreateSession(context); | |||
| @@ -205,8 +206,9 @@ TEST_F(InferTest, TestAddNode) { | |||
| meta_graph.reset(); | |||
| content = nullptr; | |||
| auto context = new lite::InnerContext; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->device_type_ = lite::DT_CPU; | |||
| auto &device_list = context->device_list_; | |||
| lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}}; | |||
| device_list.push_back(device_ctx); | |||
| context->thread_num_ = 4; | |||
| ASSERT_EQ(lite::RET_OK, context->Init()); | |||
| auto session = session::LiteSession::CreateSession(context); | |||
| @@ -295,8 +297,9 @@ TEST_F(InferTest, TestParallelExecutor) { | |||
| meta_graph.reset(); | |||
| content = nullptr; | |||
| auto context = new lite::InnerContext; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->device_type_ = lite::DT_CPU; | |||
| auto &device_list = context->device_list_; | |||
| lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}}; | |||
| device_list.push_back(device_ctx); | |||
| context->thread_num_ = 4; | |||
| ASSERT_EQ(lite::RET_OK, context->Init()); | |||
| auto session = new SessionWithParallelExecutor(); | |||
| @@ -336,8 +339,7 @@ TEST_F(InferTest, TestModel) { | |||
| ASSERT_NE(nullptr, model); | |||
| delete[] buf[0]; | |||
| auto context = new lite::InnerContext; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 4; | |||
| ASSERT_EQ(lite::RET_OK, context->Init()); | |||
| auto session = session::LiteSession::CreateSession(context); | |||
| @@ -68,7 +68,6 @@ TEST_F(TestBNGradFp32, BNGradFp32) { | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor, &dscale_tensor, &dbias_tensor}; | |||
| lite::InnerContext ctx; | |||
| ctx.device_type_ = lite::DT_CPU; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| @@ -171,7 +170,6 @@ TEST_F(TestBNGradFp32, BNTtrainFp32) { | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_FusedBatchNorm}; | |||
| mindspore::lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -108,7 +108,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -182,7 +181,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) { | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -255,7 +253,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -328,7 +325,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) { | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -401,7 +397,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -474,7 +469,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) { | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -542,7 +536,6 @@ TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -644,7 +637,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32Dilation2Group2Stride2FilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -745,7 +737,6 @@ TEST_F(TestConvolutionGradFp32, ConvGroup2Dilation2Stride2) { | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -90,7 +90,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32FilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -190,7 +189,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2FilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -290,7 +288,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group3FilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -390,7 +387,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group3Stride1FilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -490,7 +486,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group2Stride2FilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -590,7 +585,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group12Stride2FilterGrad) { | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -357,8 +357,7 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| meta_graph.reset(); | |||
| content = nullptr; | |||
| lite::Context context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.cpu_bind_mode_ = lite::NO_BIND; | |||
| context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context.thread_num_ = 1; | |||
| auto session = session::TrainSession::CreateSession(&context); | |||
| ASSERT_NE(nullptr, session); | |||
| @@ -518,8 +517,7 @@ TEST_F(NetworkTest, efficient_net) { | |||
| auto model = lite::TrainModel::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::Context; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| auto session = session::TrainSession::CreateSession(context); | |||
| @@ -544,8 +542,7 @@ TEST_F(NetworkTest, lenetnet) { | |||
| auto model = lite::TrainModel::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::Context; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| // check registration | |||
| @@ -589,8 +586,7 @@ TEST_F(NetworkTest, retina_net) { | |||
| auto model = lite::Model::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::Context; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| // auto session = session::TrainSession::CreateSession(context); | |||
| @@ -640,8 +636,7 @@ TEST_F(NetworkTest, mobileface_net) { | |||
| auto model = lite::Model::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::Context; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| // auto session = session::TrainSession::CreateSession(context); | |||
| @@ -141,7 +141,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -206,7 +205,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) { | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -269,7 +267,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride2Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -334,7 +331,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride3Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -455,7 +451,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradBatchFp32) { | |||
| std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -530,7 +525,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride2Fp32) { | |||
| std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -605,7 +599,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride3Fp32) { | |||
| std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -61,7 +61,6 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { | |||
| std::vector<lite::Tensor *> outputs = {&loss_tensor, &grad_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| @@ -378,21 +378,28 @@ int Benchmark::RunBenchmark() { | |||
| std::cerr << "New context failed while running " << model_name.c_str() << std::endl; | |||
| return RET_ERROR; | |||
| } | |||
| auto &device_ctx = context->device_list_[0]; | |||
| if (flags_->device_ == "CPU") { | |||
| context->device_type_ = lite::DT_CPU; | |||
| device_ctx.device_type_ = lite::DT_CPU; | |||
| } else if (flags_->device_ == "GPU") { | |||
| context->device_type_ = lite::DT_GPU; | |||
| device_ctx.device_type_ = lite::DT_GPU; | |||
| } | |||
| if (flags_->cpu_bind_mode_ == 2) { | |||
| context->cpu_bind_mode_ = MID_CPU; | |||
| } else if (flags_->cpu_bind_mode_ == 1) { | |||
| context->cpu_bind_mode_ = HIGHER_CPU; | |||
| } else { | |||
| context->cpu_bind_mode_ = NO_BIND; | |||
| if (device_ctx.device_type_ == DT_CPU) { | |||
| if (flags_->cpu_bind_mode_ == MID_CPU) { | |||
| device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU; | |||
| } else if (flags_->cpu_bind_mode_ == HIGHER_CPU) { | |||
| device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; | |||
| } else { | |||
| device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; | |||
| } | |||
| device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; | |||
| } | |||
| if (device_ctx.device_type_ == DT_GPU) { | |||
| device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_; | |||
| } | |||
| context->thread_num_ = flags_->num_threads_; | |||
| context->enable_float16_ = flags_->enable_fp16_; | |||
| session_ = session::LiteSession::CreateSession(context.get()); | |||
| if (session_ == nullptr) { | |||
| MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str(); | |||
| @@ -1369,9 +1369,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||
| auto model = lite::Model::Import(content, size); | |||
| Context ctx; | |||
| ctx.device_type_ = DT_CPU; | |||
| ctx.thread_num_ = calibrator_->GetThreadNum(); | |||
| ctx.cpu_bind_mode_ = MID_CPU; | |||
| fp32_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&ctx)); | |||
| if (fp32_session_ == nullptr) { | |||
| @@ -1452,9 +1450,8 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||
| auto int8_model = lite::Model::Import(int8_content, size); | |||
| Context int8_ctx; | |||
| int8_ctx.device_type_ = DT_CPU; | |||
| int8_ctx.thread_num_ = calibrator_->GetThreadNum(); | |||
| int8_ctx.cpu_bind_mode_ = HIGHER_CPU; | |||
| int8_ctx.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; | |||
| int8_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&int8_ctx)); | |||
| if (int8_session_ == nullptr) { | |||