Browse Source

support callback api

tags/v1.4.0
zhangxuetong 4 years ago
parent
commit
15dab6daa6
16 changed files with 287 additions and 131 deletions
  1. +1
    -0
      include/OWNERS
  2. +10
    -7
      include/api/context.h
  3. +2
    -1
      include/api/model.h
  4. +12
    -0
      include/api/types.h
  5. +26
    -15
      mindspore/ccsrc/cxx_api/context.cc
  6. +2
    -1
      mindspore/ccsrc/cxx_api/model/model.cc
  7. +1
    -0
      mindspore/lite/minddata/CMakeLists.txt
  8. +1
    -0
      mindspore/lite/src/CMakeLists.txt
  9. +43
    -22
      mindspore/lite/src/cxx_api/context.cc
  10. +3
    -2
      mindspore/lite/src/cxx_api/model/model.cc
  11. +93
    -52
      mindspore/lite/src/cxx_api/model/model_impl.cc
  12. +5
    -1
      mindspore/lite/src/cxx_api/model/model_impl.h
  13. +0
    -3
      mindspore/lite/src/cxx_api/serialization.cc
  14. +77
    -0
      mindspore/lite/src/cxx_api/tensor_utils.cc
  15. +10
    -27
      mindspore/lite/src/cxx_api/tensor_utils.h
  16. +1
    -0
      mindspore/lite/test/ut/src/infer_test.cc

+ 1
- 0
include/OWNERS View File

@@ -3,5 +3,6 @@ approvers:
- hangangqiang
- xu-yfei
- wilfchen
- zhang_xue_tong
reviewers:
- lx0095

+ 10
- 7
include/api/context.h View File

@@ -46,8 +46,16 @@ class MS_API Context {
void SetThreadNum(int32_t thread_num);
int32_t GetThreadNum() const;

void SetAllocator(const std::shared_ptr<Allocator> &allocator);
std::shared_ptr<Allocator> GetAllocator() const;
/// \brief Set the thread affinity to CPU cores.
///
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
int GetThreadAffinityMode() const;

void SetThreadAffinity(const std::vector<int> &core_list);
std::vector<int32_t> GetThreadAffinityCoreList() const;
void SetEnableParallel(bool is_parallel);
bool GetEnableParallel() const;

std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();

@@ -91,11 +99,6 @@ class MS_API CPUDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };

/// \brief Set the thread affinity to CPU cores.
///
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
int GetThreadAffinity() const;
void SetEnableFP16(bool is_fp16);
bool GetEnableFP16() const;
};


+ 2
- 1
include/api/model.h View File

@@ -41,7 +41,8 @@ class MS_API Model {
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr);
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);

Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);

std::vector<MSTensor> GetInputs();
inline MSTensor GetInputByTensorName(const std::string &tensor_name);


+ 12
- 0
include/api/types.h View File

@@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include <memory>
#include <functional>
#include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h"

@@ -142,5 +143,16 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
: MSTensor(StringToChar(name), type, shape, data, data_len) {}

std::string MSTensor::Name() const { return CharToString(CharName()); }

/// \brief CallBackParam defined input arguments for callBack function.
struct MSCallBackParam {
std::string node_name_; /**< node name argument */
std::string node_type_; /**< node type argument */
};

/// \brief KernelCallBack defined the function pointer for callBack.
using MSKernelCallBack = std::function<bool(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
const MSCallBackParam &opInfo)>;

} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_TYPES_H

+ 26
- 15
mindspore/ccsrc/cxx_api/context.cc View File

@@ -21,7 +21,6 @@
#include "utils/log_adapter.h"

constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
@@ -48,7 +47,9 @@ class Allocator {};
struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
int32_t thread_num;
std::shared_ptr<Allocator> allocator;
bool enable_parallel_ = false;
std::vector<int32_t> affinity_core_list_;
int affinity_mode_ = 2;
};

struct DeviceInfoContext::Data {
@@ -84,13 +85,32 @@ int32_t Context::GetThreadNum() const {
return data_->thread_num;
}

void Context::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
void Context::SetEnableParallel(bool is_parallel) {
MS_EXCEPTION_IF_NULL(data_);
data_->allocator = allocator;
data_->enable_parallel_ = is_parallel;
}
std::shared_ptr<Allocator> Context::GetAllocator() const {

bool Context::GetEnableParallel() const {
MS_EXCEPTION_IF_NULL(data_);
return data_->enable_parallel_;
}

void Context::SetThreadAffinity(int mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->affinity_mode_ = mode;
}
int Context::GetThreadAffinityMode() const {
MS_EXCEPTION_IF_NULL(data_);
return data_->affinity_mode_;
}

void Context::SetThreadAffinity(const std::vector<int> &core_list) {
MS_EXCEPTION_IF_NULL(data_);
data_->affinity_core_list_ = core_list;
}
std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
MS_EXCEPTION_IF_NULL(data_);
return data_->allocator;
return data_->affinity_core_list_;
}

std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
@@ -109,15 +129,6 @@ bool CPUDeviceInfo::GetEnableFP16() const {
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
}

void CPUDeviceInfo::SetThreadAffinity(int affinity) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionCpuThreadAffinity] = affinity;
}
int CPUDeviceInfo::GetThreadAffinity() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<bool>(data_, kModelOptionCpuThreadAffinity);
}

