Browse Source

!8201 strided slice support axes as input

From: @zhaozhenlong
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d5777b70e2
16 changed files with 325 additions and 155 deletions
  1. +5
    -1
      mindspore/lite/nnacl/strided_slice.c
  2. +1
    -1
      mindspore/lite/src/ops/constant_of_shape.cc
  3. +1
    -0
      mindspore/lite/src/ops/populate/populate_register.h
  4. +1
    -0
      mindspore/lite/src/ops/populate/strided_slice_populate.cc
  5. +102
    -4
      mindspore/lite/src/ops/strided_slice.cc
  6. +1
    -0
      mindspore/lite/src/ops/strided_slice.h
  7. +32
    -19
      mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc
  8. +7
    -2
      mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h
  9. +0
    -6
      mindspore/lite/test/models_onnx.cfg
  10. +1
    -1
      mindspore/lite/test/run_benchmark_nets.sh
  11. +3
    -0
      mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc
  12. +51
    -51
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  13. +13
    -17
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
  14. +64
    -53
      mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc
  15. +4
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h
  16. +39
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_tensor_parser.h

+ 5
- 1
mindspore/lite/nnacl/strided_slice.c View File

@@ -104,8 +104,12 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p
dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5;
if (param->data_type == kDataTypeFloat) {
*((float *)out_data + out_offset) = *((float *)in_data + in_offset);
} else {
} else if (param->data_type == kDataTypeInt8) {
*((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset);
} else if (param->data_type == kDataTypeInt) {
*((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset);
} else {
return NNACL_ERR;
}
out_offset++;
}


+ 1
- 1
mindspore/lite/src/ops/constant_of_shape.cc View File

@@ -87,7 +87,7 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
}
auto in_data = reinterpret_cast<int *>(in_tensor->data_c());
if (in_data == nullptr) {
MS_LOG(ERROR) << "Input data is nullptr";
MS_LOG(INFO) << "Input data is nullptr. Input tensor has not been calculated out yet.";
return RET_INFER_INVALID;
}
int size = in_tensor->ElementsNum();


+ 1
- 0
mindspore/lite/src/ops/populate/populate_register.h View File

@@ -50,6 +50,7 @@ class Registry {
}
};
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive);
OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *primitive);
} // namespace lite
} // namespace mindspore
#endif

+ 1
- 0
mindspore/lite/src/ops/populate/strided_slice_populate.cc View File

@@ -41,6 +41,7 @@ OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *pr
memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int));
auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape();
memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int));
strided_slice_param->in_shape_length_ = static_cast<int>(in_shape.size());
return reinterpret_cast<OpParameter *>(strided_slice_param);
}



+ 102
- 4
mindspore/lite/src/ops/strided_slice.cc View File

@@ -15,6 +15,7 @@
*/

