Merge pull request !4270 from songhonglei413/roitags/v0.7.0-beta
| @@ -56,6 +56,7 @@ union PrimitiveType { | |||
| BatchNorm, | |||
| BiasAdd, | |||
| Pooling, | |||
| ROIPooling, | |||
| DepthwiseConv2D, | |||
| DeDepthwiseConv2D, | |||
| Resize, | |||
| @@ -262,6 +262,12 @@ table BiasAdd { | |||
| axis: [int]; | |||
| } | |||
| table ROIPooling { | |||
| pooledH: int; | |||
| pooledW: int; | |||
| scale: float; | |||
| } | |||
| table Pooling { | |||
| format: Format = 0; | |||
| poolingMode: PoolMode; | |||
| @@ -35,6 +35,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { | |||
| return new lite::Reduce(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_Pooling: | |||
| return new lite::Pooling(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_ROIPooling: | |||
| return new lite::ROIPooling(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_DepthwiseConv2D: | |||
| return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_FusedBatchNorm: | |||
| @@ -56,6 +56,13 @@ class Primitive { | |||
| bool infer_flag_ = true; | |||
| }; | |||
| class ROIPooling : public Primitive { | |||
| public: | |||
| explicit ROIPooling(schema::Primitive *primitive) : Primitive(primitive) {} | |||
| const schema::ROIPooling *GetAttribute() const { return this->primitive->value_as_ROIPooling(); } | |||
| int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override; | |||
| }; | |||
| class Conv2D : public Primitive { | |||
| public: | |||
| explicit Conv2D(schema::Primitive *primitive) : Primitive(primitive) {} | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/ops.h" | |||
| #include "include/errorcode.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/ir/tensor.h" | |||
| namespace mindspore::lite { | |||
| int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| if (inputs_.size() != kDoubleNum) { | |||
| MS_LOG(ERROR) << "inputs number is not equal to " << kDoubleNum; | |||
| return RET_ERROR; | |||
| } | |||
| auto input = inputs_.front(); | |||
| if (input == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto roi = inputs_.at(1); | |||
| if (roi == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto output = outputs_.front(); | |||
| if (output == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto ROIPooling = GetAttribute(); | |||
| auto new_h = ROIPooling->pooledH(); | |||
| auto new_w = ROIPooling->pooledW(); | |||
| auto shape_data = roi->shape(); | |||
| std::vector<int> output_shape; | |||
| output_shape.push_back(shape_data[0]); | |||
| output_shape.push_back(new_h); | |||
| output_shape.push_back(new_w); | |||
| output_shape.push_back(input->Channel()); | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -33,6 +33,7 @@ | |||
| #include "src/runtime/kernel/arm/nnacl/conv_parameter.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/pooling.h" | |||
| #include "src/runtime/kernel/arm/nnacl/matmul_parameter.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/roi_pooling.h" | |||
| #include "src/runtime/kernel/arm/nnacl/softmax_parameter.h" | |||
| #include "src/runtime/kernel/arm/nnacl/tile.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/topk.h" | |||
| @@ -74,6 +75,21 @@ | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/elu.h" | |||
| namespace mindspore::kernel { | |||
| OpParameter *PopulateROIPoolingParameter(const lite::Primitive *primitive) { | |||
| auto pooling_primitive = primitive->Value()->value_as_ROIPooling(); | |||
| ROIPoolingParameter *param = new (std::nothrow) ROIPoolingParameter(); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "new PoolingParameter failed."; | |||
| return nullptr; | |||
| } | |||
| param->op_parameter_.type_ = primitive->Type(); | |||
| param->pooledH_ = pooling_primitive->pooledH(); | |||
| param->pooledW_ = pooling_primitive->pooledW(); | |||
| param->scale_ = pooling_primitive->scale(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| OpParameter *PopulateBatchNorm(const lite::Primitive *primitive) { | |||
| BatchNormParameter *batch_norm_param = new (std::nothrow) BatchNormParameter(); | |||
| if (batch_norm_param == nullptr) { | |||
| @@ -1270,6 +1286,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||
| populate_parameter_funcs_[schema::PrimitiveType_Reduce] = PopulateReduceParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_Mean] = PopulateMeanParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_Pooling] = PopulatePoolingParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_ROIPooling] = PopulateROIPoolingParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_DepthwiseConv2D] = PopulateConvDwParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_DeDepthwiseConv2D] = PopulateDeconvDwParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_DeConv2D] = PopulateDeconvParameter; | |||
| @@ -0,0 +1,112 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/roi_pooling.h" | |||
| #include <vector> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_ROIPooling; | |||
| namespace mindspore::kernel { | |||
| int ROIPoolingCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int ROIPoolingCPUKernel::ReSize() { return RET_OK; } | |||
| int ROIPoolingCPUKernel::DoExecute(int task_id) { | |||
| auto ret = ROIPooling(in_ptr_, out_ptr_, roi_ptr_, in_shape_, out_shape_, dim_, task_id, param_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ROIPooling Execute error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ROIPoolingRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto Data = reinterpret_cast<ROIPoolingCPUKernel *>(cdata); | |||
| auto ret = Data->DoExecute(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ROIPooling Run error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ROIPoolingCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail! ret: " << ret; | |||
| return ret; | |||
| } | |||
| in_ptr_ = reinterpret_cast<float *>(inputs_.front()->Data()); | |||
| out_ptr_ = reinterpret_cast<float *>(outputs_.front()->Data()); | |||
| roi_ptr_ = reinterpret_cast<float *>(inputs_.at(1)->Data()); | |||
| in_shape_ = reinterpret_cast<const int *>(inputs_.front()->shape().data()); | |||
| out_shape_ = reinterpret_cast<const int *>(outputs_.front()->shape().data()); | |||
| dim_ = inputs_.front()->shape().size(); | |||
| thread_count_ = 1; | |||
| ret = LiteBackendParallelLaunch(ROIPoolingRun, this, thread_count_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ROIPooling error: error_code[" << ret << "]"; | |||
| return ret; | |||
| } | |||
| return ret; | |||
| } | |||
| kernel::LiteKernel *CpuROIPoolingFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input opParameter is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "Input context is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (ctx->thread_num_ == 0) { | |||
| MS_LOG(ERROR) << "context thread num is 0!"; | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) ROIPoolingCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new ROIPoolingCPUKernel fail!"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ROIPooling, CpuROIPoolingFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ROI_POOLING_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ROI_POOLING_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/roi_pooling.h" | |||
| namespace mindspore::kernel { | |||
| class ROIPoolingCPUKernel : public LiteKernel { | |||
| public: | |||
| ROIPoolingCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||
| const lite::Primitive *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| param_ = reinterpret_cast<ROIPoolingParameter *>(parameter); | |||
| } | |||
| ~ROIPoolingCPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int DoExecute(int task_id); | |||
| private: | |||
| float *in_ptr_; | |||
| float *out_ptr_; | |||
| float *roi_ptr_; | |||
| const int *in_shape_; | |||
| const int *out_shape_; | |||
| ROIPoolingParameter *param_; | |||
| int dim_; | |||
| int thread_count_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_ | |||
| @@ -0,0 +1,96 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp32/roi_pooling.h" | |||
| #include <math.h> | |||
| #include "nnacl/errorcode.h" | |||
| int ROIPooling(float *in_ptr, float *out_ptr, float *roi, const int *in_shape, const int *out_shape, int dim, int tid, | |||
| ROIPoolingParameter *param) { | |||
| int num_rois = out_shape[kNHWC_N]; | |||
| int batch_size = in_shape[kNHWC_N]; | |||
| int height_ = in_shape[kNHWC_H]; | |||
| int width_ = in_shape[kNHWC_W]; | |||
| int channels_ = in_shape[kNHWC_C]; | |||
| int scale = param->scale_; | |||
| int pooled_height = param->pooledH_; | |||
| int pooled_width = param->pooledW_; | |||
| int in_stride[DIMENSION_4D]; | |||
| int out_stride[DIMENSION_4D]; | |||
| int roi_stride = 5; | |||
| in_stride[DIMENSION_4D - 1] = 1; | |||
| out_stride[DIMENSION_4D - 1] = 1; | |||
| for (int i = dim - 2; i >= 0; --i) { | |||
| in_stride[i] = in_stride[i + 1] * in_shape[i + 1]; | |||
| out_stride[i] = out_stride[i + 1] * out_shape[i + 1]; | |||
| } | |||
| int roi_ind_st = 0; | |||
| for (int i = 0; i < num_rois; ++i) { | |||
| int roi_batch_ind = (int)roi[roi_ind_st]; // batch_index | |||
| if (roi_batch_ind >= batch_size) { | |||
| return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; | |||
| } | |||
| int roi_start_h = (int)roundf(roi[roi_ind_st + 1] * scale); // top-left x1 | |||
| int roi_start_w = (int)roundf(roi[roi_ind_st + 2] * scale); // top-left y1 | |||
| int roi_end_h = (int)roundf(roi[roi_ind_st + 3] * scale); // bottom-right x2 | |||
| int roi_end_w = (int)roundf(roi[roi_ind_st + 4] * scale); // bottom-fight y2 | |||
| int roi_height = MSMAX(roi_end_h - roi_start_h + 1, 1); | |||
| int roi_width = MSMAX(roi_end_w - roi_start_w + 1, 1); | |||
| float bin_size_h = (float)roi_height / (float)pooled_height; | |||
| float bin_size_w = (float)roi_width / (float)pooled_width; | |||
| float *batch_data = in_ptr + in_stride[kNHWC_N] * roi_batch_ind; | |||
| int out_ind = i * out_stride[0]; | |||
| for (int c = kNHWC_N; c < channels_; ++c) { | |||
| float max_v = -__FLT_MAX__; | |||
| for (int ph = 0; ph < pooled_height; ++ph) { | |||
| for (int pw = 0; pw < pooled_width; ++pw) { | |||
| int pooled_index = | |||
| i * out_stride[kNHWC_N] + ph * out_stride[kNHWC_H] + pw * out_stride[kNHWC_W] + c * out_stride[kNHWC_C]; | |||
| int hstart = (int)floorf(ph * bin_size_h); // block xi_1 | |||
| int wstart = (int)floorf(pw * bin_size_w); // block yi_1 | |||
| int hend = (int)ceilf((ph + 1) * bin_size_h); // block xi_2 | |||
| int wend = (int)ceilf((pw + 1) * bin_size_w); // block yi_2 | |||
| hstart = MSMIN(MSMAX(hstart + roi_start_h, 0), height_); | |||
| hend = MSMIN(MSMAX(hend + roi_start_h, 0), height_); | |||
| wstart = MSMIN(MSMAX(wstart + roi_start_w, 0), width_); | |||
| wend = MSMIN(MSMAX(wend + roi_start_w, 0), width_); | |||
| bool is_empty = (hend <= hstart) || (wend <= wstart); | |||
| if (is_empty) { | |||
| max_v = 0; | |||
| } | |||
| int bd_index = c * in_stride[kNHWC_C] + hstart * in_stride[kNHWC_H]; | |||
| for (int h = hstart; h < hend; ++h) { | |||
| int wi = bd_index + wstart * in_stride[kNHWC_W]; | |||
| for (int w = wstart; w < wend; ++w) { | |||
| max_v = MSMAX(batch_data[wi], max_v); | |||
| // printf("bd:index: %d, data: %f, max_v: %f\n",wi,batch_data[wi],max_v); | |||
| wi += in_stride[kNHWC_W]; | |||
| } | |||
| bd_index += in_stride[kNHWC_H]; | |||
| } | |||
| out_ptr[pooled_index] = max_v; | |||
| } | |||
| } | |||
| } | |||
| roi_ind_st += roi_stride; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ROI_POOLING_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ROI_POOLING_H_ | |||
| #include "nnacl/op_base.h" | |||
| typedef struct ROIPoolingParameter { | |||
| OpParameter op_parameter_; | |||
| int pooledW_; | |||
| int pooledH_; | |||
| float scale_; | |||
| } ROIPoolingParameter; | |||
| int ROIPooling(float *in_ptr, float *out_ptr, float *roi, const int *in_shape, const int *out_shape, int dim, int tid, | |||
| ROIPoolingParameter *param); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ROI_POOLING_H_ | |||
| @@ -0,0 +1,80 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/roi_pooling.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore { | |||
| class TestROIPoolingFp32 : public mindspore::Common { | |||
| public: | |||
| TestROIPoolingFp32() {} | |||
| }; | |||
| int ROIPoolingTestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_, | |||
| float *a_ptr, float *b_ptr, std::vector<int> a_shape, std::vector<int> b_shape, | |||
| std::vector<int> c_shape) { | |||
| auto in_t = | |||
| new lite::tensor::Tensor(kNumberTypeFloat, a_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||
| in_t->MallocData(); | |||
| memcpy(in_t->Data(), a_ptr, sizeof(float) * in_t->ElementsNum()); | |||
| inputs_->push_back(in_t); | |||
| auto roi_t = | |||
| new lite::tensor::Tensor(kNumberTypeFloat, b_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||
| roi_t->MallocData(); | |||
| memcpy(roi_t->Data(), b_ptr, sizeof(float) * roi_t->ElementsNum()); | |||
| inputs_->push_back(roi_t); | |||
| auto out_t = | |||
| new lite::tensor::Tensor(kNumberTypeFloat, c_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||
| out_t->MallocData(); | |||
| outputs_->push_back(out_t); | |||
| return out_t->ElementsNum(); | |||
| } | |||
| TEST_F(TestROIPoolingFp32, Simple) { | |||
| std::vector<lite::tensor::Tensor *> inputs_; | |||
| std::vector<lite::tensor::Tensor *> outputs_; | |||
| auto param = new ROIPoolingParameter(); | |||
| param->scale_ = 1; | |||
| param->pooledW_ = 2; | |||
| param->pooledH_ = 2; | |||
| float a[] = {1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33, 34, 35, | |||
| 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33, 34, 35}; | |||
| float b[] = {0, 1, 1, 3, 4, 1, 1, 1, 3, 4}; | |||
| std::vector<int> a_shape = {2, 4, 5, 1}; | |||
| std::vector<int> b_shape = {2, 5}; | |||
| std::vector<int> c_shape = {2, 2, 2, 1}; | |||
| int total_size = ROIPoolingTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); | |||
| auto ctx = new lite::Context; | |||
| ctx->thread_num_ = 1; | |||
| kernel::ROIPoolingCPUKernel *op = | |||
| new kernel::ROIPoolingCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr); | |||
| op->Init(); | |||
| op->Run(); | |||
| float correct[] = {23, 25, 33, 35, 23, 25, 33, 35}; | |||
| float *output = reinterpret_cast<float *>(outputs_[0]->Data()); | |||
| for (int i = 0; i < 8; ++i) printf("%f ", output[i]); | |||
| printf("\n"); | |||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||
| delete op; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||
| } | |||
| } // namespace mindspore | |||