From fef75213ed06b861df468803a9643d3e49056851 Mon Sep 17 00:00:00 2001 From: wang_shaocong Date: Thu, 22 Oct 2020 10:02:58 +0800 Subject: [PATCH] Modify implementation of constant_of_shape. Fix bug of implementation of cast and reduce and infershape of slice and topk. --- mindspore/lite/nnacl/fp32/arithmetic.c | 16 ++++++++++ mindspore/lite/nnacl/fp32/arithmetic.h | 1 + mindspore/lite/nnacl/fp32/constant_of_shape.c | 11 +++++++ mindspore/lite/nnacl/fp32/constant_of_shape.h | 2 ++ mindspore/lite/nnacl/fp32/reduce.c | 21 +++++++++++++ mindspore/lite/nnacl/fp32/reduce.h | 2 ++ mindspore/lite/schema/ops.fbs | 3 +- mindspore/lite/src/ops/constant_of_shape.cc | 20 +++++++++--- mindspore/lite/src/ops/constant_of_shape.h | 4 +-- .../populate/constant_of_shape_populate.cc | 8 ++++- mindspore/lite/src/ops/slice.cc | 24 ++++++++++++++ mindspore/lite/src/ops/topk.cc | 5 ++- mindspore/lite/src/ops/unsqueeze.cc | 4 +-- .../src/runtime/kernel/arm/fp32/arithmetic.h | 1 + .../lite/src/runtime/kernel/arm/fp32/cast.cc | 3 ++ .../kernel/arm/fp32/constant_of_shape.cc | 26 ++++++++++++++-- .../kernel/arm/fp32/constant_of_shape.h | 2 +- .../kernel/arm/fp32/non_max_suppression.cc | 4 +-- .../src/runtime/kernel/arm/fp32/reduce.cc | 1 + .../lite/src/runtime/kernel/arm/fp32/topk.cc | 8 +++++ .../onnx/onnx_constant_of_shape_parser.cc | 27 +++++++++++----- .../parser/onnx/onnx_model_parser.cc | 7 ++--- .../converter/parser/onnx/onnx_node_parser.cc | 31 +++++++++++++++++++ .../converter/parser/onnx/onnx_node_parser.h | 2 ++ 24 files changed, 203 insertions(+), 30 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/arithmetic.c b/mindspore/lite/nnacl/fp32/arithmetic.c index 2b3fcff28d..3d8ec7fdcc 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic.c +++ b/mindspore/lite/nnacl/fp32/arithmetic.c @@ -657,6 +657,22 @@ int ElementAddRelu6(float *input0, float *input1, float *output, int element_siz return NNACL_OK; } +int ElementAddInt(int *input0, int *input1, int *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 4; index += C4NUM) { + int32x4_t vin0 = vld1q_s32(input0 + index); + int32x4_t vin1 = vld1q_s32(input1 + index); + int32x4_t vout = vaddq_s32(vin0, vin1); + vst1q_s32(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] + input1[index]; + } + return NNACL_OK; +} + int ElementAddInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { for (int i = 0; i < element_size; i++) { output[i] = input0[i] + input1[i]; diff --git a/mindspore/lite/nnacl/fp32/arithmetic.h b/mindspore/lite/nnacl/fp32/arithmetic.h index 7c2050a425..aac6a99120 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic.h +++ b/mindspore/lite/nnacl/fp32/arithmetic.h @@ -54,6 +54,7 @@ int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_i int ElementAdd(float *input0, float *input1, float *output, int element_size); int ElementAddRelu(float *input0, float *input1, float *output, int element_size); int ElementAddRelu6(float *input0, float *input1, float *output, int element_size); +int ElementAddInt(int *input0, int *input1, int *output, int element_size); int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param); int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output, diff --git a/mindspore/lite/nnacl/fp32/constant_of_shape.c b/mindspore/lite/nnacl/fp32/constant_of_shape.c index 1a176d0a6d..2a81cc71c2 100644 --- a/mindspore/lite/nnacl/fp32/constant_of_shape.c +++ b/mindspore/lite/nnacl/fp32/constant_of_shape.c @@ -26,3 +26,14 @@ int ConstantOfShape(float *output, int tid, ConstantOfShapeParameter *param) { } return NNACL_OK; } + +int ConstantOfShapeInt(int32_t *output, int tid, ConstantOfShapeParameter *param) { + int size = param->unit_; + float data = param->value_; + int ind_st = MSMIN(tid * size, param->element_sz_); + int ind_end = MSMIN(param->element_sz_, (tid + 1) * size); + for (int i = ind_st; i < ind_end; ++i) { + output[i] = data; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32/constant_of_shape.h b/mindspore/lite/nnacl/fp32/constant_of_shape.h index 355871776b..e34ec3b569 100644 --- a/mindspore/lite/nnacl/fp32/constant_of_shape.h +++ b/mindspore/lite/nnacl/fp32/constant_of_shape.h @@ -25,6 +25,7 @@ typedef struct ConstantOfShapeParameter { OpParameter op_parameter_; float value_; + int data_type_; int unit_; int element_sz_; } ConstantOfShapeParameter; @@ -33,6 +34,7 @@ typedef struct ConstantOfShapeParameter { extern "C" { #endif int ConstantOfShape(float *output, int tid, ConstantOfShapeParameter *param); +int ConstantOfShapeInt(int32_t *output, int tid, ConstantOfShapeParameter *param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/reduce.c b/mindspore/lite/nnacl/fp32/reduce.c index 0d8b74f5b4..156579a276 100644 --- a/mindspore/lite/nnacl/fp32/reduce.c +++ b/mindspore/lite/nnacl/fp32/reduce.c @@ -123,6 +123,27 @@ int ReduceMin(const int outer_size, const int inner_size, const int axis_size, c } return NNACL_OK; } +int IntReduceMin(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data, + const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int *outer_src = src_data + j * axis_size * inner_size; + int *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int *inner_src = outer_src + k; + int *inner_dst = outer_dst + k; + int tmp = INT32_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data, const int tid, const int thread_num) { if (src_data == NULL || dst_data == NULL) { diff --git a/mindspore/lite/nnacl/fp32/reduce.h b/mindspore/lite/nnacl/fp32/reduce.h index a7f7e730e1..cd58e638ed 100644 --- a/mindspore/lite/nnacl/fp32/reduce.h +++ b/mindspore/lite/nnacl/fp32/reduce.h @@ -30,6 +30,8 @@ int ReduceMax(const int outer_size, const int inner_size, const int axis_size, c const int tid, const int thread_num); int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data, const int tid, const int thread_num); +int IntReduceMin(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data, + const int tid, const int thread_num); int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data, const int tid, const int thread_num); int IntReduceProd(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 5a0b25a59a..27656e5b22 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -308,7 +308,8 @@ table Shape { } table ConstantOfShape{ - value: float = 0; + dataType: int; + value: [float]; } table Nchw2Nhwc { diff --git a/mindspore/lite/src/ops/constant_of_shape.cc b/mindspore/lite/src/ops/constant_of_shape.cc index e6b3e14b90..73630f650d 100644 --- a/mindspore/lite/src/ops/constant_of_shape.cc +++ b/mindspore/lite/src/ops/constant_of_shape.cc @@ -28,9 +28,9 @@ constexpr int kShapeInputNum = 1; constexpr int kShapeOutputNum = 1; } // namespace #ifdef PRIMITIVE_WRITEABLE -float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; } +std::vector ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; } -void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; } +int ConstantOfShape::GetDataType() const { return this->primitive_->value.AsConstantOfShape()->dataType; } #else int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { @@ -41,12 +41,22 @@ int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, fla MS_LOG(ERROR) << "value_as_ConstantOfShape return nullptr"; return RET_ERROR; } - auto val_offset = schema::CreateConstantOfShape(*fbb, attr->value()); + std::vector value; + if (attr->value() != nullptr) { + for (int i = 0; i < static_cast(attr->value()->size()); i++) { + value.push_back(attr->value()->data()[i]); + } + } + auto val_offset = schema::CreateConstantOfShapeDirect(*fbb, attr->dataType(), &value); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o); fbb->Finish(prim_offset); return RET_OK; } -float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); } +std::vector ConstantOfShape::GetValue() const { + auto fb_vector = this->primitive_->value_as_ConstantOfShape()->value(); + return std::vector(fb_vector->begin(), fb_vector->end()); +} +int ConstantOfShape::GetDataType() const { return this->primitive_->value_as_ConstantOfShape()->dataType(); } PrimitiveC *ConstantOfShapeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); @@ -70,7 +80,7 @@ int ConstantOfShape::InferShape(std::vector inputs_, std::vectorset_data_type(kNumberTypeFloat32); + out_tensor->set_data_type(static_cast(GetDataType())); out_tensor->SetFormat(in_tensor->GetFormat()); if (!GetInferFlag()) { return RET_OK; diff --git a/mindspore/lite/src/ops/constant_of_shape.h b/mindspore/lite/src/ops/constant_of_shape.h index 26917929ed..6dbf5ec527 100644 --- a/mindspore/lite/src/ops/constant_of_shape.h +++ b/mindspore/lite/src/ops/constant_of_shape.h @@ -30,14 +30,14 @@ class ConstantOfShape : public PrimitiveC { MS_DECLARE_PARENT(ConstantOfShape, PrimitiveC); ConstantOfShape() = default; explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - void SetValue(float value); #else ConstantOfShape() = default; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; - float GetValue() const; + std::vector GetValue() const; + int GetDataType() const; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/constant_of_shape_populate.cc b/mindspore/lite/src/ops/populate/constant_of_shape_populate.cc index 90889afbcc..dee7ac9a59 100644 --- a/mindspore/lite/src/ops/populate/constant_of_shape_populate.cc +++ b/mindspore/lite/src/ops/populate/constant_of_shape_populate.cc @@ -34,7 +34,13 @@ OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC } memset(param, 0, sizeof(ConstantOfShapeParameter)); param->op_parameter_.type_ = primitive->Type(); - param->value_ = attr->GetValue(); + auto value = attr->GetValue(); + if (value.empty() || value.size() > 1) { + MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; + } else { + param->value_ = attr->GetValue()[0]; + } + param->data_type_ = attr->GetDataType(); return reinterpret_cast(param); } Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter); diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 94dbb4e8ce..87854669a9 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -28,6 +28,7 @@ namespace lite { namespace { constexpr int kSliceInputNum = 1; constexpr int kSliceOutputNum = 1; +constexpr int kSliceMaxInputNum = 5; } // namespace #ifdef PRIMITIVE_WRITEABLE int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; } @@ -175,6 +176,29 @@ int Slice::InferShape(std::vector inputs, std::vector slice_size(GetSize()); std::vector slice_axes(GetAxes()); std::vector output_shape(input_shape.size()); + if (inputs.size() == kSliceMaxInputNum) { + if (slice_begin.empty() && inputs.at(1)->data_c() != nullptr) { + for (int i = 0; i < inputs.at(1)->ElementsNum(); i++) { + slice_begin.emplace_back(static_cast(inputs.at(1)->data_c())[i]); + } + } + if (slice_size.empty() && inputs.at(2)->data_c() != nullptr) { + for (int i = 0; i < inputs.at(2)->ElementsNum(); i++) { + auto end = static_cast(inputs.at(2)->data_c())[i]; + auto size = end < 0 ? end : (end == INT32_MAX ? -1 : end - slice_begin[i]); + slice_size.emplace_back(size); + } + } + if (slice_axes.empty() && inputs.at(3)->data_c() != nullptr) { + for (int i = 0; i < inputs.at(3)->ElementsNum(); i++) { + slice_axes.emplace_back(static_cast(inputs.at(3)->data_c())[i]); + } + } + } + if (slice_begin.empty() || slice_size.empty() || slice_axes.empty()) { + MS_LOG(ERROR) << "Infershape failed."; + return RET_INFER_INVALID; + } begin.assign(input_shape.size(), 0); size.assign(input_shape.size(), -1); for (size_t i = 0; i < slice_axes.size(); ++i) { diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index 84a3c85d6a..ad27295547 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -54,7 +54,7 @@ Registry TopKRegistry(schema::PrimitiveType_TopK, TopKCreator); int TopK::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); - if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { + if ((inputs_.size() != kSingleNum && inputs_.size() != kDoubleNum) || outputs_.size() != kDoubleNum) { MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); return RET_INPUT_TENSOR_ERROR; } @@ -74,6 +74,9 @@ int TopK::InferShape(std::vector inputs_, std::vector output MS_ASSERT(topk_prim != nullptr); auto out_shape = input->shape(); out_shape[out_shape.size() - 1] = GetK(); + if (inputs_.size() == kDoubleNum && inputs_.at(1)->data_c() != nullptr) { + out_shape[out_shape.size() - 1] = reinterpret_cast(inputs_.at(1)->data_c())[0]; + } output0->set_shape(out_shape); output1->set_shape(out_shape); return RET_OK; diff --git a/mindspore/lite/src/ops/unsqueeze.cc b/mindspore/lite/src/ops/unsqueeze.cc index e9cd9ed58b..44629e4d58 100644 --- a/mindspore/lite/src/ops/unsqueeze.cc +++ b/mindspore/lite/src/ops/unsqueeze.cc @@ -104,9 +104,7 @@ int Unsqueeze::InferShape(std::vector inputs_, std::vector o out_shape.emplace_back(1); ax_itr++; } else { - if (in_shape[in_itr] > 1) { - out_shape.emplace_back(in_shape[in_itr]); - } + out_shape.emplace_back(in_shape[in_itr]); in_itr++; } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h index 51ebf5ab62..4053783039 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -83,6 +83,7 @@ class ArithmeticCPUKernel : public LiteKernel { break; default: arithmetic_run_ = ElementAdd; + arithmetic_run_int_ = ElementAddInt; break; } break; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index cdab2a9653..56564a609b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -77,6 +77,9 @@ int CastCPUKernel::DoCast(int thread_id) { } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) { Float32ToFp16(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt32 && + (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { + memcpy(output_data, input->data_c(), data_num * sizeof(int32_t)); } else { MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.cc index b025b607d5..097dcbb2f0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.cc @@ -33,7 +33,18 @@ int ConstantOfShapeCPUKernel::Init() { return RET_OK; } int ConstantOfShapeCPUKernel::ReSize() { return RET_OK; } int ConstantOfShapeCPUKernel::DoExecute(int task_id) { - int ret = ConstantOfShape(out_ptr_, task_id, param_); + int ret = RET_ERROR; + switch (param_->data_type_) { + case kNumberTypeFloat32: + ret = ConstantOfShape(reinterpret_cast(out_ptr_), task_id, param_); + break; + case kNumberTypeInt32: + ret = ConstantOfShapeInt(reinterpret_cast(out_ptr_), task_id, param_); + break; + default: + MS_LOG(ERROR) << "Constant of shape does not support the output data type."; + return RET_ERROR; + } if (ret != RET_OK) { MS_LOG(ERROR) << "ConstantOfShapeRun error task_id[" << task_id << "] error_code[" << ret << "]"; return ret; @@ -56,7 +67,17 @@ int ConstantOfShapeCPUKernel::Run() { int thread_num = MSMIN(param_->op_parameter_.thread_num_, param_->element_sz_); param_->unit_ = UP_DIV(param_->element_sz_, thread_num); param_->op_parameter_.thread_num_ = thread_num; - out_ptr_ = reinterpret_cast(out_tensors_.front()->MutableData()); + switch (param_->data_type_) { + case kNumberTypeFloat32: + out_ptr_ = reinterpret_cast(out_tensors_.front()->MutableData()); + break; + case kNumberTypeInt32: + out_ptr_ = reinterpret_cast(out_tensors_.front()->MutableData()); + break; + default: + MS_LOG(ERROR) << "Constant of shape does not support the output data type."; + return RET_ERROR; + } auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_num); if (ret != RET_OK) { MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]"; @@ -93,4 +114,5 @@ kernel::LiteKernel *CpuConstantOfShapeFp32KernelCreator(const std::vector= 3) { auto max_output_tensor = in_tensors_.at(kMaxOutputNumTensorIndex); - if (max_output_tensor != nullptr && reinterpret_cast(max_output_tensor->data_c()) != nullptr) { - max_output_per_class_ = *(reinterpret_cast(max_output_tensor->data_c())); + if (max_output_tensor != nullptr && reinterpret_cast(max_output_tensor->data_c()) != nullptr) { + max_output_per_class_ = *(reinterpret_cast(max_output_tensor->data_c())); } } iou_threshold_ = 0.0f; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc index df2565587d..6e9f1eff78 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc @@ -61,6 +61,7 @@ int ReduceCPUKernel::Init() { } case static_cast(ReduceMode_ReduceMin): { reducer_ = ReduceMin; + int_reducer_ = IntReduceMin; break; } case static_cast(ReduceMode_ReduceProd): { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc index b3fa29b4ad..3aff858c24 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc @@ -51,6 +51,14 @@ int TopKCPUKernel::Run() { MS_ASSERT(context_->allocator != nullptr); TopkParameter *parameter = reinterpret_cast(op_parameter_); + if (in_tensors_.size() == lite::kDoubleNum) { + auto input_k = reinterpret_cast(in_tensors_.at(1)->MutableData()); + parameter->k_ = input_k[0]; + } + if (parameter->k_ > in_tensors_.at(0)->ElementsNum()) { + MS_LOG(ERROR) << "The k value is out of the data size range."; + return RET_ERROR; + } parameter->topk_node_list_ = context_->allocator->Malloc(sizeof(TopkNode) * parameter->last_dim_size_); if (parameter->topk_node_list_ == nullptr) { MS_LOG(ERROR) << "Memory allocation failed"; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc index 8908ebcd62..4210d73550 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc @@ -16,6 +16,7 @@ #include "tools/converter/parser/onnx/onnx_constant_of_shape_parser.h" #include +#include "tools/converter/parser/onnx/onnx_model_parser.h" namespace mindspore { namespace lite { @@ -41,13 +42,25 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "value") { - if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) { - auto tensor = onnx_node_attr.t(); - if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) { - attr->value = onnx_node_attr.f(); - } else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) { - attr->value = static_cast(onnx_node_attr.i()); - } + switch (onnx_node_attr.type()) { + case onnx::AttributeProto_AttributeType_FLOAT: + attr->dataType = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); + attr->value.push_back(onnx_node_attr.f()); + break; + case onnx::AttributeProto_AttributeType_INT: + attr->dataType = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); + attr->value.push_back(static_cast(onnx_node_attr.i())); + break; + case onnx::AttributeProto_AttributeType_TENSOR: { + auto tensor = onnx_node_attr.t(); + auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); + if (ret != RET_OK) { + return ret; + } + } break; + default: + MS_LOG(ERROR) << "The data type is not supported."; + return RET_ERROR; } } } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index affba69b3c..269c634d67 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -445,11 +445,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v } for (size_t i = 0; i < data_count; ++i) { if (in_data[i] > static_cast(INT32_MAX) || in_data[i] < static_cast(INT32_MIN)) { - if (llabs(in_data[i]) == INT64_MAX || in_data[i] == INT64_MIN) { - buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN; - } - MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32"; - return RET_ERROR; + MS_LOG(WARNING) << "int64 data " << in_data[i] << "too big to fit into int32"; + buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN; } else { buffer[i] = static_cast(in_data[i]); } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc index 24a9ac8dcc..84fca4c794 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -16,6 +16,7 @@ #include "tools/converter/parser/onnx/onnx_node_parser.h" #include +#include "tools/converter/parser/onnx/onnx_model_parser.h" namespace mindspore { namespace lite { @@ -34,6 +35,36 @@ schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_ } } +STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector *value, + int *type) { + size_t data_count = 1; + std::for_each(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), [&data_count](int dim) { data_count *= dim; }); + switch (onnx_tensor.data_type()) { + case onnx::TensorProto_DataType_FLOAT: + *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); + for (size_t i = 0; i < data_count; i++) { + value->push_back(reinterpret_cast(onnx_tensor.raw_data().data())[i]); + } + break; + case onnx::TensorProto_DataType_INT32: + *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); + for (size_t i = 0; i < data_count; i++) { + value->push_back(static_cast(reinterpret_cast(onnx_tensor.raw_data().data())[i])); + } + break; + case onnx::TensorProto_DataType_INT64: + *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); + for (size_t i = 0; i < data_count; i++) { + value->push_back(static_cast(reinterpret_cast(onnx_tensor.raw_data().data())[i])); + } + break; + default: + MS_LOG(ERROR) << "The data type is not supported."; + return RET_ERROR; + } + return RET_OK; +} + void OnnxNodeParser::Split(const std::string &src_str, std::vector *dst_str, const std::string &chr) { std::string ::size_type p1 = 0, p2 = src_str.find(chr); while (std::string::npos != p2) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h index 10b6a2bffb..43c7eb1032 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h @@ -35,6 +35,8 @@ class OnnxNodeParser { virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; + STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector *value, int *type); + protected: schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);