#include "src/ops/strided_slice.h"
#include <algorithm>

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
@@ -172,7 +173,8 @@ Registry StridedSliceRegistry(schema::PrimitiveType_StridedSlice, StridedSliceCr
namespace {
constexpr size_t kStridedSliceOutputNum = 1;
constexpr size_t kStridedSliceInputNum = 1;
constexpr size_t kStridedSliceMultiInputNum = 4;
constexpr size_t kStridedSliceMultiInputNumMin = 3;
constexpr size_t kStridedSliceMultiInputNumMax = 5;
} // namespace

void StridedSlice::ApplyNewAxisMask() {
@@ -251,13 +253,91 @@ void StridedSlice::TransIndexToPositive() {
}
}

int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs) {
// when axes input exist:
// input order: data, begin, end, axes(opt), stride(opt)
auto input_tensor = inputs.at(0);
MS_ASSERT(input_tensor != nullptr);
auto begin_tensor = inputs.at(1);
MS_ASSERT(begin_tensor != nullptr);
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
auto end_tensor = inputs.at(2);
MS_ASSERT(end_tensor != nullptr);
int *end_data = reinterpret_cast<int *>(end_tensor->MutableData());
if (begin_data == nullptr || end_data == nullptr) {
return RET_INFER_ERR;
}
// when input contains axes, begins, ends, strides will be expand to the same length as input rank
ndim_ = static_cast<int>(input_tensor->shape().size());
int begin_ndim = begin_tensor->ElementsNum();

int *axes_data = nullptr;
auto axes_tensor = inputs.at(3);
if (axes_tensor->ElementsNum() != 0) {
MS_ASSERT(axes_tensor->ElementsNum() == begin_ndim);
axes_data = reinterpret_cast<int *>(axes_tensor->MutableData());
if (axes_data == nullptr) {
return RET_INFER_ERR;
}
}

int *stride_data = nullptr;
auto stride_tensor = inputs.at(4);
if (stride_tensor->ElementsNum() != 0) {
MS_ASSERT(stride_tensor->ElementsNum() == begin_ndim);
stride_data = reinterpret_cast<int *>(stride_tensor->MutableData());
if (stride_data == nullptr) {
return RET_INFER_ERR;
}
}

std::vector<int> axes;
if (axes_data == nullptr) {
for (int i = 0; i < begin_ndim; ++i) {
axes[i] = i;
}
} else {
axes.assign(axes_data, axes_data + begin_ndim);
for (int i = 0; i < begin_ndim; ++i) {
if (axes[i] < 0) {
axes[i] += ndim_;
}
}
}

in_shape_.assign(ndim_, 0);
begins_.assign(ndim_, 0);
ends_.assign(ndim_, 0);
strides_.assign(ndim_, 0);
auto input_shape = input_tensor->shape();
for (int i = 0; i < ndim_; ++i) {
in_shape_[i] = input_shape.at(i);
}
for (int i = 0; i < ndim_; ++i) {
auto axes_it = std::find(axes.begin(), axes.end(), i);
if (axes_it != axes.end()) {
auto axis = axes_it - axes.begin();
// begins or ends exceed limit will be set to limit
begins_[i] = std::max(std::min(begin_data[axis], input_shape[i] - 1), -input_shape[i]);
ends_[i] = std::max(std::min(end_data[axis], input_shape[i]), -input_shape[i] - 1);
strides_[i] = stride_data[axis];
} else {
begins_[i] = 0;
ends_[i] = input_shape[i];
strides_[i] = 1;
}
}
return RET_OK;
}

int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kStridedSliceOutputNum) {
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
return RET_PARAM_INVALID;
}
if (inputs.size() != kStridedSliceInputNum && inputs.size() != kStridedSliceMultiInputNum) {
if (inputs.size() != kStridedSliceInputNum &&
!(inputs.size() <= kStridedSliceMultiInputNumMax && inputs.size() >= kStridedSliceMultiInputNumMin)) {
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
return RET_PARAM_INVALID;
}
@@ -268,6 +348,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
auto input_shape = input->shape();
auto inferflag = GetInferFlag();

in_shape_.clear();
begins_.clear();
ends_.clear();
strides_.clear();
if (inputs.size() == kStridedSliceInputNum) {
ndim_ = static_cast<int>(GetBegin().size());

@@ -279,7 +363,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
ends_.emplace_back((GetEnd())[i]);
strides_.emplace_back((GetStride())[i]);
}
} else {
}
if (inputs.size() == 4) {
// input order: input, begins, ends, strides.
auto begin_tensor = inputs.at(1);
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
auto end_tensor = inputs.at(2);
@@ -299,6 +385,13 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
strides_.emplace_back(stride_data[i]);
}
}
if (inputs.size() == 5) {
// input order: input, begins, end, axes, strides
auto ret = HandleAxesInputExist(inputs);
if (ret != RET_OK) {
return ret;
}
}

// set all mask to original input shape
begins_mask_.resize(ndim_);
@@ -333,7 +426,12 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1;
} else {
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
if (strides_.at(i) == 0) {
MS_LOG(ERROR) << "strides should not be 0.";
return RET_INFER_ERR;
}
output_shape.at(i) =
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
}
}



+ 1
- 0
mindspore/lite/src/ops/strided_slice.h View File

@@ -81,6 +81,7 @@ class StridedSlice : public PrimitiveC {
std::vector<bool> new_axis_mask_;
std::vector<bool> shrink_axis_mask_;
void TransIndexToPositive();
int HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs);
};
} // namespace lite
} // namespace mindspore


+ 32
- 19
mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc View File

@@ -16,11 +16,11 @@

