diff --git a/mindspore/lite/nnacl/strided_slice.c b/mindspore/lite/nnacl/strided_slice.c index 48a6199343..f227a082f8 100644 --- a/mindspore/lite/nnacl/strided_slice.c +++ b/mindspore/lite/nnacl/strided_slice.c @@ -104,8 +104,12 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5; if (param->data_type == kDataTypeFloat) { *((float *)out_data + out_offset) = *((float *)in_data + in_offset); - } else { + } else if (param->data_type == kDataTypeInt8) { *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); + } else if (param->data_type == kDataTypeInt) { + *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); + } else { + return NNACL_ERR; } out_offset++; } diff --git a/mindspore/lite/src/ops/constant_of_shape.cc b/mindspore/lite/src/ops/constant_of_shape.cc index 73630f650d..cacf29ae35 100644 --- a/mindspore/lite/src/ops/constant_of_shape.cc +++ b/mindspore/lite/src/ops/constant_of_shape.cc @@ -87,7 +87,7 @@ int ConstantOfShape::InferShape(std::vector inputs_, std::vector(in_tensor->data_c()); if (in_data == nullptr) { - MS_LOG(ERROR) << "Input data is nullptr"; + MS_LOG(INFO) << "Input data is nullptr. Input tensor has not been calculated out yet."; return RET_INFER_INVALID; } int size = in_tensor->ElementsNum(); diff --git a/mindspore/lite/src/ops/populate/populate_register.h b/mindspore/lite/src/ops/populate/populate_register.h index d6754644d7..fe3244841e 100644 --- a/mindspore/lite/src/ops/populate/populate_register.h +++ b/mindspore/lite/src/ops/populate/populate_register.h @@ -50,6 +50,7 @@ class Registry { } }; OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive); +OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *primitive); } // namespace lite } // namespace mindspore #endif diff --git a/mindspore/lite/src/ops/populate/strided_slice_populate.cc b/mindspore/lite/src/ops/populate/strided_slice_populate.cc index 9f25183f62..1ff1218930 100644 --- a/mindspore/lite/src/ops/populate/strided_slice_populate.cc +++ b/mindspore/lite/src/ops/populate/strided_slice_populate.cc @@ -41,6 +41,7 @@ OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *pr memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int)); auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape(); memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); + strided_slice_param->in_shape_length_ = static_cast(in_shape.size()); return reinterpret_cast(strided_slice_param); } diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 8b6a8fa97d..e40bde14bf 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -15,6 +15,7 @@ */ #include "src/ops/strided_slice.h" +#include #ifndef PRIMITIVE_WRITEABLE #include "src/ops/ops_register.h" @@ -172,7 +173,8 @@ Registry StridedSliceRegistry(schema::PrimitiveType_StridedSlice, StridedSliceCr namespace { constexpr size_t kStridedSliceOutputNum = 1; constexpr size_t kStridedSliceInputNum = 1; -constexpr size_t kStridedSliceMultiInputNum = 4; +constexpr size_t kStridedSliceMultiInputNumMin = 3; +constexpr size_t kStridedSliceMultiInputNumMax = 5; } // namespace void StridedSlice::ApplyNewAxisMask() { @@ -251,13 +253,91 @@ void StridedSlice::TransIndexToPositive() { } } +int StridedSlice::HandleAxesInputExist(const std::vector &inputs) { + // when axes input exist: + // input order: data, begin, end, axes(opt), stride(opt) + auto input_tensor = inputs.at(0); + MS_ASSERT(input_tensor != nullptr); + auto begin_tensor = inputs.at(1); + MS_ASSERT(begin_tensor != nullptr); + int *begin_data = reinterpret_cast(begin_tensor->MutableData()); + auto end_tensor = inputs.at(2); + MS_ASSERT(end_tensor != nullptr); + int *end_data = reinterpret_cast(end_tensor->MutableData()); + if (begin_data == nullptr || end_data == nullptr) { + return RET_INFER_ERR; + } + // when input contains axes, begins, ends, strides will be expand to the same length as input rank + ndim_ = static_cast(input_tensor->shape().size()); + int begin_ndim = begin_tensor->ElementsNum(); + + int *axes_data = nullptr; + auto axes_tensor = inputs.at(3); + if (axes_tensor->ElementsNum() != 0) { + MS_ASSERT(axes_tensor->ElementsNum() == begin_ndim); + axes_data = reinterpret_cast(axes_tensor->MutableData()); + if (axes_data == nullptr) { + return RET_INFER_ERR; + } + } + + int *stride_data = nullptr; + auto stride_tensor = inputs.at(4); + if (stride_tensor->ElementsNum() != 0) { + MS_ASSERT(stride_tensor->ElementsNum() == begin_ndim); + stride_data = reinterpret_cast(stride_tensor->MutableData()); + if (stride_data == nullptr) { + return RET_INFER_ERR; + } + } + + std::vector axes; + if (axes_data == nullptr) { + for (int i = 0; i < begin_ndim; ++i) { + axes[i] = i; + } + } else { + axes.assign(axes_data, axes_data + begin_ndim); + for (int i = 0; i < begin_ndim; ++i) { + if (axes[i] < 0) { + axes[i] += ndim_; + } + } + } + + in_shape_.assign(ndim_, 0); + begins_.assign(ndim_, 0); + ends_.assign(ndim_, 0); + strides_.assign(ndim_, 0); + auto input_shape = input_tensor->shape(); + for (int i = 0; i < ndim_; ++i) { + in_shape_[i] = input_shape.at(i); + } + for (int i = 0; i < ndim_; ++i) { + auto axes_it = std::find(axes.begin(), axes.end(), i); + if (axes_it != axes.end()) { + auto axis = axes_it - axes.begin(); + // begins or ends exceed limit will be set to limit + begins_[i] = std::max(std::min(begin_data[axis], input_shape[i] - 1), -input_shape[i]); + ends_[i] = std::max(std::min(end_data[axis], input_shape[i]), -input_shape[i] - 1); + strides_[i] = stride_data[axis]; + } else { + begins_[i] = 0; + ends_[i] = input_shape[i]; + strides_[i] = 1; + } + } + return RET_OK; +} + int StridedSlice::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kStridedSliceOutputNum) { MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); return RET_PARAM_INVALID; } - if (inputs.size() != kStridedSliceInputNum && inputs.size() != kStridedSliceMultiInputNum) { + if (inputs.size() != kStridedSliceInputNum && + !(inputs.size() <= kStridedSliceMultiInputNumMax && inputs.size() >= kStridedSliceMultiInputNumMin)) { MS_LOG(ERROR) << "Invalid input size " << inputs.size(); return RET_PARAM_INVALID; } @@ -268,6 +348,10 @@ int StridedSlice::InferShape(std::vector inputs, std::vectorshape(); auto inferflag = GetInferFlag(); + in_shape_.clear(); + begins_.clear(); + ends_.clear(); + strides_.clear(); if (inputs.size() == kStridedSliceInputNum) { ndim_ = static_cast(GetBegin().size()); @@ -279,7 +363,9 @@ int StridedSlice::InferShape(std::vector inputs, std::vector(begin_tensor->MutableData()); auto end_tensor = inputs.at(2); @@ -299,6 +385,13 @@ int StridedSlice::InferShape(std::vector inputs, std::vector inputs, std::vector new_axis_mask_; std::vector shrink_axis_mask_; void TransIndexToPositive(); + int HandleAxesInputExist(const std::vector &inputs); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc index 9155a85af1..a01f53e381 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc @@ -16,11 +16,11 @@ #include "src/runtime/kernel/arm/base/strided_slice.h" #include -#include "nnacl/strided_slice.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" +#include "src/ops/populate/populate_register.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -44,16 +44,16 @@ int StridedSliceCPUKernel::Init() { } int StridedSliceCPUKernel::ReSize() { - auto input = in_tensors_.at(0); - auto parameter = reinterpret_cast(op_parameter_); - MS_ASSERT(input); - MS_ASSERT(parameter); - parameter->data_type = input->data_type() == kNumberTypeInt8 ? kDataTypeInt8 : kDataTypeFloat; - auto input_shape = input->shape(); - for (size_t i = 0; i < input_shape.size(); ++i) { - parameter->in_shape_[i] = input_shape[i]; + if (op_parameter_ != nullptr) { + free(op_parameter_); + op_parameter_ = nullptr; + } + op_parameter_ = PopulateStridedSliceParameter(primitive_); + if (op_parameter_ == nullptr) { + MS_LOG(ERROR) << "Malloc parameter failed"; + return RET_ERROR; } - parameter->in_shape_length_ = static_cast(input_shape.size()); + param_ = reinterpret_cast(op_parameter_); return RET_OK; } @@ -62,8 +62,7 @@ int StridedSliceCPUKernel::HandleMultiInputs() { MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size(); return RET_ERROR; } - auto param = reinterpret_cast(op_parameter_); - if (param == nullptr) { + if (param_ == nullptr) { MS_LOG(ERROR) << "StridedSliceParamater cast nullptr"; return RET_ERROR; } @@ -74,35 +73,49 @@ int StridedSliceCPUKernel::HandleMultiInputs() { MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num; return RET_ERROR; } - memcpy(param->begins_, begins->MutableData(), axis_num * sizeof(int)); + memcpy(param_->begins_, begins->MutableData(), axis_num * sizeof(int)); auto ends = in_tensors_.at(kEndsIndex); MS_ASSERT(ends != nullptr); MS_ASSERT(axis_num == ends->ElementsNum()); - memcpy(param->ends_, ends->MutableData(), axis_num * sizeof(int)); + memcpy(param_->ends_, ends->MutableData(), axis_num * sizeof(int)); auto strides = in_tensors_.at(kStridesInex); MS_ASSERT(strides != nullptr); MS_ASSERT(axis_num == strides->ElementsNum()); - memcpy(param->strides_, strides->MutableData(), axis_num * sizeof(int)); + memcpy(param_->strides_, strides->MutableData(), axis_num * sizeof(int)); - param->num_axes_ = axis_num; + param_->num_axes_ = axis_num; return RET_OK; } int StridedSliceCPUKernel::Run() { auto input = in_tensors_.at(0); - auto output = out_tensors_.at(0); MS_ASSERT(input); + switch (input->data_type()) { + case kNumberTypeInt8: + param_->data_type = kDataTypeInt8; + break; + case kNumberTypeFloat32: + param_->data_type = kDataTypeFloat; + break; + case kNumberTypeInt32: + param_->data_type = kDataTypeInt; + break; + default: + MS_LOG(ERROR) << "Not supported data type: " << input->data_type(); + return RET_ERROR; + } + auto output = out_tensors_.at(0); MS_ASSERT(output); + // inputs order: input, begin, end, stride if (in_tensors().size() == kMultiInputsSize) { auto ret = HandleMultiInputs(); if (ret != RET_OK) { return ret; } } - auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), - reinterpret_cast(op_parameter_)); + auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_); if (ret != RET_OK) { MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h index 7a0c835e96..0a1751eda5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_BACKEND_ARM_BASE_STRIDED_SLICE_H_ #include - +#include "nnacl/strided_slice.h" #include "src/lite_kernel.h" namespace mindspore::kernel { @@ -27,7 +27,9 @@ class StridedSliceCPUKernel : public LiteKernel { StridedSliceCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param_ = reinterpret_cast(parameter); + } ~StridedSliceCPUKernel() override = default; int Init() override; @@ -36,6 +38,9 @@ class StridedSliceCPUKernel : public LiteKernel { private: int HandleMultiInputs(); + + private: + StridedSliceParameter *param_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index 04e9248fc0..7ac0833bb2 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -25,9 +25,3 @@ psenet_lite_mbv2.onnx;1,32,32,3 super-resolution-10.onnx;1,224,224,1 tinyyolov2-8.onnx;1,416,416,3 ml_2012_ocr_cn.onnx -ml_2012_ocr_cn_noLSTM.onnx -candy-9.onnx -mosaic-9.onnx -pointilism-9.onnx -rain-princess-9.onnx -udnie-9.onnx diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index acde0c356a..d4c7a3375a 100644 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -62,7 +62,7 @@ function Run_Converter() { if [ $? = 0 ]; then converter_result='converter onnx '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} else - converter_result='converter onnx '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file} + converter_result='converter onnx '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file};return 1 fi done < ${models_onnx_config} diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc index 913567335c..fb558104cc 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -64,6 +64,9 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { param_value->set_tensor_addr(tensor_data); param_value->set_tensor_size(size); parameter->set_default_param(param_value); + } else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) == + meta_graph_->inputIndex.end()) { + parameter->set_default_param(param_value); } AddNode(i, parameter); } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 11293f1fb4..0d3125e32b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -56,11 +56,11 @@ std::vector OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo return dims; } -STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { +STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph) { MS_LOG(DEBUG) << "set onnx constant tensors"; for (const auto &onnx_const_value : onnx_graph.initializer()) { int index; - const auto status = AddTensorProto(onnx_const_value, onnx_const_value.name(), GRAPH_INPUT, tensor_cache, &index); + const auto status = AddTensorProto(onnx_const_value, onnx_const_value.name(), GRAPH_INPUT, &index); if (status != RET_OK) { return status; } @@ -77,7 +77,7 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, if (attr.name() == "value") { const auto &t = attr.t(); int index; - const auto status = AddTensorProto(t, node.output(0), GRAPH_INPUT, tensor_cache, &index); + const auto status = AddTensorProto(t, node.output(0), GRAPH_INPUT, &index); if (status != RET_OK) { return status; } @@ -93,7 +93,7 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, } STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, - TensorCache *tensor_cache, int *index) { + int *index) { auto data_type = GetDataTypeFromOnnx(static_cast(proto.type().tensor_type().elem_type())); if (data_type == kTypeUnknown) { MS_LOG(ERROR) << "not support onnx data type " @@ -109,12 +109,12 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st tensor->dims = GetDimsFromOnnxValue(proto); tensor->format = schema::Format::Format_NCHW; tensor->nodeType = schema::NodeType::NodeType_ValueNode; - *index = tensor_cache->AddTensor(name, tensor.release(), type); + *index = OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(name, tensor.release(), type); return RET_OK; } STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, - TensorCache *tensor_cache, int *index) { + int *index) { auto data_type = GetDataTypeFromOnnx(static_cast(proto.data_type())); if (data_type == kTypeUnknown) { MS_LOG(ERROR) << "not support onnx data type " << static_cast(proto.data_type()); @@ -137,17 +137,16 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std if (data_type == kNumberTypeInt64) { tensor->dataType = kNumberTypeInt32; // CopyOnnxTensorData will convert int64 to int32 } - *index = tensor_cache->AddTensor(name, tensor.release(), type); + *index = OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(name, tensor.release(), type); return RET_OK; } -STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, - TensorCache *tensor_cache) { +STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph) { for (const auto &input_value : onnx_graph.input()) { - auto ret = tensor_cache->FindTensor(input_value.name()); + auto ret = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(input_value.name()); if (ret < 0) { int index; - const auto status = AddValueInfo(input_value, input_value.name(), GRAPH_INPUT, tensor_cache, &index); + const auto status = AddValueInfo(input_value, input_value.name(), GRAPH_INPUT, &index); if (status != RET_OK) { return status; } @@ -158,14 +157,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, return RET_OK; } -STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, - TensorCache *tensor_cache) { +STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph) { for (const auto &output_value : onnx_graph.output()) { int index; - if (tensor_cache->FindTensor(output_value.name()) != -1) { - index = tensor_cache->FindTensor(output_value.name()); + if (OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(output_value.name()) != -1) { + index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(output_value.name()); } else { - const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index); + const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, &index); if (status != RET_OK) { return status; } @@ -178,7 +176,7 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, - TensorCache *tensor_cache, const QuantType &quant_type) { + const QuantType &quant_type) { std::unique_ptr dst_op_1 = std::make_unique(); dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); dst_op_1->quantType = quant_type; @@ -186,8 +184,8 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); std::vector matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; std::vector matmul_outputs{matmul_output_id}; - SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node, tensor_cache); - SetOpOutputIndex(matmul_outputs, dst_op_1.get(), tensor_cache); + SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node); + SetOpOutputIndex(matmul_outputs, dst_op_1.get()); graph->nodes.emplace_back(std::move(dst_op_1)); sub_graph->nodeIndices.push_back(graph->nodes.size() - 1); @@ -197,15 +195,15 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); std::vector biasadd_inputs{matmul_output_id, onnx_node.input(2)}; std::vector biasadd_outputs{onnx_node.output(0)}; - SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node, tensor_cache); - SetOpOutputIndex(biasadd_outputs, dst_op_2.get(), tensor_cache); + SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node); + SetOpOutputIndex(biasadd_outputs, dst_op_2.get()); graph->nodes.emplace_back(std::move(dst_op_2)); sub_graph->nodeIndices.push_back(graph->nodes.size() - 1); } -STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { +STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node) { // convert GivenTensorFill node to a weight/bias tensor - auto ret = tensor_cache->FindTensor(onnx_node.output(0)); + auto ret = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node.output(0)); if (ret < 0) { std::unique_ptr tensor = std::make_unique(); std::vector shape; @@ -259,15 +257,16 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, return RET_ERROR; } } - auto index = tensor_cache->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT); + auto index = + OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT); MS_LOG(DEBUG) << "add given tensor: " << index; } return RET_OK; } STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, TensorCache *tensor_cache, - const QuantType &quantType, schema::MetaGraphT *dst_graph) { + schema::CNodeT *dst_op, const QuantType &quantType, + schema::MetaGraphT *dst_graph) { // change op_type() to name(), that is unique static bool interrupt = false; dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); @@ -308,41 +307,43 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, // set op input index std::vector node_inputs; (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); - if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { + if (SetOpInputIndex(node_inputs, dst_op, onnx_node)) { interrupt = true; MS_LOG(ERROR) << "SetOpInputIndex failed"; return RET_ERROR; } if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) { - auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); + auto &weight_tensor = + OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); weight_tensor->format = dst_op->primitive->value.AsConv2D()->format; } else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) { - auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); + auto &weight_tensor = + OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format; } // set op output index std::vector node_outputs; (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); - if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { + if (SetOpOutputIndex(node_outputs, dst_op) != RET_OK) { interrupt = true; MS_LOG(ERROR) << "SetOpOutputIndex failed"; return RET_ERROR; } - auto &output_tensor = tensor_cache->GetCachedTensor().at(dst_op->outputIndex.front()); + auto &output_tensor = + OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->outputIndex.front()); if (output_tensor == nullptr) { interrupt = true; MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr."; return RET_ERROR; } - SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor, tensor_cache); + SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor); return RET_OK; } void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { + schema::CNodeT *dst_op, schema::TensorT *dst_tensor) { MS_ASSERT(dst_op != nullptr); - MS_ASSERT(tensor_cache != nullptr); std::vector quant_node_name; quant_node_name.insert(quant_node_name.begin(), onnx_node.input().begin(), onnx_node.input().end()); quant_node_name.insert(quant_node_name.end(), onnx_node.output().begin(), onnx_node.output().end()); @@ -404,10 +405,10 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co } STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, - const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { + const onnx::NodeProto &onnx_node) { for (const auto &onnx_node_input : node_inputs) { if (onnx_node_input != "") { - int index = tensor_cache->FindTensor(onnx_node_input); + int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node_input); if (index < 0) { MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; return RET_ERROR; @@ -419,14 +420,14 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, return RET_OK; } -STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, - TensorCache *tensor_cache) { +STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op) { for (const auto &onnx_node_output : node_outputs) { - auto index = tensor_cache->FindTensor(onnx_node_output); + auto index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node_output); if (index < 0) { // when index >= 0, it's graph's output std::unique_ptr tensor = std::make_unique(); tensor->nodeType = schema::NodeType_Parameter; - index = tensor_cache->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); + index = + OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); } MS_LOG(DEBUG) << "node: " << onnx_node_output << ", output index: " << index; dst_op->outputIndex.emplace_back(index); @@ -495,8 +496,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v return RET_OK; } -STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) { - std::vector tensors = tensor_cache.GetCachedTensor(); +STATUS OnnxModelParser::SetAllTensors(schema::MetaGraphT *graphDef) { + std::vector tensors = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor(); for (auto iter : tensors) { std::unique_ptr temp(iter); graphDef->allTensors.emplace_back(move(temp)); @@ -549,12 +550,11 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, const QuantType &quantType) { - TensorCache tensor_cache; // dst_graph->name = onnx_graph.name(); // this is not used // find out input names and const names FindGraphInputAndConst(onnx_graph); // set const tensor - int status = SetGraphConstTensor(onnx_graph, &tensor_cache); + int status = SetGraphConstTensor(onnx_graph); if (status != RET_OK) { MS_LOG(ERROR) << "SetGraphConstTensor failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -563,7 +563,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT // init onnx model graph input tensor - status = SetGraphInputTensor(onnx_graph, dst_sub_graph, &tensor_cache); + status = SetGraphInputTensor(onnx_graph, dst_sub_graph); if (status != RET_OK) { MS_LOG(ERROR) << "SetGraphInputTensor failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -579,12 +579,12 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT } if (onnx_node.op_type() == "Gemm") { if (status == RET_OK) { - ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, &tensor_cache, quantType); + ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, quantType); } continue; } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { if (status == RET_OK) { - status_node = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); + status_node = ParseOnnxGivenFillNode(onnx_node); if (status_node != RET_OK) { MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status_node; status = (status == RET_OK ? status_node : status); @@ -594,7 +594,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT } std::unique_ptr dst_op = std::make_unique(); - status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), &tensor_cache, quantType, dst_graph); + status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), quantType, dst_graph); if (status_node != RET_OK) { status = (status == RET_OK ? status_node : status); continue; @@ -604,19 +604,19 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT } if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - for (auto &tensor : tensor_cache.GetCachedTensor()) { + for (auto &tensor : OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()) { delete tensor; } return RET_ERROR; } // init onnx model graph output tensor - status = SetGraphOutputTensor(onnx_graph, dst_sub_graph, &tensor_cache); + status = SetGraphOutputTensor(onnx_graph, dst_sub_graph); if (status != RET_OK) { MS_LOG(ERROR) << "SetGraphOutputTensor failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return RET_ERROR; } - SetAllTensors(tensor_cache, dst_graph); + SetAllTensors(dst_graph); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 42200806db..420f4661ec 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -30,7 +30,7 @@ #include "securec/include/securec.h" #include "tools/converter/model_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" -#include "tools/common/tensor_util.h" +#include "tools/converter/parser/onnx/onnx_tensor_parser.h" #include "proto/onnx.pb.h" namespace mindspore { @@ -53,42 +53,38 @@ class OnnxModelParser : public ModelParser { private: std::vector GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); - STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); + STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); - STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache); + STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); - STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache); + STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); - STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, - TensorCache *tensor_cache, int *index); + STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index); - STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, - TensorCache *tensor_cache, int *index); + STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index); STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, TensorCache *tensor_cache, const QuantType &quantType, - schema::MetaGraphT *dst_graph); + schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph); void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache, - const QuantType &quant_type); + schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type); - STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); + STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node); STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, const string &onnx_op_type, schema::CNodeT *dst_op); void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, - schema::TensorT *dst_tensor, TensorCache *tensor_cache); + schema::TensorT *dst_tensor); STATUS SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, - const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); + const onnx::NodeProto &onnx_node); - STATUS SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); + STATUS SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op); STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); - STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); + STATUS SetAllTensors(schema::MetaGraphT *graphDef); void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc index 9629057ab4..5672819be9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -17,9 +17,35 @@ #include "tools/converter/parser/onnx/onnx_slice_parser.h" #include #include +#include namespace mindspore { namespace lite { +STATUS OnnxSliceParser::InsertTensor(const std::vector &onnx_val, const std::string &name, + onnx::NodeProto *onnx_node) { + std::unique_ptr tensor = std::make_unique(); + if (tensor == nullptr) { + MS_LOG(ERROR) << "new tensor failed"; + return RET_ERROR; + } + tensor->dataType = mindspore::kNumberTypeInt32; + tensor->dims.push_back(onnx_val.size()); + tensor->format = schema::Format::Format_NCHW; + tensor->nodeType = schema::NodeType::NodeType_ValueNode; + int data_size = sizeof(int32_t) * onnx_val.size(); + tensor->data.resize(data_size); + if (data_size != 0 && + memcpy_s(static_cast(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size(); + std::string tensor_name = name + std::to_string(tensor_num); + OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT); + onnx_node->add_input(tensor_name); + return RET_OK; +} + STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SliceParser"; @@ -33,15 +59,15 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No return RET_NULL_PTR; } - std::unique_ptr attr = std::make_unique(); + std::unique_ptr attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - std::vector axes; std::vector starts; std::vector ends; + std::vector axes; std::vector steps; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); @@ -71,64 +97,49 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No } } } - - if (onnx_node.input_size() > 1) { - const auto &starts_name = onnx_node.input(1); - for (const auto &it : onnx_graph.initializer()) { - if (it.name() == starts_name) { - starts.clear(); - for (int i = 0; i < it.int32_data_size(); ++i) { - starts.push_back(it.int32_data(i)); - } - } + if (axes.empty()) { + for (size_t i = 0; i < starts.size(); ++i) { + axes.push_back(i); } } - - if (onnx_node.input_size() > 2) { - const auto &ends_name = onnx_node.input(2); - for (const auto &it : onnx_graph.initializer()) { - if (it.name() == ends_name) { - ends.clear(); - for (int i = 0; i < it.int32_data_size(); ++i) { - ends.push_back(it.int32_data(i)); - } - } - } + if (steps.empty()) { + steps.assign(starts.size(), 1); } - - if (onnx_node.input_size() > 3) { - const auto &axes_name = onnx_node.input(3); - for (const auto &it : onnx_graph.initializer()) { - if (it.name() == axes_name) { - axes.clear(); - for (int i = 0; i < it.int32_data_size(); ++i) { - axes.push_back(it.int32_data(i)); - } - } + onnx::NodeProto *slice_node = nullptr; + for (auto &node : onnx_graph.node()) { + if (&node == &onnx_node) { + slice_node = const_cast(&node); } } - - if (onnx_node.input_size() > 4) { - const auto &steps_name = onnx_node.input(4); - for (const auto &it : onnx_graph.initializer()) { - if (it.name() == steps_name) { - steps.clear(); - for (int i = 0; i < it.int32_data_size(); ++i) { - steps.push_back(it.int32_data(i)); - } - } + int insert_num = 5 - onnx_node.input_size(); + int status = RET_OK; + switch (insert_num) { + case 4: { + std::string name = "slice/starts/"; + status = InsertTensor(starts, name, slice_node); } + case 3: + if (status == RET_OK) { + std::string name = "slice/ends/"; + status = InsertTensor(ends, name, slice_node); + } + case 2: + if (status == RET_OK) { + std::string name = "slice/axes/"; + status = InsertTensor(axes, name, slice_node); + } + case 1: + if (status == RET_OK) { + std::string name = "slice/steps/"; + status = InsertTensor(steps, name, slice_node); + } + default: + if (status != RET_OK) { + MS_LOG(ERROR) << "onnx slice insert tensor failed"; + return RET_ERROR; + } } - - std::vector sizes(starts.size(), -1); - for (size_t i = 0; i < starts.size(); ++i) { - sizes[i] = (ends[i] < 0 ? ends[i] : ends[i] - starts[i]); - } - attr->axes = axes; - attr->begin = starts; - attr->size = sizes; - attr->step = steps; - op->primitive->value.type = schema::PrimitiveType_Slice; + op->primitive->value.type = schema::PrimitiveType_StridedSlice; op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h index 5b86d32dcc..83fde3ea95 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h @@ -17,8 +17,11 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H +#include +#include #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" +#include "tools/converter/parser/onnx/onnx_tensor_parser.h" namespace mindspore { namespace lite { @@ -27,6 +30,7 @@ class OnnxSliceParser : public OnnxNodeParser { OnnxSliceParser() : OnnxNodeParser("Slice") {} ~OnnxSliceParser() override = default; + STATUS InsertTensor(const std::vector &onnx_val, const std::string &name, onnx::NodeProto *onnx_node); STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; }; } // namespace lite diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tensor_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_tensor_parser.h new file mode 100644 index 0000000000..ad9f66ab28 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tensor_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H + +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class OnnxTensorParser { + public: + ~OnnxTensorParser() = default; + static OnnxTensorParser *GetInstance() { + static OnnxTensorParser onnxTensorParser; + return &onnxTensorParser; + } + TensorCache *GetTensorCache() { return &tensor_cache_; } + + private: + OnnxTensorParser() = default; + TensorCache tensor_cache_; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H