Browse Source

modify external api impl for ABI

tags/v1.6.0
zhengyuanhua 4 years ago
parent
commit
18e6b49630
2 changed files with 26 additions and 14 deletions
  1. +14
    -4
      include/api/context.h
  2. +12
    -10
      mindspore/lite/src/cxx_api/context.cc

+ 14
- 4
include/api/context.h View File

@@ -138,21 +138,21 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
/// \brief obtain provider's name
///
/// \return provider's name.
std::string GetProvider() const;
inline std::string GetProvider() const;
/// \brief set provider's name.
///
/// \param[in] provider define the provider's name.

void SetProvider(const std::string &provider);
inline void SetProvider(const std::string &provider);
/// \brief obtain provider's device type.
///
/// \return provider's device type.

std::string GetProviderDevice() const;
inline std::string GetProviderDevice() const;
/// \brief set provider's device type.
///
/// \param[in] device define the provider's device type.EG: CPU.
void SetProviderDevice(const std::string &device);
inline void SetProviderDevice(const std::string &device);

/// \brief set memory allocator.
///
@@ -165,9 +165,19 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
std::shared_ptr<Allocator> GetAllocator() const;

protected:
std::vector<char> GetProviderChar() const;
void SetProvider(const std::vector<char> &provider);
std::vector<char> GetProviderDeviceChar() const;
void SetProviderDevice(const std::vector<char> &device);

std::shared_ptr<Data> data_;
};

std::string DeviceInfoContext::GetProvider() const { return CharToString(GetProviderChar()); }
void DeviceInfoContext::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); }
std::string DeviceInfoContext::GetProviderDevice() const { return CharToString(GetProviderDeviceChar()); }
void DeviceInfoContext::SetProviderDevice(const std::string &device) { SetProviderDevice(StringToChar(device)); }

/// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
/// for MindSpore Lite.
class MS_API CPUDeviceInfo : public DeviceInfoContext {


+ 12
- 10
mindspore/lite/src/cxx_api/context.cc View File

@@ -164,36 +164,38 @@ std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {

DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared<Data>()) {}

std::string DeviceInfoContext::GetProvider() const {
std::vector<char> DeviceInfoContext::GetProviderChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return "";
return std::vector<char>();
}
return GetValue<std::string>(data_, kModelOptionProvider);
const std::string &ref = GetValue<std::string>(data_, kModelOptionProvider);
return StringToChar(ref);
}

void DeviceInfoContext::SetProvider(const std::string &provider) {
void DeviceInfoContext::SetProvider(const std::vector<char> &provider) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionProvider] = provider;
data_->params[kModelOptionProvider] = CharToString(provider);
}

std::string DeviceInfoContext::GetProviderDevice() const {
std::vector<char> DeviceInfoContext::GetProviderDeviceChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return "";
return std::vector<char>();
}
return GetValue<std::string>(data_, kModelOptionProviderDevice);
const std::string &ref = GetValue<std::string>(data_, kModelOptionProviderDevice);
return StringToChar(ref);
}

void DeviceInfoContext::SetProviderDevice(const std::string &device) {
void DeviceInfoContext::SetProviderDevice(const std::vector<char> &device) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionProviderDevice] = device;
data_->params[kModelOptionProviderDevice] = CharToString(device);
}

void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {


Loading…
Cancel
Save