#include "src/runtime/kernel/arm/base/strided_slice.h"
#include <vector>
#include "nnacl/strided_slice.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/ops/populate/populate_register.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@@ -44,16 +44,16 @@ int StridedSliceCPUKernel::Init() {
}

int StridedSliceCPUKernel::ReSize() {
auto input = in_tensors_.at(0);
auto parameter = reinterpret_cast<StridedSliceParameter *>(op_parameter_);
MS_ASSERT(input);
MS_ASSERT(parameter);
parameter->data_type = input->data_type() == kNumberTypeInt8 ? kDataTypeInt8 : kDataTypeFloat;
auto input_shape = input->shape();
for (size_t i = 0; i < input_shape.size(); ++i) {
parameter->in_shape_[i] = input_shape[i];
if (op_parameter_ != nullptr) {
free(op_parameter_);
op_parameter_ = nullptr;
}
op_parameter_ = PopulateStridedSliceParameter(primitive_);
if (op_parameter_ == nullptr) {
MS_LOG(ERROR) << "Malloc parameter failed";
return RET_ERROR;
}
parameter->in_shape_length_ = static_cast<int>(input_shape.size());
param_ = reinterpret_cast<StridedSliceParameter *>(op_parameter_);
return RET_OK;
}

@@ -62,8 +62,7 @@ int StridedSliceCPUKernel::HandleMultiInputs() {
MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size();
return RET_ERROR;
}
auto param = reinterpret_cast<StridedSliceParameter *>(op_parameter_);
if (param == nullptr) {
if (param_ == nullptr) {
MS_LOG(ERROR) << "StridedSliceParamater cast nullptr";
return RET_ERROR;
}
@@ -74,35 +73,49 @@ int StridedSliceCPUKernel::HandleMultiInputs() {
MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num;
return RET_ERROR;
}
memcpy(param->begins_, begins->MutableData(), axis_num * sizeof(int));
memcpy(param_->begins_, begins->MutableData(), axis_num * sizeof(int));

auto ends = in_tensors_.at(kEndsIndex);
MS_ASSERT(ends != nullptr);
MS_ASSERT(axis_num == ends->ElementsNum());
memcpy(param->ends_, ends->MutableData(), axis_num * sizeof(int));
memcpy(param_->ends_, ends->MutableData(), axis_num * sizeof(int));

auto strides = in_tensors_.at(kStridesInex);
MS_ASSERT(strides != nullptr);
MS_ASSERT(axis_num == strides->ElementsNum());
memcpy(param->strides_, strides->MutableData(), axis_num * sizeof(int));
memcpy(param_->strides_, strides->MutableData(), axis_num * sizeof(int));

param->num_axes_ = axis_num;
param_->num_axes_ = axis_num;
return RET_OK;
}

int StridedSliceCPUKernel::Run() {
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);
MS_ASSERT(input);
switch (input->data_type()) {
case kNumberTypeInt8:
param_->data_type = kDataTypeInt8;
break;
case kNumberTypeFloat32:
param_->data_type = kDataTypeFloat;
break;
case kNumberTypeInt32:
param_->data_type = kDataTypeInt;
break;
default:
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
return RET_ERROR;
}
auto output = out_tensors_.at(0);
MS_ASSERT(output);
// inputs order: input, begin, end, stride
if (in_tensors().size() == kMultiInputsSize) {
auto ret = HandleMultiInputs();
if (ret != RET_OK) {
return ret;
}
}
auto ret = DoStridedSlice(input->MutableData(), output->MutableData(),
reinterpret_cast<StridedSliceParameter *>(op_parameter_));
auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]";
return RET_ERROR;


+ 7
- 2
mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h View File

@@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_BACKEND_ARM_BASE_STRIDED_SLICE_H_

#include <vector>
#include "nnacl/strided_slice.h"
#include "src/lite_kernel.h"

namespace mindspore::kernel {
@@ -27,7 +27,9 @@ class StridedSliceCPUKernel : public LiteKernel {
StridedSliceCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<StridedSliceParameter *>(parameter);
}
~StridedSliceCPUKernel() override = default;

int Init() override;
@@ -36,6 +38,9 @@ class StridedSliceCPUKernel : public LiteKernel {

private:
int HandleMultiInputs();

private:
StridedSliceParameter *param_;
};
} // namespace mindspore::kernel