void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionMaliGpuEnableFP16] = is_fp16;


+ 2
- 1
mindspore/ccsrc/cxx_api/model/model.cc View File

@@ -73,7 +73,8 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std:
return impl_->Resize(inputs, dims);
}

Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Failed because this model has not been built.";
return kMCFailed;


+ 1
- 0
mindspore/lite/minddata/CMakeLists.txt View File

@@ -109,6 +109,7 @@ if(BUILD_MINDDATA STREQUAL "full")

set(MINDDATA_FULL_SRC
${TOP_DIR}/mindspore/lite/src/cxx_api/types.cc
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor_utils.cc
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc
${TOP_DIR}/mindspore/lite/src/tensor.cc
${TOP_DIR}/mindspore/lite/src/ms_tensor.cc


+ 1
- 0
mindspore/lite/src/CMakeLists.txt View File

@@ -35,6 +35,7 @@ else()
${CORE_DIR}/utils/status.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/types.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/context.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model.cc


+ 43
- 22
mindspore/lite/src/cxx_api/context.cc View File

@@ -25,7 +25,6 @@

namespace mindspore {
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionProvider = "mindspore.option.provider";
@@ -34,7 +33,9 @@ constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
int32_t thread_num = 2;
std::shared_ptr<Allocator> allocator = nullptr;
bool enable_parallel_ = false;
std::vector<int32_t> affinity_core_list_;
int affinity_mode_ = 2;
};

struct DeviceInfoContext::Data {
@@ -74,19 +75,54 @@ int32_t Context::GetThreadNum() const {
return data_->thread_num;
}

void Context::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
void Context::SetEnableParallel(bool is_parallel) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->allocator = allocator;
data_->enable_parallel_ = is_parallel;
}
std::shared_ptr<Allocator> Context::GetAllocator() const {

bool Context::GetEnableParallel() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return nullptr;
return false;
}
return data_->allocator;
return data_->enable_parallel_;
}

void Context::SetThreadAffinity(int mode) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->affinity_mode_ = mode;

return;
}
int Context::GetThreadAffinityMode() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return -1;
}
return data_->affinity_mode_;
}

void Context::SetThreadAffinity(const std::vector<int> &core_list) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->affinity_core_list_ = core_list;

return;
}
std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return {};
}
return data_->affinity_core_list_;
}

std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
@@ -163,21 +199,6 @@ bool CPUDeviceInfo::GetEnableFP16() const {
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
}

void CPUDeviceInfo::SetThreadAffinity(int affinity) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionCpuThreadAffinity] = affinity;
}
int CPUDeviceInfo::GetThreadAffinity() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return 0;
}
return GetValue<int>(data_, kModelOptionCpuThreadAffinity);
}

void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";


+ 3
- 2
mindspore/lite/src/cxx_api/model/model.cc View File

@@ -53,12 +53,13 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std:
return impl_->Resize(inputs, dims);
}

Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteNullptr;
}
return impl_->Predict(inputs, outputs);
return impl_->Predict(inputs, outputs, before, after);
}

Model::Model() : impl_(nullptr) {}


+ 93
- 52
mindspore/lite/src/cxx_api/model/model_impl.cc View File

@@ -16,16 +16,12 @@

