| @@ -156,6 +156,7 @@ union PrimitiveType { | |||||
| Select, | Select, | ||||
| Scatter, | Scatter, | ||||
| ScatterND, | ScatterND, | ||||
| ConstantOfShape, | |||||
| Unique, | Unique, | ||||
| Unstack, | Unstack, | ||||
| LogicalAnd, | LogicalAnd, | ||||
| @@ -249,6 +249,10 @@ table PoolingGrad { | |||||
| table Shape { | table Shape { | ||||
| } | } | ||||
| table ConstantOfShape{ | |||||
| value: float = 0; | |||||
| } | |||||
| table Nchw2Nhwc { | table Nchw2Nhwc { | ||||
| } | } | ||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/ops.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "src/ir/tensor.h" | |||||
| namespace mindspore::lite { | |||||
| namespace { | |||||
| constexpr int kShapeInputNum = 1; | |||||
| constexpr int kShapeOutputNum = 1; | |||||
| } // namespace | |||||
| int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||||
| if (inputs_.size() != kShapeInputNum) { | |||||
| MS_LOG(ERROR) << "inputs to ConstantOfShape operator should be 1, but " << inputs_.size() << " is given."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (inputs_.front() == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr!"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (outputs_.size() != kShapeOutputNum) { | |||||
| MS_LOG(ERROR) << "outputs to ConstantOfShape operator should be 1, but " << outputs_.size() << " is given."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto in_tensor = inputs_.front(); | |||||
| auto in_data = reinterpret_cast<int *>(in_tensor->Data()); | |||||
| auto out_tensor = outputs_.front(); | |||||
| int size = in_tensor->ElementsNum(); | |||||
| std::vector<int> out_shape(size); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| out_shape[i] = in_data[i]; | |||||
| } | |||||
| out_tensor->set_shape(out_shape); | |||||
| out_tensor->set_data_type(kNumberTypeFloat32); | |||||
| out_tensor->SetFormat(in_tensor->GetFormat()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -145,6 +145,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { | |||||
| return new lite::MatMul(const_cast<schema::Primitive *>(primitive)); | return new lite::MatMul(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_EmbeddingLookup: | case schema::PrimitiveType_EmbeddingLookup: | ||||
| return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(primitive)); | return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(primitive)); | ||||
| case schema::PrimitiveType_ConstantOfShape: | |||||
| return new lite::ConstantOfShape(const_cast<schema::Primitive *>(primitive)); | |||||
| default: | default: | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -717,6 +717,13 @@ class Shape : public Primitive { | |||||
| int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; | int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; | ||||
| }; | }; | ||||
| class ConstantOfShape : public Primitive { | |||||
| public: | |||||
| explicit ConstantOfShape(schema::Primitive *primitive) : Primitive(primitive) {} | |||||
| const schema::ConstantOfShape *GetAttribute() const { return this->primitive->value_as_ConstantOfShape(); } | |||||
| int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; | |||||
| }; | |||||
| class ScatterND : public Primitive { | class ScatterND : public Primitive { | ||||
| public: | public: | ||||
| explicit ScatterND(schema::Primitive *primitive) : Primitive(primitive) {} | explicit ScatterND(schema::Primitive *primitive) : Primitive(primitive) {} | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h" | #include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/reshape_parameter.h" | #include "src/runtime/kernel/arm/nnacl/reshape_parameter.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/shape.h" | #include "src/runtime/kernel/arm/nnacl/shape.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/stack.h" | #include "src/runtime/kernel/arm/nnacl/fp32/stack.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/unstack.h" | #include "src/runtime/kernel/arm/nnacl/unstack.h" | ||||
| #include "src/runtime/kernel/arm/nnacl/depth_to_space.h" | #include "src/runtime/kernel/arm/nnacl/depth_to_space.h" | ||||
| @@ -937,6 +938,18 @@ OpParameter *PopulateShapeParameter(const lite::Primitive *primitive) { | |||||
| return reinterpret_cast<OpParameter *>(shape_param); | return reinterpret_cast<OpParameter *>(shape_param); | ||||
| } | } | ||||
| OpParameter *PopulateConstantOfShapeParameter(const lite::Primitive *primitive) { | |||||
| auto attr = primitive->Value()->value_as_ConstantOfShape(); | |||||
| ConstantOfShapeParameter *param = new (std::nothrow) ConstantOfShapeParameter(); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "new ConstantOfShapeParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| param->op_parameter_.type_ = primitive->Type(); | |||||
| param->value_ = attr->value(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | |||||
| OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) { | OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) { | ||||
| auto reverse_attr = primitive->Value()->value_as_Reverse(); | auto reverse_attr = primitive->Value()->value_as_Reverse(); | ||||
| ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter(); | ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter(); | ||||
| @@ -1370,6 +1383,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||||
| populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter; | populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter; | populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter; | populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_ConstantOfShape] = PopulateConstantOfShapeParameter; | |||||
| populate_parameter_funcs_[schema::PrimitiveType_Shape] = PopulateShapeParameter; | populate_parameter_funcs_[schema::PrimitiveType_Shape] = PopulateShapeParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter; | populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter; | populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter; | ||||
| @@ -0,0 +1,106 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/constant_of_shape.h" | |||||
| #include <vector> | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_ConstantOfShape; | |||||
| namespace mindspore::kernel { | |||||
| namespace { | |||||
| constexpr int kInputNum = 1; | |||||
| constexpr int kOutputNum = 1; | |||||
| } // namespace | |||||
| int ConstantOfShapeCPUKernel::Init() { return RET_OK; } | |||||
| int ConstantOfShapeCPUKernel::ReSize() { return RET_OK; } | |||||
| int ConstantOfShapeCPUKernel::DoExecute(int task_id) { | |||||
| int ret = ConstantOfShape(out_ptr_, task_id, param_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConstantOfShapeRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConstantOfShapeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto g_kernel = reinterpret_cast<ConstantOfShapeCPUKernel *>(cdata); | |||||
| auto ret = g_kernel->DoExecute(task_id); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConstantOfShapeRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConstantOfShapeCPUKernel::Run() { | |||||
| auto prepare_ret = Prepare(); | |||||
| if (prepare_ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||||
| return prepare_ret; | |||||
| } | |||||
| param_->element_sz_ = out_tensors_.front()->ElementsNum(); | |||||
| int thread_num = MSMIN(param_->op_parameter_.thread_num_, param_->element_sz_); | |||||
| param_->unit_ = UP_DIV(param_->element_sz_, thread_num); | |||||
| param_->op_parameter_.thread_num_ = thread_num; | |||||
| out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->Data()); | |||||
| auto ret = LiteBackendParallelLaunch(ConstantOfShapeRun, this, thread_num); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]"; | |||||
| return ret; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| kernel::LiteKernel *CpuConstantOfShapeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc, | |||||
| const lite::Primitive *primitive) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| if (opParameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_ConstantOfShape. "; | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_ConstantOfShape); | |||||
| auto *kernel = new (std::nothrow) ConstantOfShapeCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new ConstantOfShapeCPUKernel fail!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, CpuConstantOfShapeFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "include/context.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.h" | |||||
| using mindspore::lite::Context; | |||||
| namespace mindspore::kernel { | |||||
| class ConstantOfShapeCPUKernel : public LiteKernel { | |||||
| public: | |||||
| ConstantOfShapeCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||||
| const lite::Primitive *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| param_ = reinterpret_cast<ConstantOfShapeParameter *>(parameter); | |||||
| } | |||||
| ~ConstantOfShapeCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int DoExecute(int task_id); | |||||
| private: | |||||
| ConstantOfShapeParameter *param_; | |||||
| float *out_ptr_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_ | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32/constant_of_shape.h" | |||||
| int ConstantOfShape(float *output, int tid, ConstantOfShapeParameter *param) { | |||||
| int size = param->unit_; | |||||
| float data = param->value_; | |||||
| int ind_st = MSMIN(tid * size, param->element_sz_); | |||||
| int ind_end = MSMIN(param->element_sz_, (tid + 1) * size); | |||||
| for (int i = ind_st; i < ind_end; ++i) { | |||||
| output[i] = data; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONSTANT_OF_SHAPE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONSTANT_OF_SHAPE_H_ | |||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/op_base.h" | |||||
| #include "nnacl/errorcode.h" | |||||
| typedef struct ConstantOfShapeParameter { | |||||
| OpParameter op_parameter_; | |||||
| float value_; | |||||
| int unit_; | |||||
| int element_sz_; | |||||
| } ConstantOfShapeParameter; | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int ConstantOfShape(float *output, int tid, ConstantOfShapeParameter *param); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONSTANT_OF_SHAPE_H_ | |||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "mindspore/core/utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore { | |||||
| class TestConstantOfShapeFp32 : public mindspore::CommonTest { | |||||
| public: | |||||
| TestConstantOfShapeFp32() {} | |||||
| }; | |||||
| int ConstantOfShapeTestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_, | |||||
| float *a_ptr, std::vector<int> a_shape) { | |||||
| auto in_t = | |||||
| new lite::tensor::Tensor(kNumberTypeInt32, a_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||||
| in_t->MallocData(); | |||||
| memcpy(in_t->Data(), a_ptr, sizeof(float) * in_t->ElementsNum()); | |||||
| inputs_->push_back(in_t); | |||||
| std::vector<int> c_shape(in_t->ElementsNum()); | |||||
| for (int i = 0; i < c_shape.size(); ++i) { | |||||
| c_shape[i] = a_ptr[i]; | |||||
| } | |||||
| auto out_t = | |||||
| new lite::tensor::Tensor(kNumberTypeFloat, c_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1)); | |||||
| out_t->MallocData(); | |||||
| outputs_->push_back(out_t); | |||||
| return out_t->ElementsNum(); | |||||
| } | |||||
| TEST_F(TestConstantOfShapeFp32, Simple) { | |||||
| std::vector<lite::tensor::Tensor *> inputs_; | |||||
| std::vector<lite::tensor::Tensor *> outputs_; | |||||
| auto param = new ConstantOfShapeParameter(); | |||||
| param->value_ = 1; | |||||
| float a[] = {1, 2, 3, 4}; | |||||
| std::vector<int> a_shape = {4, 1, 1, 1}; | |||||
| // std::vector<int> c_shape = {2, 2, 2, 1}; | |||||
| int total_size = ConstantOfShapeTestInit(&inputs_, &outputs_, a, a_shape); | |||||
| auto ctx = new lite::Context; | |||||
| ctx->thread_num_ = 4; | |||||
| kernel::ConstantOfShapeCPUKernel *op = | |||||
| new kernel::ConstantOfShapeCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr); | |||||
| op->Init(); | |||||
| op->Run(); | |||||
| float correct[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; | |||||
| float *output = reinterpret_cast<float *>(outputs_[0]->Data()); | |||||
| for (int i = 0; i < 8; ++i) printf("%f ", output[i]); | |||||
| printf("\n"); | |||||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||||
| delete op; | |||||
| for (auto t : inputs_) delete t; | |||||
| for (auto t : outputs_) delete t; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -63,7 +63,7 @@ TEST_F(TestROIPoolingFp32, Simple) { | |||||
| std::vector<int> c_shape = {2, 2, 2, 1}; | std::vector<int> c_shape = {2, 2, 2, 1}; | ||||
| int total_size = ROIPoolingTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); | int total_size = ROIPoolingTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); | ||||
| auto ctx = new lite::Context; | auto ctx = new lite::Context; | ||||
| ctx->thread_num_ = 1; | |||||
| ctx->thread_num_ = 3; | |||||
| kernel::ROIPoolingCPUKernel *op = | kernel::ROIPoolingCPUKernel *op = | ||||
| new kernel::ROIPoolingCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr); | new kernel::ROIPoolingCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr); | ||||
| op->Init(); | op->Init(); | ||||