+ 0
- 6
mindspore/lite/test/models_onnx.cfg View File

@@ -25,9 +25,3 @@ psenet_lite_mbv2.onnx;1,32,32,3
super-resolution-10.onnx;1,224,224,1
tinyyolov2-8.onnx;1,416,416,3
ml_2012_ocr_cn.onnx
ml_2012_ocr_cn_noLSTM.onnx
candy-9.onnx
mosaic-9.onnx
pointilism-9.onnx
rain-princess-9.onnx
udnie-9.onnx

+ 1
- 1
mindspore/lite/test/run_benchmark_nets.sh View File

@@ -62,7 +62,7 @@ function Run_Converter() {
if [ $? = 0 ]; then
converter_result='converter onnx '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
else
converter_result='converter onnx '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file}
converter_result='converter onnx '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file};return 1
fi
done < ${models_onnx_config}



+ 3
- 0
mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc View File

@@ -64,6 +64,9 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() {
param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size);
parameter->set_default_param(param_value);
} else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) ==
meta_graph_->inputIndex.end()) {
parameter->set_default_param(param_value);
}
AddNode(i, parameter);
}


+ 51
- 51
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -56,11 +56,11 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
return dims;
}

STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) {
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph) {
MS_LOG(DEBUG) << "set onnx constant tensors";
for (const auto &onnx_const_value : onnx_graph.initializer()) {
int index;
const auto status = AddTensorProto(onnx_const_value, onnx_const_value.name(), GRAPH_INPUT, tensor_cache, &index);
const auto status = AddTensorProto(onnx_const_value, onnx_const_value.name(), GRAPH_INPUT, &index);
if (status != RET_OK) {
return status;
}
@@ -77,7 +77,7 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
if (attr.name() == "value") {
const auto &t = attr.t();
int index;
const auto status = AddTensorProto(t, node.output(0), GRAPH_INPUT, tensor_cache, &index);
const auto status = AddTensorProto(t, node.output(0), GRAPH_INPUT, &index);
if (status != RET_OK) {
return status;
}
@@ -93,7 +93,7 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
}

STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type,
TensorCache *tensor_cache, int *index) {
int *index) {
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type "
@@ -109,12 +109,12 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st
tensor->dims = GetDimsFromOnnxValue(proto);
tensor->format = schema::Format::Format_NCHW;
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
*index = tensor_cache->AddTensor(name, tensor.release(), type);
*index = OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(name, tensor.release(), type);
return RET_OK;
}

STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type,
TensorCache *tensor_cache, int *index) {
int *index) {
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type()));
if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type());
@@ -137,17 +137,16 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std
if (data_type == kNumberTypeInt64) {
tensor->dataType = kNumberTypeInt32; // CopyOnnxTensorData will convert int64 to int32
}
*index = tensor_cache->AddTensor(name, tensor.release(), type);
*index = OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(name, tensor.release(), type);
return RET_OK;
}

STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph,
TensorCache *tensor_cache) {
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph) {
for (const auto &input_value : onnx_graph.input()) {
auto ret = tensor_cache->FindTensor(input_value.name());
auto ret = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(input_value.name());
if (ret < 0) {
int index;
const auto status = AddValueInfo(input_value, input_value.name(), GRAPH_INPUT, tensor_cache, &index);
const auto status = AddValueInfo(input_value, input_value.name(), GRAPH_INPUT, &index);
if (status != RET_OK) {
return status;
}
@@ -158,14 +157,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
return RET_OK;
}

STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph,
TensorCache *tensor_cache) {
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph) {
for (const auto &output_value : onnx_graph.output()) {
int index;
if (tensor_cache->FindTensor(output_value.name()) != -1) {
index = tensor_cache->FindTensor(output_value.name());
if (OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(output_value.name()) != -1) {
index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(output_value.name());
} else {
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index);
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, &index);
if (status != RET_OK) {
return status;
}
@@ -178,7 +176,7 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,

void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph,
TensorCache *tensor_cache, const QuantType &quant_type) {
const QuantType &quant_type) {
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
dst_op_1->quantType = quant_type;
@@ -186,8 +184,8 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0);
std::vector<string> matmul_inputs{onnx_node.input(0), onnx_node.input(1)};
std::vector<string> matmul_outputs{matmul_output_id};
SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node, tensor_cache);
SetOpOutputIndex(matmul_outputs, dst_op_1.get(), tensor_cache);
SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node);
SetOpOutputIndex(matmul_outputs, dst_op_1.get());
graph->nodes.emplace_back(std::move(dst_op_1));
sub_graph->nodeIndices.push_back(graph->nodes.size() - 1);