#include "src/cxx_api/model/model_impl.h"
#include <memory>
#include <unordered_map>
#include <algorithm>
#include "include/api/types.h"
#include "include/api/context.h"
#include "include/api/dual_abi_helper.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "src/lite_model.h"
#include "src/runtime/inner_allocator.h"
#include "src/common/string_util.h"
#include "src/cxx_api/graph/graph_data.h"
#include "src/cxx_api/tensor/tensor_impl.h"
#include "src/cxx_api/tensor_utils.h"
@@ -35,22 +31,21 @@ namespace mindspore {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;

Status ModelImpl::Build() {
MS_LOG(DEBUG) << "Start build model.";
auto model = graph_->graph_data_->lite_model();
if (graph_ == nullptr || graph_->graph_data_ == nullptr || model == nullptr) {
MS_LOG(ERROR) << "Invalid graph.";
return kLiteNullptr;
lite::CpuBindMode ModelImpl::GetCpuBindMode() {
auto affinity_mode = context_->GetThreadAffinityMode();
switch (affinity_mode) {
case 0:
return lite::NO_BIND;
case 1:
return lite::HIGHER_CPU;
case 2:
return lite::MID_CPU;
default:
return lite::NO_BIND;
}
if (model->buf == nullptr) {
MS_LOG(ERROR) << "Lite model has been freed.";
return kLiteError;
}
if (context_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return kLiteNullptr;
}
lite::Context model_context;
}

Status ModelImpl::ConverterContext(lite::Context *model_context) {
auto device_list = context_->MutableDeviceInfo();
if (device_list.size() == 0) {
MS_LOG(ERROR) << "Invalid device list.";
@@ -60,54 +55,73 @@ Status ModelImpl::Build() {
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
return kLiteInputParamInvalid;
}
model_context.allocator = context_->GetAllocator();
if (model_context.allocator == nullptr) {
model_context.allocator = Allocator::Create();
if (model_context.allocator == nullptr) {
MS_LOG(ERROR) << "Create Allocator failed.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "Set new allocator.";
context_->SetAllocator(model_context.allocator);
}
model_context.thread_num_ = context_->GetThreadNum();
model_context.device_list_.clear();

model_context->thread_num_ = context_->GetThreadNum();
model_context->enable_parallel_ = context_->GetEnableParallel();
model_context->affinity_core_list_ = context_->GetThreadAffinityCoreList();
model_context->device_list_.clear();
if (device_list[0]->GetDeviceType() != kCPU) {
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
return kLiteInputParamInvalid;
}

auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
lite::CpuBindMode mode;
if (cpu_context->GetThreadAffinity() == 0) {
mode = lite::NO_BIND;
} else if (cpu_context->GetThreadAffinity() == 1) {
mode = lite::HIGHER_CPU;
} else if (cpu_context->GetThreadAffinity() == 2) {
mode = lite::MID_CPU;
} else {
MS_LOG(ERROR) << "Invalid thread affinity.";
return kLiteInputParamInvalid;
model_context->allocator = cpu_context->GetAllocator();
if (model_context->allocator == nullptr) {
model_context->allocator = Allocator::Create();
if (model_context->allocator == nullptr) {
MS_LOG(ERROR) << "Create Allocator failed.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "Set new allocator.";
cpu_context->SetAllocator(model_context->allocator);
}

lite::CpuBindMode mode = GetCpuBindMode();
lite::DeviceInfo cpu_info = {0};
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
model_context.device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
model_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
if (device_list.size() == 2) {
lite::DeviceInfo device_info = {0};
if (device_list[1]->GetDeviceType() == kMaliGPU) {
auto gpu_context = device_list[1]->Cast<MaliGPUDeviceInfo>();
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
model_context.device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
model_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
device_info.npu_device_info_ = {npu_context->GetFrequency()};
model_context.device_list_.push_back({lite::DT_NPU, device_info});
model_context->device_list_.push_back({lite::DT_NPU, device_info});
} else {
MS_LOG(ERROR) << "Invalid device.";
return kLiteInputParamInvalid;
}
}
return kSuccess;
}

Status ModelImpl::Build() {
MS_LOG(DEBUG) << "Start build model.";
auto model = graph_->graph_data_->lite_model();
if (graph_ == nullptr || graph_->graph_data_ == nullptr || model == nullptr) {
MS_LOG(ERROR) << "Invalid graph.";
return kLiteNullptr;
}
if (model->buf == nullptr) {
MS_LOG(ERROR) << "Lite model has been freed.";
return kLiteError;
}
if (context_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return kLiteNullptr;
}
lite::Context model_context;
auto status = ConverterContext(&model_context);
if (status != kSuccess) {
return status;
}

auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
if (session == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
@@ -130,7 +144,36 @@ static void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MS
}
}

Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (before == nullptr || after == nullptr) {
auto ret = session_->RunGraph();
return static_cast<StatusCode>(ret);
}
auto before_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
const CallBackParam &call_param) {
std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
MSCallBackParam mscall_param;
mscall_param.node_name_ = call_param.node_name;
mscall_param.node_type_ = call_param.node_type;
return before(inputs, outputs, mscall_param);
};
auto after_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
const CallBackParam &call_param) {
std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
MSCallBackParam mscall_param;
mscall_param.node_name_ = call_param.node_name;
mscall_param.node_type_ = call_param.node_type;
return after(inputs, outputs, mscall_param);
};
auto ret = session_->RunGraph(before_call_back, after_call_back);
return static_cast<StatusCode>(ret);
}
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (outputs == nullptr) {
MS_LOG(ERROR) << "outputs is nullptr.";
return kLiteError;
@@ -188,13 +231,11 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
}
}
}
session_->BindThread(true);
auto ret = session_->RunGraph();
session_->BindThread(false);
auto ret = RunGraph(before, after);
ResetTensorData(old_data, input_tensors);
if (ret != RET_OK) {
if (ret != kSuccess) {
MS_LOG(ERROR) << "Run graph failed.";
return static_cast<StatusCode>(ret);
return ret;
}
MS_LOG(DEBUG) << "Run graph success.";
auto res = GetOutputs();


+ 5
- 1
mindspore/lite/src/cxx_api/model/model_impl.h View File

@@ -38,7 +38,8 @@ class ModelImpl {
Status Build();
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);

Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
const MSKernelCallBack &after);

std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
@@ -56,6 +57,9 @@ class ModelImpl {
std::shared_ptr<Context> context_;
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
lite::CpuBindMode GetCpuBindMode();
Status ConverterContext(lite::Context *model_context);
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
};
} // namespace mindspore



+ 0
- 3
mindspore/lite/src/cxx_api/serialization.cc View File

@@ -17,12 +17,9 @@
#include "include/api/serialization.h"
#include <algorithm>
#include <queue>
#include <set>
#include "include/api/graph.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/model.h"
#include "include/ms_tensor.h"
#include "src/cxx_api/graph/graph_data.h"
#include "src/common/log_adapter.h"



+ 77
- 0
mindspore/lite/src/cxx_api/tensor_utils.cc View File

@@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/cxx_api/tensor_utils.h"
#include "src/common/log_adapter.h"

namespace mindspore {
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size) {
std::vector<int32_t> empty;
if (shape.empty()) {
return empty;
}
std::vector<int32_t> truncated_shape;
truncated_shape.resize(shape.size());
size_t element_size = lite::DataTypeSize(type);
for (size_t i = 0; i < shape.size(); i++) {
auto dim = shape[i];
if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast<size_t>(dim)) {
MS_LOG(ERROR) << "Invalid shape.";
return empty;
} else {
element_size *= static_cast<size_t>(dim);
truncated_shape[i] = static_cast<int32_t>(dim);
}
}
if (verify_size) {
if (element_size != data_len) {
MS_LOG(ERROR) << "Invalid data size.";
return empty;
}
}
return truncated_shape;
}
Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor) {
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(srcTensor));
if (impl == nullptr || impl->lite_tensor() == nullptr) {
MS_LOG(ERROR) << "Create tensor failed.";
return kLiteError;
}
auto tensor = MSTensor(impl);
if (tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor failed.";
return kLiteError;
}
*dstTensor = tensor;
return kSuccess;
}

