Browse Source

registry kernel support cxx api

tags/v1.3.0
chenjianping 4 years ago
parent
commit
3f4a587849
4 changed files with 66 additions and 4 deletions
  1. +9
    -0
      include/api/context.h
  2. +51
    -0
      mindspore/lite/src/cxx_api/context.cc
  3. +4
    -2
      mindspore/lite/src/cxx_api/model/model_impl.cc
  4. +2
    -2
      mindspore/lite/src/sub_graph_kernel.cc

+ 9
- 0
include/api/context.h View File

@@ -74,6 +74,15 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
return std::static_pointer_cast<T>(shared_from_this());
}

std::string GetProvider() const;
void SetProvider(const std::string &provider);

std::string GetProviderDevice() const;
void SetProviderDevice(const std::string &device);

void SetAllocator(const std::shared_ptr<Allocator> &allocator);
std::shared_ptr<Allocator> GetAllocator() const;

protected:
std::shared_ptr<Data> data_;
};


+ 51
- 0
mindspore/lite/src/cxx_api/context.cc View File

@@ -28,6 +28,8 @@ constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionProvider = "mindspore.option.provider";
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";

struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
@@ -37,6 +39,7 @@ struct Context::Data {

struct DeviceInfoContext::Data {
std::map<std::string, std::any> params;
std::shared_ptr<Allocator> allocator = nullptr;
};

Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
@@ -97,6 +100,54 @@ std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {

DeviceInfoContext::DeviceInfoContext() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}

std::string DeviceInfoContext::GetProvider() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return "";
}
return GetValue<std::string>(data_, kModelOptionProvider);
}

void DeviceInfoContext::SetProvider(const std::string &provider) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionProvider] = provider;
}

std::string DeviceInfoContext::GetProviderDevice() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return "";
}
return GetValue<std::string>(data_, kModelOptionProviderDevice);
}

void DeviceInfoContext::SetProviderDevice(const std::string &device) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionProviderDevice] = device;
}

void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->allocator = allocator;
}

std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return nullptr;
}
return data_->allocator;
}

void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";


+ 4
- 2
mindspore/lite/src/cxx_api/model/model_impl.cc View File

@@ -90,13 +90,15 @@ Status ModelImpl::Build() {
}
lite::DeviceInfo cpu_info = {0};
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
model_context.device_list_.push_back({lite::DT_CPU, cpu_info});
model_context.device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
if (device_list.size() == 2) {
lite::DeviceInfo device_info = {0};
if (device_list[1]->GetDeviceType() == kMaliGPU) {
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
model_context.device_list_.push_back({lite::DT_GPU, device_info});
model_context.device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
device_info.npu_device_info_ = {npu_context->GetFrequency()};


+ 2
- 2
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -155,7 +155,7 @@ int CustomSubGraph::Prepare() {
}
auto provider = nodes_[0]->desc().provider;
auto context = this->Context();
AllocatorPtr allocator = nullptr;
AllocatorPtr allocator = context->allocator;
auto iter = std::find_if(context->device_list_.begin(), context->device_list_.end(),
[&provider](const auto &dev) { return dev.provider_ == provider; });
if (iter != context->device_list_.end()) {
@@ -173,7 +173,7 @@ int CustomSubGraph::Prepare() {
auto node = nodes_[nodes_.size() - 1];
for (auto tensor : node->out_tensors()) {
MS_ASSERT(tensor != nullptr);
tensor->set_allocator(this->Context()->allocator);
tensor->set_allocator(context->allocator);
}
return RET_OK;
}


Loading…
Cancel
Save