@@ -197,15 +195,15 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get());
std::vector<string> biasadd_inputs{matmul_output_id, onnx_node.input(2)};
std::vector<string> biasadd_outputs{onnx_node.output(0)};
SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node, tensor_cache);
SetOpOutputIndex(biasadd_outputs, dst_op_2.get(), tensor_cache);
SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node);
SetOpOutputIndex(biasadd_outputs, dst_op_2.get());
graph->nodes.emplace_back(std::move(dst_op_2));
sub_graph->nodeIndices.push_back(graph->nodes.size() - 1);
}

STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node) {
// convert GivenTensorFill node to a weight/bias tensor
auto ret = tensor_cache->FindTensor(onnx_node.output(0));
auto ret = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node.output(0));
if (ret < 0) {
std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>();
std::vector<int> shape;
@@ -259,15 +257,16 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
return RET_ERROR;
}
}
auto index = tensor_cache->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT);
auto index =
OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT);
MS_LOG(DEBUG) << "add given tensor: " << index;
}
return RET_OK;
}

STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, TensorCache *tensor_cache,
const QuantType &quantType, schema::MetaGraphT *dst_graph) {
schema::CNodeT *dst_op, const QuantType &quantType,
schema::MetaGraphT *dst_graph) {
// change op_type() to name(), that is unique
static bool interrupt = false;
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
@@ -308,41 +307,43 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
// set op input index
std::vector<string> node_inputs;
(void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end());
if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) {
if (SetOpInputIndex(node_inputs, dst_op, onnx_node)) {
interrupt = true;
MS_LOG(ERROR) << "SetOpInputIndex failed";
return RET_ERROR;
}
if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) {
auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex));
auto &weight_tensor =
OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex));
weight_tensor->format = dst_op->primitive->value.AsConv2D()->format;
} else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) {
auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex));
auto &weight_tensor =
OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex));
weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format;
}
// set op output index
std::vector<string> node_outputs;
(void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end());

if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) {
if (SetOpOutputIndex(node_outputs, dst_op) != RET_OK) {
interrupt = true;
MS_LOG(ERROR) << "SetOpOutputIndex failed";
return RET_ERROR;
}
auto &output_tensor = tensor_cache->GetCachedTensor().at(dst_op->outputIndex.front());
auto &output_tensor =
OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->outputIndex.front());
if (output_tensor == nullptr) {
interrupt = true;
MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr.";
return RET_ERROR;
}
SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor, tensor_cache);
SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor);
return RET_OK;
}

void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) {
schema::CNodeT *dst_op, schema::TensorT *dst_tensor) {
MS_ASSERT(dst_op != nullptr);
MS_ASSERT(tensor_cache != nullptr);
std::vector<string> quant_node_name;
quant_node_name.insert(quant_node_name.begin(), onnx_node.input().begin(), onnx_node.input().end());
quant_node_name.insert(quant_node_name.end(), onnx_node.output().begin(), onnx_node.output().end());
@@ -404,10 +405,10 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co
}

STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
const onnx::NodeProto &onnx_node) {
for (const auto &onnx_node_input : node_inputs) {
if (onnx_node_input != "") {
int index = tensor_cache->FindTensor(onnx_node_input);
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node_input);
if (index < 0) {
MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found";
return RET_ERROR;
@@ -419,14 +420,14 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
return RET_OK;
}

STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op,
TensorCache *tensor_cache) {
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op) {
for (const auto &onnx_node_output : node_outputs) {
auto index = tensor_cache->FindTensor(onnx_node_output);
auto index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node_output);
if (index < 0) { // when index >= 0, it's graph's output
std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>();
tensor->nodeType = schema::NodeType_Parameter;
index = tensor_cache->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT);
index =
OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT);
}
MS_LOG(DEBUG) << "node: " << onnx_node_output << ", output index: " << index;
dst_op->outputIndex.emplace_back(index);
@@ -495,8 +496,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
return RET_OK;
}

STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) {
std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor();
STATUS OnnxModelParser::SetAllTensors(schema::MetaGraphT *graphDef) {
std::vector<schema::TensorT *> tensors = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor();
for (auto iter : tensors) {
std::unique_ptr<schema::TensorT> temp(iter);
graphDef->allTensors.emplace_back(move(temp));
@@ -549,12 +550,11 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr

int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph,
const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
TensorCache tensor_cache;
// dst_graph->name = onnx_graph.name(); // this is not used
// find out input names and const names
FindGraphInputAndConst(onnx_graph);
// set const tensor
int status = SetGraphConstTensor(onnx_graph, &tensor_cache);
int status = SetGraphConstTensor(onnx_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphConstTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@@ -563,7 +563,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT

// init onnx model graph input tensor

status = SetGraphInputTensor(onnx_graph, dst_sub_graph, &tensor_cache);
status = SetGraphInputTensor(onnx_graph, dst_sub_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphInputTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@@ -579,12 +579,12 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT
}
if (onnx_node.op_type() == "Gemm") {
if (status == RET_OK) {
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, &tensor_cache, quantType);
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, quantType);
}
continue;
} else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") {
if (status == RET_OK) {
status_node = ParseOnnxGivenFillNode(onnx_node, &tensor_cache);
status_node = ParseOnnxGivenFillNode(onnx_node);
if (status_node != RET_OK) {
MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status_node;
status = (status == RET_OK ? status_node : status);
@@ -594,7 +594,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT
}

std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), &tensor_cache, quantType, dst_graph);
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), quantType, dst_graph);
if (status_node != RET_OK) {
status = (status == RET_OK ? status_node : status);
continue;
@@ -604,19 +604,19 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT
}
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
for (auto &tensor : tensor_cache.GetCachedTensor()) {
for (auto &tensor : OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()) {
delete tensor;
}
return RET_ERROR;
}
// init onnx model graph output tensor
status = SetGraphOutputTensor(onnx_graph, dst_sub_graph, &tensor_cache);
status = SetGraphOutputTensor(onnx_graph, dst_sub_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
SetAllTensors(tensor_cache, dst_graph);
SetAllTensors(dst_graph);
return RET_OK;
}



+ 13
- 17
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h View File

@@ -30,7 +30,7 @@
#include "securec/include/securec.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
#include "proto/onnx.pb.h"

namespace mindspore {
@@ -53,42 +53,38 @@ class OnnxModelParser : public ModelParser {
private:
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);

STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph);

STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph);

STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph);

STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type,
TensorCache *tensor_cache, int *index);
STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index);

STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type,
TensorCache *tensor_cache, int *index);
STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index);

STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, TensorCache *tensor_cache, const QuantType &quantType,
schema::MetaGraphT *dst_graph);
schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph);

void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache,
const QuantType &quant_type);
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type);

STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node);

STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const string &onnx_op_type, schema::CNodeT *dst_op);

void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op,
schema::TensorT *dst_tensor, TensorCache *tensor_cache);
schema::TensorT *dst_tensor);

STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
const onnx::NodeProto &onnx_node);

STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache);
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op);

STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);

STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef);
STATUS SetAllTensors(schema::MetaGraphT *graphDef);

void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);



+ 64
- 53
mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc View File

@@ -17,9 +17,35 @@
#include "tools/converter/parser/onnx/onnx_slice_parser.h"
#include <memory>
#include <vector>
#include <string>

