Merge pull request !7601 from wangshaocong/bugfix_mastertags/v1.1.0
| @@ -657,6 +657,22 @@ int ElementAddRelu6(float *input0, float *input1, float *output, int element_siz | |||
| return NNACL_OK; | |||
| } | |||
| int ElementAddInt(int *input0, int *input1, int *output, int element_size) { | |||
| int index = 0; | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 4; index += C4NUM) { | |||
| int32x4_t vin0 = vld1q_s32(input0 + index); | |||
| int32x4_t vin1 = vld1q_s32(input1 + index); | |||
| int32x4_t vout = vaddq_s32(vin0, vin1); | |||
| vst1q_s32(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = input0[index] + input1[index]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementAddInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = input0[i] + input1[i]; | |||
| @@ -54,6 +54,7 @@ int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_i | |||
| int ElementAdd(float *input0, float *input1, float *output, int element_size); | |||
| int ElementAddRelu(float *input0, float *input1, float *output, int element_size); | |||
| int ElementAddRelu6(float *input0, float *input1, float *output, int element_size); | |||
| int ElementAddInt(int *input0, int *input1, int *output, int element_size); | |||
| int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output, | |||
| @@ -26,3 +26,14 @@ int ConstantOfShape(float *output, int tid, ConstantOfShapeParameter *param) { | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ConstantOfShapeInt(int32_t *output, int tid, ConstantOfShapeParameter *param) { | |||
| int size = param->unit_; | |||
| float data = param->value_; | |||
| int ind_st = MSMIN(tid * size, param->element_sz_); | |||
| int ind_end = MSMIN(param->element_sz_, (tid + 1) * size); | |||
| for (int i = ind_st; i < ind_end; ++i) { | |||
| output[i] = data; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -25,6 +25,7 @@ | |||
| typedef struct ConstantOfShapeParameter { | |||
| OpParameter op_parameter_; | |||
| float value_; | |||
| int data_type_; | |||
| int unit_; | |||
| int element_sz_; | |||
| } ConstantOfShapeParameter; | |||
| @@ -33,6 +34,7 @@ typedef struct ConstantOfShapeParameter { | |||
| extern "C" { | |||
| #endif | |||
| int ConstantOfShape(float *output, int tid, ConstantOfShapeParameter *param); | |||
| int ConstantOfShapeInt(int32_t *output, int tid, ConstantOfShapeParameter *param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -123,6 +123,27 @@ int ReduceMin(const int outer_size, const int inner_size, const int axis_size, c | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int IntReduceMin(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data, | |||
| const int tid, const int thread_num) { | |||
| if (src_data == NULL || dst_data == NULL) { | |||
| return NNACL_NULL_PTR; | |||
| } | |||
| int i, j, k; | |||
| for (j = tid; j < outer_size; j += thread_num) { | |||
| const int *outer_src = src_data + j * axis_size * inner_size; | |||
| int *outer_dst = dst_data + j * inner_size; | |||
| for (k = 0; k < inner_size; k++) { | |||
| const int *inner_src = outer_src + k; | |||
| int *inner_dst = outer_dst + k; | |||
| int tmp = INT32_MAX; | |||
| for (i = 0; i < axis_size; i++) { | |||
| tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; | |||
| } | |||
| *inner_dst = tmp; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data, | |||
| const int tid, const int thread_num) { | |||
| if (src_data == NULL || dst_data == NULL) { | |||
| @@ -30,6 +30,8 @@ int ReduceMax(const int outer_size, const int inner_size, const int axis_size, c | |||
| const int tid, const int thread_num); | |||
| int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data, | |||
| const int tid, const int thread_num); | |||
| int IntReduceMin(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data, | |||
| const int tid, const int thread_num); | |||
| int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data, | |||
| const int tid, const int thread_num); | |||
| int IntReduceProd(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data, | |||
| @@ -308,7 +308,8 @@ table Shape { | |||
| } | |||
| table ConstantOfShape{ | |||
| value: float = 0; | |||
| dataType: int; | |||
| value: [float]; | |||
| } | |||
| table Nchw2Nhwc { | |||
| @@ -28,9 +28,9 @@ constexpr int kShapeInputNum = 1; | |||
| constexpr int kShapeOutputNum = 1; | |||
| } // namespace | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; } | |||
| std::vector<float> ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; } | |||
| void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; } | |||
| int ConstantOfShape::GetDataType() const { return this->primitive_->value.AsConstantOfShape()->dataType; } | |||
| #else | |||
| int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| @@ -41,12 +41,22 @@ int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, fla | |||
| MS_LOG(ERROR) << "value_as_ConstantOfShape return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateConstantOfShape(*fbb, attr->value()); | |||
| std::vector<float> value; | |||
| if (attr->value() != nullptr) { | |||
| for (int i = 0; i < static_cast<int>(attr->value()->size()); i++) { | |||
| value.push_back(attr->value()->data()[i]); | |||
| } | |||
| } | |||
| auto val_offset = schema::CreateConstantOfShapeDirect(*fbb, attr->dataType(), &value); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); } | |||
| std::vector<float> ConstantOfShape::GetValue() const { | |||
| auto fb_vector = this->primitive_->value_as_ConstantOfShape()->value(); | |||
| return std::vector<float>(fb_vector->begin(), fb_vector->end()); | |||
| } | |||
| int ConstantOfShape::GetDataType() const { return this->primitive_->value_as_ConstantOfShape()->dataType(); } | |||
| PrimitiveC *ConstantOfShapeCreator(const schema::Primitive *primitive) { | |||
| return PrimitiveC::NewPrimitiveC<ConstantOfShape>(primitive); | |||
| @@ -70,7 +80,7 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso | |||
| } | |||
| auto in_tensor = inputs_.front(); | |||
| auto out_tensor = outputs_.front(); | |||
| out_tensor->set_data_type(kNumberTypeFloat32); | |||
| out_tensor->set_data_type(static_cast<TypeId>(GetDataType())); | |||
| out_tensor->SetFormat(in_tensor->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| @@ -30,14 +30,14 @@ class ConstantOfShape : public PrimitiveC { | |||
| MS_DECLARE_PARENT(ConstantOfShape, PrimitiveC); | |||
| ConstantOfShape() = default; | |||
| explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetValue(float value); | |||
| #else | |||
| ConstantOfShape() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| float GetValue() const; | |||
| std::vector<float> GetValue() const; | |||
| int GetDataType() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -34,7 +34,13 @@ OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC | |||
| } | |||
| memset(param, 0, sizeof(ConstantOfShapeParameter)); | |||
| param->op_parameter_.type_ = primitive->Type(); | |||
| param->value_ = attr->GetValue(); | |||
| auto value = attr->GetValue(); | |||
| if (value.empty() || value.size() > 1) { | |||
| MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1."; | |||
| } else { | |||
| param->value_ = attr->GetValue()[0]; | |||
| } | |||
| param->data_type_ = attr->GetDataType(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter); | |||
| @@ -28,6 +28,7 @@ namespace lite { | |||
| namespace { | |||
| constexpr int kSliceInputNum = 1; | |||
| constexpr int kSliceOutputNum = 1; | |||
| constexpr int kSliceMaxInputNum = 5; | |||
| } // namespace | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; } | |||
| @@ -175,6 +176,29 @@ int Slice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tens | |||
| std::vector<int32_t> slice_size(GetSize()); | |||
| std::vector<int32_t> slice_axes(GetAxes()); | |||
| std::vector<int32_t> output_shape(input_shape.size()); | |||
| if (inputs.size() == kSliceMaxInputNum) { | |||
| if (slice_begin.empty() && inputs.at(1)->data_c() != nullptr) { | |||
| for (int i = 0; i < inputs.at(1)->ElementsNum(); i++) { | |||
| slice_begin.emplace_back(static_cast<int *>(inputs.at(1)->data_c())[i]); | |||
| } | |||
| } | |||
| if (slice_size.empty() && inputs.at(2)->data_c() != nullptr) { | |||
| for (int i = 0; i < inputs.at(2)->ElementsNum(); i++) { | |||
| auto end = static_cast<int *>(inputs.at(2)->data_c())[i]; | |||
| auto size = end < 0 ? end : (end == INT32_MAX ? -1 : end - slice_begin[i]); | |||
| slice_size.emplace_back(size); | |||
| } | |||
| } | |||
| if (slice_axes.empty() && inputs.at(3)->data_c() != nullptr) { | |||
| for (int i = 0; i < inputs.at(3)->ElementsNum(); i++) { | |||
| slice_axes.emplace_back(static_cast<int *>(inputs.at(3)->data_c())[i]); | |||
| } | |||
| } | |||
| } | |||
| if (slice_begin.empty() || slice_size.empty() || slice_axes.empty()) { | |||
| MS_LOG(ERROR) << "Infershape failed."; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| begin.assign(input_shape.size(), 0); | |||
| size.assign(input_shape.size(), -1); | |||
| for (size_t i = 0; i < slice_axes.size(); ++i) { | |||
| @@ -54,7 +54,7 @@ Registry TopKRegistry(schema::PrimitiveType_TopK, TopKCreator); | |||
| int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive_ != nullptr); | |||
| if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { | |||
| if ((inputs_.size() != kSingleNum && inputs_.size() != kDoubleNum) || outputs_.size() != kDoubleNum) { | |||
| MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| @@ -74,6 +74,9 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||
| MS_ASSERT(topk_prim != nullptr); | |||
| auto out_shape = input->shape(); | |||
| out_shape[out_shape.size() - 1] = GetK(); | |||
| if (inputs_.size() == kDoubleNum && inputs_.at(1)->data_c() != nullptr) { | |||
| out_shape[out_shape.size() - 1] = reinterpret_cast<int *>(inputs_.at(1)->data_c())[0]; | |||
| } | |||
| output0->set_shape(out_shape); | |||
| output1->set_shape(out_shape); | |||
| return RET_OK; | |||
| @@ -104,9 +104,7 @@ int Unsqueeze::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o | |||
| out_shape.emplace_back(1); | |||
| ax_itr++; | |||
| } else { | |||
| if (in_shape[in_itr] > 1) { | |||
| out_shape.emplace_back(in_shape[in_itr]); | |||
| } | |||
| out_shape.emplace_back(in_shape[in_itr]); | |||
| in_itr++; | |||
| } | |||
| } | |||
| @@ -83,6 +83,7 @@ class ArithmeticCPUKernel : public LiteKernel { | |||
| break; | |||
| default: | |||
| arithmetic_run_ = ElementAdd; | |||
| arithmetic_run_int_ = ElementAddInt; | |||
| break; | |||
| } | |||
| break; | |||
| @@ -77,6 +77,9 @@ int CastCPUKernel::DoCast(int thread_id) { | |||
| } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) { | |||
| Float32ToFp16(reinterpret_cast<float *>(input->data_c()) + offset, | |||
| reinterpret_cast<uint16_t *>(output_data) + offset, data_num); | |||
| } else if (input_data_type == kNumberTypeInt32 && | |||
| (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { | |||
| memcpy(output_data, input->data_c(), data_num * sizeof(int32_t)); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | |||
| return RET_ERROR; | |||
| @@ -33,7 +33,18 @@ int ConstantOfShapeCPUKernel::Init() { return RET_OK; } | |||
| int ConstantOfShapeCPUKernel::ReSize() { return RET_OK; } | |||
| int ConstantOfShapeCPUKernel::DoExecute(int task_id) { | |||
| int ret = ConstantOfShape(out_ptr_, task_id, param_); | |||
| int ret = RET_ERROR; | |||
| switch (param_->data_type_) { | |||
| case kNumberTypeFloat32: | |||
| ret = ConstantOfShape(reinterpret_cast<float *>(out_ptr_), task_id, param_); | |||
| break; | |||
| case kNumberTypeInt32: | |||
| ret = ConstantOfShapeInt(reinterpret_cast<int32_t *>(out_ptr_), task_id, param_); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Constant of shape does not support the output data type."; | |||
| return RET_ERROR; | |||
| } | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConstantOfShapeRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return ret; | |||
| @@ -56,7 +67,17 @@ int ConstantOfShapeCPUKernel::Run() { | |||
| int thread_num = MSMIN(param_->op_parameter_.thread_num_, param_->element_sz_); | |||
| param_->unit_ = UP_DIV(param_->element_sz_, thread_num); | |||
| param_->op_parameter_.thread_num_ = thread_num; | |||
| out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData()); | |||
| switch (param_->data_type_) { | |||
| case kNumberTypeFloat32: | |||
| out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData()); | |||
| break; | |||
| case kNumberTypeInt32: | |||
| out_ptr_ = reinterpret_cast<int32_t *>(out_tensors_.front()->MutableData()); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Constant of shape does not support the output data type."; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_num); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]"; | |||
| @@ -93,4 +114,5 @@ kernel::LiteKernel *CpuConstantOfShapeFp32KernelCreator(const std::vector<lite:: | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, CpuConstantOfShapeFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, CpuConstantOfShapeFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -41,7 +41,7 @@ class ConstantOfShapeCPUKernel : public LiteKernel { | |||
| private: | |||
| ConstantOfShapeParameter *param_; | |||
| float *out_ptr_; | |||
| void *out_ptr_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -74,8 +74,8 @@ int NonMaxSuppressionCPUKernel::GetParams() { | |||
| max_output_per_class_ = 0; | |||
| if (in_tensors_.size() >= 3) { | |||
| auto max_output_tensor = in_tensors_.at(kMaxOutputNumTensorIndex); | |||
| if (max_output_tensor != nullptr && reinterpret_cast<int64_t *>(max_output_tensor->data_c()) != nullptr) { | |||
| max_output_per_class_ = *(reinterpret_cast<int64_t *>(max_output_tensor->data_c())); | |||
| if (max_output_tensor != nullptr && reinterpret_cast<int32_t *>(max_output_tensor->data_c()) != nullptr) { | |||
| max_output_per_class_ = *(reinterpret_cast<int32_t *>(max_output_tensor->data_c())); | |||
| } | |||
| } | |||
| iou_threshold_ = 0.0f; | |||
| @@ -61,6 +61,7 @@ int ReduceCPUKernel::Init() { | |||
| } | |||
| case static_cast<int>(ReduceMode_ReduceMin): { | |||
| reducer_ = ReduceMin; | |||
| int_reducer_ = IntReduceMin; | |||
| break; | |||
| } | |||
| case static_cast<int>(ReduceMode_ReduceProd): { | |||
| @@ -51,6 +51,14 @@ int TopKCPUKernel::Run() { | |||
| MS_ASSERT(context_->allocator != nullptr); | |||
| TopkParameter *parameter = reinterpret_cast<TopkParameter *>(op_parameter_); | |||
| if (in_tensors_.size() == lite::kDoubleNum) { | |||
| auto input_k = reinterpret_cast<int *>(in_tensors_.at(1)->MutableData()); | |||
| parameter->k_ = input_k[0]; | |||
| } | |||
| if (parameter->k_ > in_tensors_.at(0)->ElementsNum()) { | |||
| MS_LOG(ERROR) << "The k value is out of the data size range."; | |||
| return RET_ERROR; | |||
| } | |||
| parameter->topk_node_list_ = context_->allocator->Malloc(sizeof(TopkNode) * parameter->last_dim_size_); | |||
| if (parameter->topk_node_list_ == nullptr) { | |||
| MS_LOG(ERROR) << "Memory allocation failed"; | |||
| @@ -16,6 +16,7 @@ | |||
| #include "tools/converter/parser/onnx/onnx_constant_of_shape_parser.h" | |||
| #include <memory> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -41,13 +42,25 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "value") { | |||
| if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) { | |||
| auto tensor = onnx_node_attr.t(); | |||
| if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) { | |||
| attr->value = onnx_node_attr.f(); | |||
| } else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) { | |||
| attr->value = static_cast<int32_t>(onnx_node_attr.i()); | |||
| } | |||
| switch (onnx_node_attr.type()) { | |||
| case onnx::AttributeProto_AttributeType_FLOAT: | |||
| attr->dataType = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); | |||
| attr->value.push_back(onnx_node_attr.f()); | |||
| break; | |||
| case onnx::AttributeProto_AttributeType_INT: | |||
| attr->dataType = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); | |||
| attr->value.push_back(static_cast<float>(onnx_node_attr.i())); | |||
| break; | |||
| case onnx::AttributeProto_AttributeType_TENSOR: { | |||
| auto tensor = onnx_node_attr.t(); | |||
| auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| } break; | |||
| default: | |||
| MS_LOG(ERROR) << "The data type is not supported."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| @@ -445,11 +445,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v | |||
| } | |||
| for (size_t i = 0; i < data_count; ++i) { | |||
| if (in_data[i] > static_cast<int64_t>(INT32_MAX) || in_data[i] < static_cast<int64_t>(INT32_MIN)) { | |||
| if (llabs(in_data[i]) == INT64_MAX || in_data[i] == INT64_MIN) { | |||
| buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN; | |||
| } | |||
| MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32"; | |||
| return RET_ERROR; | |||
| MS_LOG(WARNING) << "int64 data " << in_data[i] << "too big to fit into int32"; | |||
| buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN; | |||
| } else { | |||
| buffer[i] = static_cast<int>(in_data[i]); | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include <vector> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -34,6 +35,36 @@ schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_ | |||
| } | |||
| } | |||
| STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, | |||
| int *type) { | |||
| size_t data_count = 1; | |||
| std::for_each(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), [&data_count](int dim) { data_count *= dim; }); | |||
| switch (onnx_tensor.data_type()) { | |||
| case onnx::TensorProto_DataType_FLOAT: | |||
| *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT); | |||
| for (size_t i = 0; i < data_count; i++) { | |||
| value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]); | |||
| } | |||
| break; | |||
| case onnx::TensorProto_DataType_INT32: | |||
| *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); | |||
| for (size_t i = 0; i < data_count; i++) { | |||
| value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i])); | |||
| } | |||
| break; | |||
| case onnx::TensorProto_DataType_INT64: | |||
| *type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32); | |||
| for (size_t i = 0; i < data_count; i++) { | |||
| value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i])); | |||
| } | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "The data type is not supported."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void OnnxNodeParser::Split(const std::string &src_str, std::vector<std::string> *dst_str, const std::string &chr) { | |||
| std::string ::size_type p1 = 0, p2 = src_str.find(chr); | |||
| while (std::string::npos != p2) { | |||
| @@ -35,6 +35,8 @@ class OnnxNodeParser { | |||
| virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; | |||
| STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); | |||
| protected: | |||
| schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); | |||