From: @zhaozhenlong Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -104,8 +104,12 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p | |||
| dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5; | |||
| if (param->data_type == kDataTypeFloat) { | |||
| *((float *)out_data + out_offset) = *((float *)in_data + in_offset); | |||
| } else { | |||
| } else if (param->data_type == kDataTypeInt8) { | |||
| *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); | |||
| } else if (param->data_type == kDataTypeInt) { | |||
| *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); | |||
| } else { | |||
| return NNACL_ERR; | |||
| } | |||
| out_offset++; | |||
| } | |||
| @@ -87,7 +87,7 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso | |||
| } | |||
| auto in_data = reinterpret_cast<int *>(in_tensor->data_c()); | |||
| if (in_data == nullptr) { | |||
| MS_LOG(ERROR) << "Input data is nullptr"; | |||
| MS_LOG(INFO) << "Input data is nullptr. Input tensor has not been calculated out yet."; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| int size = in_tensor->ElementsNum(); | |||
| @@ -50,6 +50,7 @@ class Registry { | |||
| } | |||
| }; | |||
| OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive); | |||
| OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *primitive); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -41,6 +41,7 @@ OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *pr | |||
| memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int)); | |||
| auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape(); | |||
| memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); | |||
| strided_slice_param->in_shape_length_ = static_cast<int>(in_shape.size()); | |||
| return reinterpret_cast<OpParameter *>(strided_slice_param); | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "src/ops/strided_slice.h" | |||
| #include <algorithm> | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| @@ -172,7 +173,8 @@ Registry StridedSliceRegistry(schema::PrimitiveType_StridedSlice, StridedSliceCr | |||
| namespace { | |||
| constexpr size_t kStridedSliceOutputNum = 1; | |||
| constexpr size_t kStridedSliceInputNum = 1; | |||
| constexpr size_t kStridedSliceMultiInputNum = 4; | |||
| constexpr size_t kStridedSliceMultiInputNumMin = 3; | |||
| constexpr size_t kStridedSliceMultiInputNumMax = 5; | |||
| } // namespace | |||
| void StridedSlice::ApplyNewAxisMask() { | |||
| @@ -251,13 +253,91 @@ void StridedSlice::TransIndexToPositive() { | |||
| } | |||
| } | |||
| int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs) { | |||
| // when axes input exist: | |||
| // input order: data, begin, end, axes(opt), stride(opt) | |||
| auto input_tensor = inputs.at(0); | |||
| MS_ASSERT(input_tensor != nullptr); | |||
| auto begin_tensor = inputs.at(1); | |||
| MS_ASSERT(begin_tensor != nullptr); | |||
| int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData()); | |||
| auto end_tensor = inputs.at(2); | |||
| MS_ASSERT(end_tensor != nullptr); | |||
| int *end_data = reinterpret_cast<int *>(end_tensor->MutableData()); | |||
| if (begin_data == nullptr || end_data == nullptr) { | |||
| return RET_INFER_ERR; | |||
| } | |||
| // when input contains axes, begins, ends, strides will be expand to the same length as input rank | |||
| ndim_ = static_cast<int>(input_tensor->shape().size()); | |||
| int begin_ndim = begin_tensor->ElementsNum(); | |||
| int *axes_data = nullptr; | |||
| auto axes_tensor = inputs.at(3); | |||
| if (axes_tensor->ElementsNum() != 0) { | |||
| MS_ASSERT(axes_tensor->ElementsNum() == begin_ndim); | |||
| axes_data = reinterpret_cast<int *>(axes_tensor->MutableData()); | |||
| if (axes_data == nullptr) { | |||
| return RET_INFER_ERR; | |||
| } | |||
| } | |||
| int *stride_data = nullptr; | |||
| auto stride_tensor = inputs.at(4); | |||
| if (stride_tensor->ElementsNum() != 0) { | |||
| MS_ASSERT(stride_tensor->ElementsNum() == begin_ndim); | |||
| stride_data = reinterpret_cast<int *>(stride_tensor->MutableData()); | |||
| if (stride_data == nullptr) { | |||
| return RET_INFER_ERR; | |||
| } | |||
| } | |||
| std::vector<int> axes; | |||
| if (axes_data == nullptr) { | |||
| for (int i = 0; i < begin_ndim; ++i) { | |||
| axes[i] = i; | |||
| } | |||
| } else { | |||
| axes.assign(axes_data, axes_data + begin_ndim); | |||
| for (int i = 0; i < begin_ndim; ++i) { | |||
| if (axes[i] < 0) { | |||
| axes[i] += ndim_; | |||
| } | |||
| } | |||
| } | |||
| in_shape_.assign(ndim_, 0); | |||
| begins_.assign(ndim_, 0); | |||
| ends_.assign(ndim_, 0); | |||
| strides_.assign(ndim_, 0); | |||
| auto input_shape = input_tensor->shape(); | |||
| for (int i = 0; i < ndim_; ++i) { | |||
| in_shape_[i] = input_shape.at(i); | |||
| } | |||
| for (int i = 0; i < ndim_; ++i) { | |||
| auto axes_it = std::find(axes.begin(), axes.end(), i); | |||
| if (axes_it != axes.end()) { | |||
| auto axis = axes_it - axes.begin(); | |||
| // begins or ends exceed limit will be set to limit | |||
| begins_[i] = std::max(std::min(begin_data[axis], input_shape[i] - 1), -input_shape[i]); | |||
| ends_[i] = std::max(std::min(end_data[axis], input_shape[i]), -input_shape[i] - 1); | |||
| strides_[i] = stride_data[axis]; | |||
| } else { | |||
| begins_[i] = 0; | |||
| ends_[i] = input_shape[i]; | |||
| strides_[i] = 1; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| MS_ASSERT(this->primitive_ != nullptr); | |||
| if (outputs.size() != kStridedSliceOutputNum) { | |||
| MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (inputs.size() != kStridedSliceInputNum && inputs.size() != kStridedSliceMultiInputNum) { | |||
| if (inputs.size() != kStridedSliceInputNum && | |||
| !(inputs.size() <= kStridedSliceMultiInputNumMax && inputs.size() >= kStridedSliceMultiInputNumMin)) { | |||
| MS_LOG(ERROR) << "Invalid input size " << inputs.size(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| @@ -268,6 +348,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit | |||
| auto input_shape = input->shape(); | |||
| auto inferflag = GetInferFlag(); | |||
| in_shape_.clear(); | |||
| begins_.clear(); | |||
| ends_.clear(); | |||
| strides_.clear(); | |||
| if (inputs.size() == kStridedSliceInputNum) { | |||
| ndim_ = static_cast<int>(GetBegin().size()); | |||
| @@ -279,7 +363,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit | |||
| ends_.emplace_back((GetEnd())[i]); | |||
| strides_.emplace_back((GetStride())[i]); | |||
| } | |||
| } else { | |||
| } | |||
| if (inputs.size() == 4) { | |||
| // input order: input, begins, ends, strides. | |||
| auto begin_tensor = inputs.at(1); | |||
| int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData()); | |||
| auto end_tensor = inputs.at(2); | |||
| @@ -299,6 +385,13 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit | |||
| strides_.emplace_back(stride_data[i]); | |||
| } | |||
| } | |||
| if (inputs.size() == 5) { | |||
| // input order: input, begins, end, axes, strides | |||
| auto ret = HandleAxesInputExist(inputs); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| } | |||
| // set all mask to original input shape | |||
| begins_mask_.resize(ndim_); | |||
| @@ -333,7 +426,12 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit | |||
| if (i < ndim_ && new_axis_mask_.at(i)) { | |||
| output_shape.at(i) = 1; | |||
| } else { | |||
| output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); | |||
| if (strides_.at(i) == 0) { | |||
| MS_LOG(ERROR) << "strides should not be 0."; | |||
| return RET_INFER_ERR; | |||
| } | |||
| output_shape.at(i) = | |||
| (ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i); | |||
| } | |||
| } | |||
| @@ -81,6 +81,7 @@ class StridedSlice : public PrimitiveC { | |||
| std::vector<bool> new_axis_mask_; | |||
| std::vector<bool> shrink_axis_mask_; | |||
| void TransIndexToPositive(); | |||
| int HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -16,11 +16,11 @@ | |||
| #include "src/runtime/kernel/arm/base/strided_slice.h" | |||
| #include <vector> | |||
| #include "nnacl/strided_slice.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -44,16 +44,16 @@ int StridedSliceCPUKernel::Init() { | |||
| } | |||
| int StridedSliceCPUKernel::ReSize() { | |||
| auto input = in_tensors_.at(0); | |||
| auto parameter = reinterpret_cast<StridedSliceParameter *>(op_parameter_); | |||
| MS_ASSERT(input); | |||
| MS_ASSERT(parameter); | |||
| parameter->data_type = input->data_type() == kNumberTypeInt8 ? kDataTypeInt8 : kDataTypeFloat; | |||
| auto input_shape = input->shape(); | |||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||
| parameter->in_shape_[i] = input_shape[i]; | |||
| if (op_parameter_ != nullptr) { | |||
| free(op_parameter_); | |||
| op_parameter_ = nullptr; | |||
| } | |||
| op_parameter_ = PopulateStridedSliceParameter(primitive_); | |||
| if (op_parameter_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc parameter failed"; | |||
| return RET_ERROR; | |||
| } | |||
| parameter->in_shape_length_ = static_cast<int>(input_shape.size()); | |||
| param_ = reinterpret_cast<StridedSliceParameter *>(op_parameter_); | |||
| return RET_OK; | |||
| } | |||
| @@ -62,8 +62,7 @@ int StridedSliceCPUKernel::HandleMultiInputs() { | |||
| MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto param = reinterpret_cast<StridedSliceParameter *>(op_parameter_); | |||
| if (param == nullptr) { | |||
| if (param_ == nullptr) { | |||
| MS_LOG(ERROR) << "StridedSliceParamater cast nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -74,35 +73,49 @@ int StridedSliceCPUKernel::HandleMultiInputs() { | |||
| MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num; | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(param->begins_, begins->MutableData(), axis_num * sizeof(int)); | |||
| memcpy(param_->begins_, begins->MutableData(), axis_num * sizeof(int)); | |||
| auto ends = in_tensors_.at(kEndsIndex); | |||
| MS_ASSERT(ends != nullptr); | |||
| MS_ASSERT(axis_num == ends->ElementsNum()); | |||
| memcpy(param->ends_, ends->MutableData(), axis_num * sizeof(int)); | |||
| memcpy(param_->ends_, ends->MutableData(), axis_num * sizeof(int)); | |||
| auto strides = in_tensors_.at(kStridesInex); | |||
| MS_ASSERT(strides != nullptr); | |||
| MS_ASSERT(axis_num == strides->ElementsNum()); | |||
| memcpy(param->strides_, strides->MutableData(), axis_num * sizeof(int)); | |||
| memcpy(param_->strides_, strides->MutableData(), axis_num * sizeof(int)); | |||
| param->num_axes_ = axis_num; | |||
| param_->num_axes_ = axis_num; | |||
| return RET_OK; | |||
| } | |||
| int StridedSliceCPUKernel::Run() { | |||
| auto input = in_tensors_.at(0); | |||
| auto output = out_tensors_.at(0); | |||
| MS_ASSERT(input); | |||
| switch (input->data_type()) { | |||
| case kNumberTypeInt8: | |||
| param_->data_type = kDataTypeInt8; | |||
| break; | |||
| case kNumberTypeFloat32: | |||
| param_->data_type = kDataTypeFloat; | |||
| break; | |||
| case kNumberTypeInt32: | |||
| param_->data_type = kDataTypeInt; | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Not supported data type: " << input->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| auto output = out_tensors_.at(0); | |||
| MS_ASSERT(output); | |||
| // inputs order: input, begin, end, stride | |||
| if (in_tensors().size() == kMultiInputsSize) { | |||
| auto ret = HandleMultiInputs(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| } | |||
| auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), | |||
| reinterpret_cast<StridedSliceParameter *>(op_parameter_)); | |||
| auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_LITE_SRC_BACKEND_ARM_BASE_STRIDED_SLICE_H_ | |||
| #include <vector> | |||
| #include "nnacl/strided_slice.h" | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore::kernel { | |||
| @@ -27,7 +27,9 @@ class StridedSliceCPUKernel : public LiteKernel { | |||
| StridedSliceCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| param_ = reinterpret_cast<StridedSliceParameter *>(parameter); | |||
| } | |||
| ~StridedSliceCPUKernel() override = default; | |||
| int Init() override; | |||
| @@ -36,6 +38,9 @@ class StridedSliceCPUKernel : public LiteKernel { | |||
| private: | |||
| int HandleMultiInputs(); | |||
| private: | |||
| StridedSliceParameter *param_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -25,9 +25,3 @@ psenet_lite_mbv2.onnx;1,32,32,3 | |||
| super-resolution-10.onnx;1,224,224,1 | |||
| tinyyolov2-8.onnx;1,416,416,3 | |||
| ml_2012_ocr_cn.onnx | |||
| ml_2012_ocr_cn_noLSTM.onnx | |||
| candy-9.onnx | |||
| mosaic-9.onnx | |||
| pointilism-9.onnx | |||
| rain-princess-9.onnx | |||
| udnie-9.onnx | |||
| @@ -62,7 +62,7 @@ function Run_Converter() { | |||
| if [ $? = 0 ]; then | |||
| converter_result='converter onnx '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | |||
| else | |||
| converter_result='converter onnx '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file} | |||
| converter_result='converter onnx '${model_name}' failed';echo ${converter_result} >> ${run_converter_result_file};return 1 | |||
| fi | |||
| done < ${models_onnx_config} | |||
| @@ -64,6 +64,9 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| param_value->set_tensor_addr(tensor_data); | |||
| param_value->set_tensor_size(size); | |||
| parameter->set_default_param(param_value); | |||
| } else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) == | |||
| meta_graph_->inputIndex.end()) { | |||
| parameter->set_default_param(param_value); | |||
| } | |||
| AddNode(i, parameter); | |||
| } | |||
| @@ -56,11 +56,11 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo | |||
| return dims; | |||
| } | |||
| STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { | |||
| STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph) { | |||
| MS_LOG(DEBUG) << "set onnx constant tensors"; | |||
| for (const auto &onnx_const_value : onnx_graph.initializer()) { | |||
| int index; | |||
| const auto status = AddTensorProto(onnx_const_value, onnx_const_value.name(), GRAPH_INPUT, tensor_cache, &index); | |||
| const auto status = AddTensorProto(onnx_const_value, onnx_const_value.name(), GRAPH_INPUT, &index); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| @@ -77,7 +77,7 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, | |||
| if (attr.name() == "value") { | |||
| const auto &t = attr.t(); | |||
| int index; | |||
| const auto status = AddTensorProto(t, node.output(0), GRAPH_INPUT, tensor_cache, &index); | |||
| const auto status = AddTensorProto(t, node.output(0), GRAPH_INPUT, &index); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| @@ -93,7 +93,7 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, | |||
| } | |||
| STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, | |||
| TensorCache *tensor_cache, int *index) { | |||
| int *index) { | |||
| auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); | |||
| if (data_type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "not support onnx data type " | |||
| @@ -109,12 +109,12 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st | |||
| tensor->dims = GetDimsFromOnnxValue(proto); | |||
| tensor->format = schema::Format::Format_NCHW; | |||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| *index = tensor_cache->AddTensor(name, tensor.release(), type); | |||
| *index = OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(name, tensor.release(), type); | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, | |||
| TensorCache *tensor_cache, int *index) { | |||
| int *index) { | |||
| auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type())); | |||
| if (data_type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type()); | |||
| @@ -137,17 +137,16 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std | |||
| if (data_type == kNumberTypeInt64) { | |||
| tensor->dataType = kNumberTypeInt32; // CopyOnnxTensorData will convert int64 to int32 | |||
| } | |||
| *index = tensor_cache->AddTensor(name, tensor.release(), type); | |||
| *index = OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(name, tensor.release(), type); | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, | |||
| TensorCache *tensor_cache) { | |||
| STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph) { | |||
| for (const auto &input_value : onnx_graph.input()) { | |||
| auto ret = tensor_cache->FindTensor(input_value.name()); | |||
| auto ret = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(input_value.name()); | |||
| if (ret < 0) { | |||
| int index; | |||
| const auto status = AddValueInfo(input_value, input_value.name(), GRAPH_INPUT, tensor_cache, &index); | |||
| const auto status = AddValueInfo(input_value, input_value.name(), GRAPH_INPUT, &index); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| @@ -158,14 +157,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, | |||
| TensorCache *tensor_cache) { | |||
| STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph) { | |||
| for (const auto &output_value : onnx_graph.output()) { | |||
| int index; | |||
| if (tensor_cache->FindTensor(output_value.name()) != -1) { | |||
| index = tensor_cache->FindTensor(output_value.name()); | |||
| if (OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(output_value.name()) != -1) { | |||
| index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(output_value.name()); | |||
| } else { | |||
| const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index); | |||
| const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, &index); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| @@ -178,7 +176,7 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, | |||
| void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, | |||
| TensorCache *tensor_cache, const QuantType &quant_type) { | |||
| const QuantType &quant_type) { | |||
| std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>(); | |||
| dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); | |||
| dst_op_1->quantType = quant_type; | |||
| @@ -186,8 +184,8 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons | |||
| auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); | |||
| std::vector<string> matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; | |||
| std::vector<string> matmul_outputs{matmul_output_id}; | |||
| SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node, tensor_cache); | |||
| SetOpOutputIndex(matmul_outputs, dst_op_1.get(), tensor_cache); | |||
| SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node); | |||
| SetOpOutputIndex(matmul_outputs, dst_op_1.get()); | |||
| graph->nodes.emplace_back(std::move(dst_op_1)); | |||
| sub_graph->nodeIndices.push_back(graph->nodes.size() - 1); | |||
| @@ -197,15 +195,15 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons | |||
| ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); | |||
| std::vector<string> biasadd_inputs{matmul_output_id, onnx_node.input(2)}; | |||
| std::vector<string> biasadd_outputs{onnx_node.output(0)}; | |||
| SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node, tensor_cache); | |||
| SetOpOutputIndex(biasadd_outputs, dst_op_2.get(), tensor_cache); | |||
| SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node); | |||
| SetOpOutputIndex(biasadd_outputs, dst_op_2.get()); | |||
| graph->nodes.emplace_back(std::move(dst_op_2)); | |||
| sub_graph->nodeIndices.push_back(graph->nodes.size() - 1); | |||
| } | |||
| STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { | |||
| STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node) { | |||
| // convert GivenTensorFill node to a weight/bias tensor | |||
| auto ret = tensor_cache->FindTensor(onnx_node.output(0)); | |||
| auto ret = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node.output(0)); | |||
| if (ret < 0) { | |||
| std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>(); | |||
| std::vector<int> shape; | |||
| @@ -259,15 +257,16 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| auto index = tensor_cache->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT); | |||
| auto index = | |||
| OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT); | |||
| MS_LOG(DEBUG) << "add given tensor: " << index; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, TensorCache *tensor_cache, | |||
| const QuantType &quantType, schema::MetaGraphT *dst_graph) { | |||
| schema::CNodeT *dst_op, const QuantType &quantType, | |||
| schema::MetaGraphT *dst_graph) { | |||
| // change op_type() to name(), that is unique | |||
| static bool interrupt = false; | |||
| dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); | |||
| @@ -308,41 +307,43 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, | |||
| // set op input index | |||
| std::vector<string> node_inputs; | |||
| (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); | |||
| if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { | |||
| if (SetOpInputIndex(node_inputs, dst_op, onnx_node)) { | |||
| interrupt = true; | |||
| MS_LOG(ERROR) << "SetOpInputIndex failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) { | |||
| auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); | |||
| auto &weight_tensor = | |||
| OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); | |||
| weight_tensor->format = dst_op->primitive->value.AsConv2D()->format; | |||
| } else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) { | |||
| auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); | |||
| auto &weight_tensor = | |||
| OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); | |||
| weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format; | |||
| } | |||
| // set op output index | |||
| std::vector<string> node_outputs; | |||
| (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); | |||
| if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { | |||
| if (SetOpOutputIndex(node_outputs, dst_op) != RET_OK) { | |||
| interrupt = true; | |||
| MS_LOG(ERROR) << "SetOpOutputIndex failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto &output_tensor = tensor_cache->GetCachedTensor().at(dst_op->outputIndex.front()); | |||
| auto &output_tensor = | |||
| OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->outputIndex.front()); | |||
| if (output_tensor == nullptr) { | |||
| interrupt = true; | |||
| MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr."; | |||
| return RET_ERROR; | |||
| } | |||
| SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor, tensor_cache); | |||
| SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor); | |||
| return RET_OK; | |||
| } | |||
| void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { | |||
| schema::CNodeT *dst_op, schema::TensorT *dst_tensor) { | |||
| MS_ASSERT(dst_op != nullptr); | |||
| MS_ASSERT(tensor_cache != nullptr); | |||
| std::vector<string> quant_node_name; | |||
| quant_node_name.insert(quant_node_name.begin(), onnx_node.input().begin(), onnx_node.input().end()); | |||
| quant_node_name.insert(quant_node_name.end(), onnx_node.output().begin(), onnx_node.output().end()); | |||
| @@ -404,10 +405,10 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co | |||
| } | |||
| STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, | |||
| const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { | |||
| const onnx::NodeProto &onnx_node) { | |||
| for (const auto &onnx_node_input : node_inputs) { | |||
| if (onnx_node_input != "") { | |||
| int index = tensor_cache->FindTensor(onnx_node_input); | |||
| int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node_input); | |||
| if (index < 0) { | |||
| MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; | |||
| return RET_ERROR; | |||
| @@ -419,14 +420,14 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, | |||
| TensorCache *tensor_cache) { | |||
| STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op) { | |||
| for (const auto &onnx_node_output : node_outputs) { | |||
| auto index = tensor_cache->FindTensor(onnx_node_output); | |||
| auto index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node_output); | |||
| if (index < 0) { // when index >= 0, it's graph's output | |||
| std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>(); | |||
| tensor->nodeType = schema::NodeType_Parameter; | |||
| index = tensor_cache->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); | |||
| index = | |||
| OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); | |||
| } | |||
| MS_LOG(DEBUG) << "node: " << onnx_node_output << ", output index: " << index; | |||
| dst_op->outputIndex.emplace_back(index); | |||
| @@ -495,8 +496,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) { | |||
| std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor(); | |||
| STATUS OnnxModelParser::SetAllTensors(schema::MetaGraphT *graphDef) { | |||
| std::vector<schema::TensorT *> tensors = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor(); | |||
| for (auto iter : tensors) { | |||
| std::unique_ptr<schema::TensorT> temp(iter); | |||
| graphDef->allTensors.emplace_back(move(temp)); | |||
| @@ -549,12 +550,11 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr | |||
| int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, | |||
| const onnx::GraphProto &onnx_graph, const QuantType &quantType) { | |||
| TensorCache tensor_cache; | |||
| // dst_graph->name = onnx_graph.name(); // this is not used | |||
| // find out input names and const names | |||
| FindGraphInputAndConst(onnx_graph); | |||
| // set const tensor | |||
| int status = SetGraphConstTensor(onnx_graph, &tensor_cache); | |||
| int status = SetGraphConstTensor(onnx_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetGraphConstTensor failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| @@ -563,7 +563,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT | |||
| // init onnx model graph input tensor | |||
| status = SetGraphInputTensor(onnx_graph, dst_sub_graph, &tensor_cache); | |||
| status = SetGraphInputTensor(onnx_graph, dst_sub_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetGraphInputTensor failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| @@ -579,12 +579,12 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT | |||
| } | |||
| if (onnx_node.op_type() == "Gemm") { | |||
| if (status == RET_OK) { | |||
| ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, &tensor_cache, quantType); | |||
| ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, quantType); | |||
| } | |||
| continue; | |||
| } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { | |||
| if (status == RET_OK) { | |||
| status_node = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); | |||
| status_node = ParseOnnxGivenFillNode(onnx_node); | |||
| if (status_node != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status_node; | |||
| status = (status == RET_OK ? status_node : status); | |||
| @@ -594,7 +594,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT | |||
| } | |||
| std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>(); | |||
| status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), &tensor_cache, quantType, dst_graph); | |||
| status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), quantType, dst_graph); | |||
| if (status_node != RET_OK) { | |||
| status = (status == RET_OK ? status_node : status); | |||
| continue; | |||
| @@ -604,19 +604,19 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT | |||
| } | |||
| if (status != RET_OK) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| for (auto &tensor : tensor_cache.GetCachedTensor()) { | |||
| for (auto &tensor : OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()) { | |||
| delete tensor; | |||
| } | |||
| return RET_ERROR; | |||
| } | |||
| // init onnx model graph output tensor | |||
| status = SetGraphOutputTensor(onnx_graph, dst_sub_graph, &tensor_cache); | |||
| status = SetGraphOutputTensor(onnx_graph, dst_sub_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetGraphOutputTensor failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return RET_ERROR; | |||
| } | |||
| SetAllTensors(tensor_cache, dst_graph); | |||
| SetAllTensors(dst_graph); | |||
| return RET_OK; | |||
| } | |||
| @@ -30,7 +30,7 @@ | |||
| #include "securec/include/securec.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | |||
| #include "proto/onnx.pb.h" | |||
| namespace mindspore { | |||
| @@ -53,42 +53,38 @@ class OnnxModelParser : public ModelParser { | |||
| private: | |||
| std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | |||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); | |||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); | |||
| STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache); | |||
| STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); | |||
| STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache); | |||
| STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); | |||
| STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, | |||
| TensorCache *tensor_cache, int *index); | |||
| STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index); | |||
| STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, | |||
| TensorCache *tensor_cache, int *index); | |||
| STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index); | |||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, TensorCache *tensor_cache, const QuantType &quantType, | |||
| schema::MetaGraphT *dst_graph); | |||
| schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph); | |||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache, | |||
| const QuantType &quant_type); | |||
| schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type); | |||
| STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); | |||
| STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node); | |||
| STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| const string &onnx_op_type, schema::CNodeT *dst_op); | |||
| void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, | |||
| schema::TensorT *dst_tensor, TensorCache *tensor_cache); | |||
| schema::TensorT *dst_tensor); | |||
| STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, | |||
| const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); | |||
| const onnx::NodeProto &onnx_node); | |||
| STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); | |||
| STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op); | |||
| STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); | |||
| STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); | |||
| STATUS SetAllTensors(schema::MetaGraphT *graphDef); | |||
| void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); | |||
| @@ -17,9 +17,35 @@ | |||
| #include "tools/converter/parser/onnx/onnx_slice_parser.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std::string &name, | |||
| onnx::NodeProto *onnx_node) { | |||
| std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>(); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| tensor->dataType = mindspore::kNumberTypeInt32; | |||
| tensor->dims.push_back(onnx_val.size()); | |||
| tensor->format = schema::Format::Format_NCHW; | |||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| int data_size = sizeof(int32_t) * onnx_val.size(); | |||
| tensor->data.resize(data_size); | |||
| if (data_size != 0 && | |||
| memcpy_s(static_cast<void *>(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size(); | |||
| std::string tensor_name = name + std::to_string(tensor_num); | |||
| OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT); | |||
| onnx_node->add_input(tensor_name); | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx SliceParser"; | |||
| @@ -33,15 +59,15 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>(); | |||
| std::unique_ptr<schema::StridedSliceT> attr = std::make_unique<schema::StridedSliceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::vector<int> axes; | |||
| std::vector<int> starts; | |||
| std::vector<int> ends; | |||
| std::vector<int> axes; | |||
| std::vector<int> steps; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| @@ -71,64 +97,49 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||
| } | |||
| } | |||
| } | |||
| if (onnx_node.input_size() > 1) { | |||
| const auto &starts_name = onnx_node.input(1); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| if (it.name() == starts_name) { | |||
| starts.clear(); | |||
| for (int i = 0; i < it.int32_data_size(); ++i) { | |||
| starts.push_back(it.int32_data(i)); | |||
| } | |||
| } | |||
| if (axes.empty()) { | |||
| for (size_t i = 0; i < starts.size(); ++i) { | |||
| axes.push_back(i); | |||
| } | |||
| } | |||
| if (onnx_node.input_size() > 2) { | |||
| const auto &ends_name = onnx_node.input(2); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| if (it.name() == ends_name) { | |||
| ends.clear(); | |||
| for (int i = 0; i < it.int32_data_size(); ++i) { | |||
| ends.push_back(it.int32_data(i)); | |||
| } | |||
| } | |||
| } | |||
| if (steps.empty()) { | |||
| steps.assign(starts.size(), 1); | |||
| } | |||
| if (onnx_node.input_size() > 3) { | |||
| const auto &axes_name = onnx_node.input(3); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| if (it.name() == axes_name) { | |||
| axes.clear(); | |||
| for (int i = 0; i < it.int32_data_size(); ++i) { | |||
| axes.push_back(it.int32_data(i)); | |||
| } | |||
| } | |||
| onnx::NodeProto *slice_node = nullptr; | |||
| for (auto &node : onnx_graph.node()) { | |||
| if (&node == &onnx_node) { | |||
| slice_node = const_cast<onnx::NodeProto *>(&node); | |||
| } | |||
| } | |||
| if (onnx_node.input_size() > 4) { | |||
| const auto &steps_name = onnx_node.input(4); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| if (it.name() == steps_name) { | |||
| steps.clear(); | |||
| for (int i = 0; i < it.int32_data_size(); ++i) { | |||
| steps.push_back(it.int32_data(i)); | |||
| } | |||
| } | |||
| int insert_num = 5 - onnx_node.input_size(); | |||
| int status = RET_OK; | |||
| switch (insert_num) { | |||
| case 4: { | |||
| std::string name = "slice/starts/"; | |||
| status = InsertTensor(starts, name, slice_node); | |||
| } | |||
| case 3: | |||
| if (status == RET_OK) { | |||
| std::string name = "slice/ends/"; | |||
| status = InsertTensor(ends, name, slice_node); | |||
| } | |||
| case 2: | |||
| if (status == RET_OK) { | |||
| std::string name = "slice/axes/"; | |||
| status = InsertTensor(axes, name, slice_node); | |||
| } | |||
| case 1: | |||
| if (status == RET_OK) { | |||
| std::string name = "slice/steps/"; | |||
| status = InsertTensor(steps, name, slice_node); | |||
| } | |||
| default: | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "onnx slice insert tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| std::vector<int> sizes(starts.size(), -1); | |||
| for (size_t i = 0; i < starts.size(); ++i) { | |||
| sizes[i] = (ends[i] < 0 ? ends[i] : ends[i] - starts[i]); | |||
| } | |||
| attr->axes = axes; | |||
| attr->begin = starts; | |||
| attr->size = sizes; | |||
| attr->step = steps; | |||
| op->primitive->value.type = schema::PrimitiveType_Slice; | |||
| op->primitive->value.type = schema::PrimitiveType_StridedSlice; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| @@ -17,8 +17,11 @@ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -27,6 +30,7 @@ class OnnxSliceParser : public OnnxNodeParser { | |||
| OnnxSliceParser() : OnnxNodeParser("Slice") {} | |||
| ~OnnxSliceParser() override = default; | |||
| STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node); | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| }; | |||
| } // namespace lite | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H | |||
| #include "tools/common/tensor_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxTensorParser { | |||
| public: | |||
| ~OnnxTensorParser() = default; | |||
| static OnnxTensorParser *GetInstance() { | |||
| static OnnxTensorParser onnxTensorParser; | |||
| return &onnxTensorParser; | |||
| } | |||
| TensorCache *GetTensorCache() { return &tensor_cache_; } | |||
| private: | |||
| OnnxTensorParser() = default; | |||
| TensorCache tensor_cache_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H | |||