std::vector<MSTensor> LiteTensorsToMSTensors(const std::vector<mindspore::tensor::MSTensor *> &srcTensors) {
std::vector<MSTensor> dstTensors;
dstTensors.reserve(srcTensors.size());
for (auto inTensor : srcTensors) {
MSTensor tensor;
auto status = LiteTensorToMSTensor(inTensor, &tensor);
if (status != kSuccess) {
return {};
}
dstTensors.emplace_back(tensor);
}
return dstTensors;
}

} // namespace mindspore

+ 10
- 27
mindspore/lite/src/cxx_api/tensor_utils.h View File

@@ -19,36 +19,19 @@

#include <limits.h>
#include <vector>
#include <memory>
#include "ir/dtype/type_id.h"
#include "include/ms_tensor.h"
#include "include/api/types.h"
#include "src/cxx_api/tensor/tensor_impl.h"

namespace mindspore {
static std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size) {
std::vector<int32_t> empty;
if (shape.empty()) {
return empty;
}
std::vector<int32_t> truncated_shape;
truncated_shape.resize(shape.size());
size_t element_size = lite::DataTypeSize(type);
for (size_t i = 0; i < shape.size(); i++) {
auto dim = shape[i];
if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast<size_t>(dim)) {
MS_LOG(ERROR) << "Invalid shape.";
return empty;
} else {
element_size *= static_cast<size_t>(dim);
truncated_shape[i] = static_cast<int32_t>(dim);
}
}
if (verify_size) {
if (element_size != data_len) {
MS_LOG(ERROR) << "Invalid data size.";
return empty;
}
}
return truncated_shape;
}
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size);
Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor);

std::vector<MSTensor> LiteTensorsToMSTensors(const std::vector<mindspore::tensor::MSTensor *> &srcTensors);

} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_CXX_API_TENSOR_UTILS_H

+ 1
- 0
mindspore/lite/test/ut/src/infer_test.cc View File

@@ -353,4 +353,5 @@ TEST_F(InferTest, TestModel) {
auto outputs = session->GetOutputs();
MS_LOG(INFO) << "Passed";
}

} // namespace mindspore

Loading…
Cancel
Save