namespace mindspore {
namespace lite {
STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std::string &name,
onnx::NodeProto *onnx_node) {
std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>();
if (tensor == nullptr) {
MS_LOG(ERROR) << "new tensor failed";
return RET_ERROR;
}
tensor->dataType = mindspore::kNumberTypeInt32;
tensor->dims.push_back(onnx_val.size());
tensor->format = schema::Format::Format_NCHW;
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
int data_size = sizeof(int32_t) * onnx_val.size();
tensor->data.resize(data_size);
if (data_size != 0 &&
memcpy_s(static_cast<void *>(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size();
std::string tensor_name = name + std::to_string(tensor_num);
OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT);
onnx_node->add_input(tensor_name);
return RET_OK;
}

STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SliceParser";
@@ -33,15 +59,15 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
return RET_NULL_PTR;
}

std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>();
std::unique_ptr<schema::StridedSliceT> attr = std::make_unique<schema::StridedSliceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

std::vector<int> axes;
std::vector<int> starts;
std::vector<int> ends;
std::vector<int> axes;
std::vector<int> steps;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
@@ -71,64 +97,49 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}
}

if (onnx_node.input_size() > 1) {
const auto &starts_name = onnx_node.input(1);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == starts_name) {
starts.clear();
for (int i = 0; i < it.int32_data_size(); ++i) {
starts.push_back(it.int32_data(i));
}
}
if (axes.empty()) {
for (size_t i = 0; i < starts.size(); ++i) {
axes.push_back(i);
}
}

if (onnx_node.input_size() > 2) {
const auto &ends_name = onnx_node.input(2);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == ends_name) {
ends.clear();
for (int i = 0; i < it.int32_data_size(); ++i) {
ends.push_back(it.int32_data(i));
}
}
}
if (steps.empty()) {
steps.assign(starts.size(), 1);
}

if (onnx_node.input_size() > 3) {
const auto &axes_name = onnx_node.input(3);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == axes_name) {
axes.clear();
for (int i = 0; i < it.int32_data_size(); ++i) {
axes.push_back(it.int32_data(i));
}
}
onnx::NodeProto *slice_node = nullptr;
for (auto &node : onnx_graph.node()) {
if (&node == &onnx_node) {
slice_node = const_cast<onnx::NodeProto *>(&node);
}
}

if (onnx_node.input_size() > 4) {
const auto &steps_name = onnx_node.input(4);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == steps_name) {
steps.clear();
for (int i = 0; i < it.int32_data_size(); ++i) {
steps.push_back(it.int32_data(i));
}
}
int insert_num = 5 - onnx_node.input_size();
int status = RET_OK;
switch (insert_num) {
case 4: {
std::string name = "slice/starts/";
status = InsertTensor(starts, name, slice_node);
}
case 3:
if (status == RET_OK) {
std::string name = "slice/ends/";
status = InsertTensor(ends, name, slice_node);
}
case 2:
if (status == RET_OK) {
std::string name = "slice/axes/";
status = InsertTensor(axes, name, slice_node);
}
case 1:
if (status == RET_OK) {
std::string name = "slice/steps/";
status = InsertTensor(steps, name, slice_node);
}
default:
if (status != RET_OK) {
MS_LOG(ERROR) << "onnx slice insert tensor failed";
return RET_ERROR;
}
}

std::vector<int> sizes(starts.size(), -1);
for (size_t i = 0; i < starts.size(); ++i) {
sizes[i] = (ends[i] < 0 ? ends[i] : ends[i] - starts[i]);
}
attr->axes = axes;
attr->begin = starts;
attr->size = sizes;
attr->step = steps;
op->primitive->value.type = schema::PrimitiveType_Slice;
op->primitive->value.type = schema::PrimitiveType_StridedSlice;
op->primitive->value.value = attr.release();
return RET_OK;
}


+ 4
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h View File

@@ -17,8 +17,11 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H

#include <vector>
#include <string>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"

namespace mindspore {
namespace lite {
@@ -27,6 +30,7 @@ class OnnxSliceParser : public OnnxNodeParser {
OnnxSliceParser() : OnnxNodeParser("Slice") {}
~OnnxSliceParser() override = default;

STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node);
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite


+ 39
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_tensor_parser.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H

#include "tools/common/tensor_util.h"

namespace mindspore {
namespace lite {
class OnnxTensorParser {
public:
~OnnxTensorParser() = default;
static OnnxTensorParser *GetInstance() {
static OnnxTensorParser onnxTensorParser;
return &onnxTensorParser;
}
TensorCache *GetTensorCache() { return &tensor_cache_; }

private:
OnnxTensorParser() = default;
TensorCache tensor_cache_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H

Loading…
Cancel
Save