| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| // * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32/sparse_to_dense.h" | |||||
| void SparseToDense(int **sparse_indices, int *output_shape, | |||||
| float *sparse_values, float default_value, float *output, | |||||
| bool isScalar, int index_start, int index_end, int out_width) { | |||||
| for (int i = index_start; i < index_end; i++) { | |||||
| for (int j = 0; j < out_width; j++) { | |||||
| output[i * out_width + j] = default_value; | |||||
| } | |||||
| } | |||||
| int d1 = output_shape[1] * output_shape[2] * output_shape[3]; | |||||
| int d2 = output_shape[2] * output_shape[3]; | |||||
| int d3 = output_shape[3]; | |||||
| int index; | |||||
| if (isScalar == true) { | |||||
| for (int i = index_start; i < index_end; i++) { | |||||
| index = d1 * sparse_indices[i][0] + d2 * sparse_indices[i][1] + | |||||
| d3 * sparse_indices[i][2] + sparse_indices[i][3]; | |||||
| output[index] = sparse_values[0]; | |||||
| } | |||||
| } else { | |||||
| for (int i = index_start; i < index_end; i++) { | |||||
| index = d1 * sparse_indices[i][0] + d2 * sparse_indices[i][1] + | |||||
| d3 * sparse_indices[i][2] + sparse_indices[i][3]; | |||||
| output[index] = sparse_values[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_SPARSETODENSE_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_SPARSETODENSE_H_ | |||||
| #include "nnacl/sparse_to_dense_parameter.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| void SparseToDense(int **sparse_indices_vect, int *output_shape, | |||||
| float *sparse_values, float default_value, float *output, | |||||
| bool isScalar, int index_start, int index_end, int out_width); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_SPARSETODENSE_H_ | |||||
| @@ -1,34 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| // * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/sparse_to_dense.h" | |||||
| void SparseToDense(int *input, int *output_shape_, float *snum, float *dnum, int sp_num, float *output, | |||||
| SparseToDenseParameter *s2d_param_, int task_id) { | |||||
| int m; | |||||
| for (int i = task_id; i < output_shape_[0]; i += s2d_param_->op_parameter_.thread_num_) { | |||||
| for (int j = 0; j < output_shape_[1]; j++) { | |||||
| m = i * output_shape_[1] + j; | |||||
| output[m] = dnum[0]; | |||||
| } | |||||
| } | |||||
| for (int j = 0; j < sp_num; j++) { | |||||
| int temp = j * 2; | |||||
| int temp1 = j * 2 + 1; | |||||
| int tempout1 = input[temp] * output_shape_[1] + input[temp1]; | |||||
| output[tempout1] = snum[j]; | |||||
| } | |||||
| } | |||||
| @@ -13,8 +13,9 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_NNACL_SPARSETODENSE_H_ | |||||
| #define MINDSPORE_LITE_NNACL_SPARSETODENSE_H_ | |||||
| #ifndef MINDSPORE_LITE_NNACL_SPARSE_TO_DENSE_PARAMETER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_SPARSE_TO_DENSE_PARAMETER_H_ | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| @@ -22,16 +23,6 @@ typedef struct SparseToDenseParameter { | |||||
| OpParameter op_parameter_; | OpParameter op_parameter_; | ||||
| bool validate_indices_; | bool validate_indices_; | ||||
| int thread_num_; | int thread_num_; | ||||
| int count_; | |||||
| } SparseToDenseParameter; | } SparseToDenseParameter; | ||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| void SparseToDense(int *input, int *output_shape_, float *snum, float *dnum, int sp_num, float *output, | |||||
| SparseToDenseParameter *s2d_param_, int task_id); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_SPARSETODENCE_H_ | |||||
| #endif // MINDSPORE_LITE_NNACL_SPARSE_TO_DENSE_PARAMETER_H_ | |||||
| @@ -42,5 +42,34 @@ int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | #endif | ||||
| int SparseToDense::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| MS_ASSERT(this->primitive_ != nullptr); | |||||
| MS_ASSERT(output_shape != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| if (output == nullptr) { | |||||
| MS_LOG(ERROR) << "output null pointer dereferencing."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto input2 = inputs_.at(2); | |||||
| outputs_[0]->set_data_type(input2->data_type()); | |||||
| outputs_[0]->SetFormat(input2->GetFormat()); | |||||
| if (!GetInferFlag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (this->primitive_ == nullptr) { | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto input1 = inputs_.at(1); | |||||
| int *input1_data = reinterpret_cast<int *>(input1->MutableData()); | |||||
| std::vector<int> output_shape; | |||||
| for (int i = 0; i < input1->ElementsNum(); i++) { | |||||
| output_shape.push_back(input1_data[i]); | |||||
| } | |||||
| outputs_[0]->set_shape(output_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,6 +45,7 @@ class SparseToDense : public PrimitiveC { | |||||
| std::vector<int> GetSparseValue() const; | std::vector<int> GetSparseValue() const; | ||||
| std::vector<int> GetDefaultValue() const; | std::vector<int> GetDefaultValue() const; | ||||
| bool GetValidateIndices() const; | bool GetValidateIndices() const; | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -170,7 +170,7 @@ | |||||
| #include "nnacl/fp32/embedding_lookup.h" | #include "nnacl/fp32/embedding_lookup.h" | ||||
| #include "nnacl/fp32/elu.h" | #include "nnacl/fp32/elu.h" | ||||
| #include "nnacl/leaky_relu_parameter.h" | #include "nnacl/leaky_relu_parameter.h" | ||||
| #include "nnacl/sparse_to_dense.h" | |||||
| #include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" | |||||
| #include "nnacl/l2_norm_parameter.h" | #include "nnacl/l2_norm_parameter.h" | ||||
| #include "nnacl/detection_post_process_parameter.h" | #include "nnacl/detection_post_process_parameter.h" | ||||
| #include "nnacl/fp32/exp.h" | #include "nnacl/fp32/exp.h" | ||||
| @@ -14,12 +14,14 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/arm/fp32/sparse_to_dense.h" | #include "src/runtime/kernel/arm/fp32/sparse_to_dense.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include "include/errorcode.h" | |||||
| #include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "schema/ops_generated.h" | #include "schema/ops_generated.h" | ||||
| #include "nnacl/sparse_to_dense.h" | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -30,12 +32,45 @@ using mindspore::schema::PrimitiveType_SparseToDense; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int SparseToDenseCPUKernel::Init() { | int SparseToDenseCPUKernel::Init() { | ||||
| s2d_param_->op_parameter_.thread_num_ = thread_count_; | |||||
| auto input2 = in_tensors_.at(2); | |||||
| auto input3 = in_tensors_.at(3); | |||||
| sparse_values = reinterpret_cast<float *>(input2->MutableData()); | |||||
| default_value = reinterpret_cast<float *>(input3->MutableData())[0]; | |||||
| if (input2->ElementsNum() == 1) { | |||||
| isScalar = true; | |||||
| } | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int SparseToDenseCPUKernel::ReSize() { | |||||
| auto output0 = out_tensors_.at(0); | |||||
| std::vector<int> out_shape_tensor = output0->shape(); | |||||
| auto output_shape_tmp = reinterpret_cast<int *>(out_shape_tensor.data()); | |||||
| int output_dim = output0->shape().size(); | |||||
| for (int i = 0; i < DIMENSION_4D - output_dim; i++) { | |||||
| output_shape[i] = 1; | |||||
| } | |||||
| for (int i = 0; i < output_dim; i++) { | |||||
| output_shape[i + DIMENSION_4D - output_dim] = output_shape_tmp[i]; | |||||
| } | |||||
| output_num = output0->ElementsNum(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int SparseToDenseCPUKernel::DoExcute(int task_id) { | int SparseToDenseCPUKernel::DoExcute(int task_id) { | ||||
| SparseToDense(input_data_, output_shape_, snum_, dnum_, sp_num_, output_data, s2d_param_, task_id); | |||||
| int real_dst_count = MSMIN(index_num - task_id * count_unit_, count_unit_); | |||||
| if (real_dst_count <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| int index_start = task_id * count_unit_; | |||||
| int index_end = index_start + real_dst_count; | |||||
| int out_width = output_num / index_num; | |||||
| SparseToDense(sparse_indices_vect, output_shape, sparse_values, | |||||
| default_value, output_data, isScalar, | |||||
| index_start, index_end, out_width); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -43,38 +78,117 @@ int SparseToDenseRun(void *cdata, int task_id) { | |||||
| auto s2ddata = reinterpret_cast<SparseToDenseCPUKernel *>(cdata); | auto s2ddata = reinterpret_cast<SparseToDenseCPUKernel *>(cdata); | ||||
| auto ret = s2ddata->DoExcute(task_id); | auto ret = s2ddata->DoExcute(task_id); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "SparseToDenseRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| MS_LOG(ERROR) << "SparseToDenseRun error task_id[" << task_id | |||||
| << "] error_code[" << ret << "]"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int SparseToDenseCPUKernel::GenerateIndices() { | |||||
| auto input0 = in_tensors_.at(0); | |||||
| index_dim = input0->shape().size(); | |||||
| index_num = input0->shape()[0]; | |||||
| int *sparse_indices = reinterpret_cast<int *>(input0->MutableData()); | |||||
| sparse_indices_vect = reinterpret_cast<int **>(ctx_->allocator->Malloc(sizeof(int *) * index_num)); | |||||
| if (sparse_indices_vect == nullptr) { | |||||
| MS_LOG(ERROR) << "Null pointer reference: sparse_indices_vect."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| switch (index_dim) { | |||||
| case 0: | |||||
| case 1: { | |||||
| for (int i = 0; i < index_num; i++) { | |||||
| sparse_indices_vect[i] = new int[DIMENSION_4D]; | |||||
| if (sparse_indices_vect[i] == nullptr) { | |||||
| MS_LOG(ERROR) << "Null pointer reference: sparse_indices_vect[" << i << "]."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (int j = 0; j < DIMENSION_4D - 1; j++) { | |||||
| sparse_indices_vect[i][j] = 0; | |||||
| } | |||||
| sparse_indices_vect[i][DIMENSION_4D - 1] = sparse_indices[i]; | |||||
| } | |||||
| break; | |||||
| } | |||||
| case 2: { | |||||
| int true_dims = input0->shape()[1]; | |||||
| MS_ASSERT(true_dims <= DIMENSION_4D); | |||||
| for (int i = 0; i < index_num; i++) { | |||||
| sparse_indices_vect[i] = new int[DIMENSION_4D]; | |||||
| if (sparse_indices_vect[i] == nullptr) { | |||||
| MS_LOG(ERROR) << "Null pointer reference: sparse_indices_vect[" << i << "]."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (int j = 0; j < DIMENSION_4D - true_dims; j++) { | |||||
| sparse_indices_vect[i][j] = 0; | |||||
| } | |||||
| for (int j = 0; j < true_dims; j++) { | |||||
| sparse_indices_vect[i][j + DIMENSION_4D - true_dims] = sparse_indices[i * true_dims + j]; | |||||
| } | |||||
| } | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| MS_LOG(ERROR) << "Indices dimensions is " << index_dim << ", which must be 0, 1 or 2"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int SparseToDenseCPUKernel::IndicesValidCheck() { | |||||
| int d1 = output_shape[1] * output_shape[2] * output_shape[3]; | |||||
| int d2 = output_shape[2] * output_shape[3]; | |||||
| int d3 = output_shape[3]; | |||||
| int index_before = -1; | |||||
| for (int i = 0; i < index_num; i++) { | |||||
| int index = d1 * sparse_indices_vect[i][0] + d2 * sparse_indices_vect[i][1] + | |||||
| d3 * sparse_indices_vect[i][2] + sparse_indices_vect[i][3]; | |||||
| if (index <= index_before) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| index_before = index; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int SparseToDenseCPUKernel::Run() { | int SparseToDenseCPUKernel::Run() { | ||||
| auto ret = Prepare(); | auto ret = Prepare(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Prepare failed."; | MS_LOG(ERROR) << "Prepare failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto input = in_tensors_.at(0); | |||||
| auto input1 = in_tensors_.at(1); | |||||
| auto input2 = in_tensors_.at(2); | |||||
| auto input3 = in_tensors_.at(3); | |||||
| auto output0 = out_tensors_.at(0); | |||||
| input_data_ = reinterpret_cast<int *>(input->MutableData()); | |||||
| total_number_ = reinterpret_cast<int *>(input1->MutableData()); | |||||
| snum_ = reinterpret_cast<float *>(input2->MutableData()); | |||||
| dnum_ = reinterpret_cast<float *>(input3->MutableData()); | |||||
| sp_num_ = static_cast<int>(input->ElementsNum() / 2); | |||||
| auto ret1 = GenerateIndices(); | |||||
| if (ret1 != RET_OK) { | |||||
| MS_LOG(ERROR) << "Generate Indices failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (s2d_param->validate_indices_ == true) { | |||||
| auto ret2 = IndicesValidCheck(); | |||||
| if (ret2 != RET_OK) { | |||||
| MS_LOG(ERROR) << "The sparse indices is not valid."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| output_data = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | output_data = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| std::vector<int> temp_shape = output0->shape(); | |||||
| output_shape_ = reinterpret_cast<int *>(temp_shape.data()); | |||||
| ret = ParallelLaunch(THREAD_POOL_DEFAULT, SparseToDenseRun, this, s2d_param_->thread_num_); | |||||
| count_unit_ = thread_count_ > 1 ? UP_DIV(index_num, thread_count_) : index_num; | |||||
| ret = ParallelLaunch(THREAD_POOL_DEFAULT, SparseToDenseRun, this, | |||||
| s2d_param->thread_num_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "SparseToDenseRun error: error_code[" << ret << "]"; | MS_LOG(ERROR) << "SparseToDenseRun error: error_code[" << ret << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| for (int i = 0; i < index_num; i++) { | |||||
| if (sparse_indices_vect[i] != nullptr) { | |||||
| delete sparse_indices_vect[i]; | |||||
| } | |||||
| } | |||||
| if (sparse_indices_vect != nullptr) { | |||||
| ctx_->allocator->Free(sparse_indices_vect); | |||||
| sparse_indices_vect = nullptr; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -88,20 +202,25 @@ kernel::LiteKernel *CpuSparseToDenseFp32KernelCreator(const std::vector<lite::Te | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_SparseToDense); | MS_ASSERT(desc.type == schema::PrimitiveType_SparseToDense); | ||||
| auto *kernel = new (std::nothrow) SparseToDenseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| auto *kernel = new (std::nothrow) | |||||
| SparseToDenseCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "new SparseToDenseCPUKernel fail!"; | MS_LOG(ERROR) << "new SparseToDenseCPUKernel fail!"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto ret = kernel->Init(); | auto ret = kernel->Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ | |||||
| << ", type: " | |||||
| << schema::EnumNamePrimitiveType( | |||||
| static_cast<schema::PrimitiveType>( | |||||
| opParameter->type_)); | |||||
| delete kernel; | delete kernel; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, CpuSparseToDenseFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, | |||||
| CpuSparseToDenseFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "include/context.h" | #include "include/context.h" | ||||
| #include "nnacl/sparse_to_dense.h" | |||||
| #include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" | |||||
| #include "src/runtime/kernel/arm/base/layout_transform.h" | #include "src/runtime/kernel/arm/base/layout_transform.h" | ||||
| using mindspore::lite::Context; | using mindspore::lite::Context; | ||||
| @@ -32,28 +32,34 @@ class SparseToDenseCPUKernel : public LiteKernel { | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::Context *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::Context *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { | : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { | ||||
| s2d_param_ = (reinterpret_cast<SparseToDenseParameter *>(op_parameter_)); | |||||
| s2d_param = (reinterpret_cast<SparseToDenseParameter *>(op_parameter_)); | |||||
| s2d_param->thread_num_ = thread_count_; | |||||
| } | } | ||||
| ~SparseToDenseCPUKernel() = default; | ~SparseToDenseCPUKernel() = default; | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override { return 0; } | |||||
| int ReSize() override; | |||||
| int Run() override; | int Run() override; | ||||
| int DoExcute(int task_id); | int DoExcute(int task_id); | ||||
| int GenerateIndices(); | |||||
| int IndicesValidCheck(); | |||||
| protected: | protected: | ||||
| const Context *ctx_; | const Context *ctx_; | ||||
| int thread_count_; | int thread_count_; | ||||
| SparseToDenseParameter *s2d_param_; | |||||
| SparseToDenseParameter *s2d_param; | |||||
| private: | private: | ||||
| int *input_data_; | |||||
| int *total_number_; | |||||
| int sp_num_; | |||||
| float *snum_; | |||||
| float *dnum_; | |||||
| float *output_data; | |||||
| int *output_shape_; | |||||
| int **sparse_indices_vect = nullptr; | |||||
| float *sparse_values = nullptr; | |||||
| float default_value; | |||||
| bool isScalar = false; | |||||
| int index_num; | |||||
| int index_dim; | |||||
| float *output_data = nullptr; | |||||
| int output_shape[4]; | |||||
| int output_num; | |||||
| int64_t count_unit_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ | ||||
| @@ -0,0 +1,452 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| #include "mindspore/lite/src/lite_kernel.h" | |||||
| #include "mindspore/lite/src/tensor.h" | |||||
| namespace mindspore { | |||||
| class TestSparseToDenseFp32 : public mindspore::CommonTest { | |||||
| public: | |||||
| TestSparseToDenseFp32() {} | |||||
| }; | |||||
| TEST_F(TestSparseToDenseFp32, SparseToDense_test1) { | |||||
| std::vector<int> input1 = {0, 0, 1, 2, 2, 3, 3, 6, 4, 7, 5, 9}; | |||||
| std::vector<int> shape1 = {6, 2}; | |||||
| std::vector<int> input2 = {6, 10}; | |||||
| std::vector<int> shape2 = {2}; | |||||
| std::vector<float> input3 = {1}; | |||||
| std::vector<int> shape3 = {1}; | |||||
| std::vector<float> input4 = {0}; | |||||
| std::vector<int> shape4 = {1}; | |||||
| TypeId tid = kNumberTypeFloat32; | |||||
| lite::Tensor *input_tensor1 = new lite::Tensor; | |||||
| input_tensor1->SetData(input1.data()); | |||||
| input_tensor1->set_shape(shape1); | |||||
| input_tensor1->set_data_type(tid); | |||||
| lite::Tensor *input_tensor2 = new lite::Tensor; | |||||
| input_tensor2->SetData(input2.data()); | |||||
| input_tensor2->set_shape(shape2); | |||||
| input_tensor2->set_data_type(tid); | |||||
| lite::Tensor *input_tensor3 = new lite::Tensor; | |||||
| input_tensor3->SetData(input3.data()); | |||||
| input_tensor3->set_shape(shape3); | |||||
| input_tensor3->set_data_type(tid); | |||||
| lite::Tensor *input_tensor4 = new lite::Tensor; | |||||
| input_tensor4->SetData(input4.data()); | |||||
| input_tensor4->set_shape(shape4); | |||||
| input_tensor4->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> inputs_tensor(4); | |||||
| inputs_tensor[0] = input_tensor1; | |||||
| inputs_tensor[1] = input_tensor2; | |||||
| inputs_tensor[2] = input_tensor3; | |||||
| inputs_tensor[3] = input_tensor4; | |||||
| const int output_size = 60; | |||||
| float output[60]; | |||||
| std::vector<int> output_shape = {6, 10}; | |||||
| lite::Tensor *output0_tensor = new lite::Tensor; | |||||
| output0_tensor->SetData(output); | |||||
| output0_tensor->set_shape(output_shape); | |||||
| output0_tensor->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> outputs_tensor(1); | |||||
| outputs_tensor[0] = output0_tensor; | |||||
| SparseToDenseParameter op_param; | |||||
| op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; | |||||
| lite::Context *ctx = new lite::Context; | |||||
| ctx->thread_num_ = 3; | |||||
| op_param.validate_indices_ = false; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| kernel::LiteKernel *kernel = | |||||
| creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx, desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto output_tensor_shape = output0_tensor->shape(); | |||||
| ASSERT_EQ(output_tensor_shape, output_shape); | |||||
| kernel->Run(); | |||||
| std::vector<float> except_result = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; | |||||
| PrintData("output data", output, output_size); | |||||
| PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); | |||||
| CompareOutputData(output, except_result.data(), output_size, 0.000001); | |||||
| input_tensor1->SetData(nullptr); | |||||
| input_tensor2->SetData(nullptr); | |||||
| input_tensor3->SetData(nullptr); | |||||
| input_tensor4->SetData(nullptr); | |||||
| output0_tensor->SetData(nullptr); | |||||
| delete input_tensor1; | |||||
| delete input_tensor2; | |||||
| delete input_tensor3; | |||||
| delete input_tensor4; | |||||
| delete output0_tensor; | |||||
| delete ctx; | |||||
| } | |||||
| TEST_F(TestSparseToDenseFp32, SparseToDense_test2) { | |||||
| std::vector<int> input1 = {0, 0, 1, 2, 2, 3, 3, 6, 4, 7, 5, 9}; | |||||
| std::vector<int> shape1 = {6, 2}; | |||||
| std::vector<int> input2 = {6, 10}; | |||||
| std::vector<int> shape2 = {2}; | |||||
| std::vector<float> input3 = {1, 2, 3, 4, 5, 6}; | |||||
| std::vector<int> shape3 = {6}; | |||||
| std::vector<float> input4 = {0}; | |||||
| std::vector<int> shape4 = {1}; | |||||
| TypeId tid = kNumberTypeFloat32; | |||||
| lite::Tensor *input_tensor1 = new lite::Tensor; | |||||
| input_tensor1->SetData(input1.data()); | |||||
| input_tensor1->set_shape(shape1); | |||||
| input_tensor1->set_data_type(tid); | |||||
| lite::Tensor *input_tensor2 = new lite::Tensor; | |||||
| input_tensor2->SetData(input2.data()); | |||||
| input_tensor2->set_shape(shape2); | |||||
| input_tensor2->set_data_type(tid); | |||||
| lite::Tensor *input_tensor3 = new lite::Tensor; | |||||
| input_tensor3->SetData(input3.data()); | |||||
| input_tensor3->set_shape(shape3); | |||||
| input_tensor3->set_data_type(tid); | |||||
| lite::Tensor *input_tensor4 = new lite::Tensor; | |||||
| input_tensor4->SetData(input4.data()); | |||||
| input_tensor4->set_shape(shape4); | |||||
| input_tensor4->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> inputs_tensor(4); | |||||
| inputs_tensor[0] = input_tensor1; | |||||
| inputs_tensor[1] = input_tensor2; | |||||
| inputs_tensor[2] = input_tensor3; | |||||
| inputs_tensor[3] = input_tensor4; | |||||
| const int output_size = 60; | |||||
| float output[60]; | |||||
| std::vector<int> output_shape = {6, 10}; | |||||
| lite::Tensor *output0_tensor = new lite::Tensor; | |||||
| output0_tensor->SetData(output); | |||||
| output0_tensor->set_shape(output_shape); | |||||
| output0_tensor->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> outputs_tensor(1); | |||||
| outputs_tensor[0] = output0_tensor; | |||||
| SparseToDenseParameter op_param; | |||||
| op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; | |||||
| lite::Context *ctx = new lite::Context; | |||||
| ctx->thread_num_ = 2; | |||||
| op_param.validate_indices_ = false; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| kernel::LiteKernel *kernel = | |||||
| creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx, desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto output_tensor_shape = output0_tensor->shape(); | |||||
| ASSERT_EQ(output_tensor_shape, output_shape); | |||||
| kernel->Run(); | |||||
| std::vector<float> except_result = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 6}; | |||||
| PrintData("output data", output, output_size); | |||||
| PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); | |||||
| CompareOutputData(output, except_result.data(), output_size, 0.000001); | |||||
| input_tensor1->SetData(nullptr); | |||||
| input_tensor2->SetData(nullptr); | |||||
| input_tensor3->SetData(nullptr); | |||||
| input_tensor4->SetData(nullptr); | |||||
| output0_tensor->SetData(nullptr); | |||||
| delete input_tensor1; | |||||
| delete input_tensor2; | |||||
| delete input_tensor3; | |||||
| delete input_tensor4; | |||||
| delete output0_tensor; | |||||
| delete ctx; | |||||
| } | |||||
| TEST_F(TestSparseToDenseFp32, SparseToDense_test3) { | |||||
| std::vector<int> input1 = {1, 3, 4}; | |||||
| std::vector<int> shape1 = {3}; | |||||
| std::vector<int> input2 = {1, 10}; | |||||
| std::vector<int> shape2 = {2}; | |||||
| std::vector<float> input3 = {1}; | |||||
| std::vector<int> shape3 = {1}; | |||||
| std::vector<float> input4 = {0}; | |||||
| std::vector<int> shape4 = {1}; | |||||
| TypeId tid = kNumberTypeFloat32; | |||||
| lite::Tensor *input_tensor1 = new lite::Tensor; | |||||
| input_tensor1->SetData(input1.data()); | |||||
| input_tensor1->set_shape(shape1); | |||||
| input_tensor1->set_data_type(tid); | |||||
| lite::Tensor *input_tensor2 = new lite::Tensor; | |||||
| input_tensor2->SetData(input2.data()); | |||||
| input_tensor2->set_shape(shape2); | |||||
| input_tensor2->set_data_type(tid); | |||||
| lite::Tensor *input_tensor3 = new lite::Tensor; | |||||
| input_tensor3->SetData(input3.data()); | |||||
| input_tensor3->set_shape(shape3); | |||||
| input_tensor3->set_data_type(tid); | |||||
| lite::Tensor *input_tensor4 = new lite::Tensor; | |||||
| input_tensor4->SetData(input4.data()); | |||||
| input_tensor4->set_shape(shape4); | |||||
| input_tensor4->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> inputs_tensor(4); | |||||
| inputs_tensor[0] = input_tensor1; | |||||
| inputs_tensor[1] = input_tensor2; | |||||
| inputs_tensor[2] = input_tensor3; | |||||
| inputs_tensor[3] = input_tensor4; | |||||
| const int output_size = 10; | |||||
| float output[10]; | |||||
| std::vector<int> output_shape = {1, 10}; | |||||
| lite::Tensor *output0_tensor = new lite::Tensor; | |||||
| output0_tensor->SetData(output); | |||||
| output0_tensor->set_shape(output_shape); | |||||
| output0_tensor->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> outputs_tensor(1); | |||||
| outputs_tensor[0] = output0_tensor; | |||||
| SparseToDenseParameter op_param; | |||||
| op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; | |||||
| lite::Context *ctx = new lite::Context; | |||||
| ctx->thread_num_ = 2; | |||||
| op_param.validate_indices_ = true; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| kernel::LiteKernel *kernel = | |||||
| creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx, desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto output_tensor_shape = output0_tensor->shape(); | |||||
| ASSERT_EQ(output_tensor_shape, output_shape); | |||||
| kernel->Run(); | |||||
| std::vector<float> except_result = {0, 1, 0, 1, 1, 0, 0, 0, 0, 0}; | |||||
| PrintData("output data", output, output_size); | |||||
| PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); | |||||
| CompareOutputData(output, except_result.data(), output_size, 0.000001); | |||||
| input_tensor1->SetData(nullptr); | |||||
| input_tensor2->SetData(nullptr); | |||||
| input_tensor3->SetData(nullptr); | |||||
| input_tensor4->SetData(nullptr); | |||||
| output0_tensor->SetData(nullptr); | |||||
| delete input_tensor1; | |||||
| delete input_tensor2; | |||||
| delete input_tensor3; | |||||
| delete input_tensor4; | |||||
| delete output0_tensor; | |||||
| delete ctx; | |||||
| } | |||||
| TEST_F(TestSparseToDenseFp32, SparseToDense_test4) { | |||||
| std::vector<int> input1 = {5}; | |||||
| std::vector<int> shape1 = {1}; | |||||
| std::vector<int> input2 = {10}; | |||||
| std::vector<int> shape2 = {1}; | |||||
| std::vector<float> input3 = {1}; | |||||
| std::vector<int> shape3 = {1}; | |||||
| std::vector<float> input4 = {0}; | |||||
| std::vector<int> shape4 = {1}; | |||||
| TypeId tid = kNumberTypeFloat32; | |||||
| lite::Tensor *input_tensor1 = new lite::Tensor; | |||||
| input_tensor1->SetData(input1.data()); | |||||
| input_tensor1->set_shape(shape1); | |||||
| input_tensor1->set_data_type(tid); | |||||
| lite::Tensor *input_tensor2 = new lite::Tensor; | |||||
| input_tensor2->SetData(input2.data()); | |||||
| input_tensor2->set_shape(shape2); | |||||
| input_tensor2->set_data_type(tid); | |||||
| lite::Tensor *input_tensor3 = new lite::Tensor; | |||||
| input_tensor3->SetData(input3.data()); | |||||
| input_tensor3->set_shape(shape3); | |||||
| input_tensor3->set_data_type(tid); | |||||
| lite::Tensor *input_tensor4 = new lite::Tensor; | |||||
| input_tensor4->SetData(input4.data()); | |||||
| input_tensor4->set_shape(shape4); | |||||
| input_tensor4->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> inputs_tensor(4); | |||||
| inputs_tensor[0] = input_tensor1; | |||||
| inputs_tensor[1] = input_tensor2; | |||||
| inputs_tensor[2] = input_tensor3; | |||||
| inputs_tensor[3] = input_tensor4; | |||||
| const int output_size = 10; | |||||
| float output[10]; | |||||
| std::vector<int> output_shape = {1, 10}; | |||||
| lite::Tensor *output0_tensor = new lite::Tensor; | |||||
| output0_tensor->SetData(output); | |||||
| output0_tensor->set_shape(output_shape); | |||||
| output0_tensor->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> outputs_tensor(1); | |||||
| outputs_tensor[0] = output0_tensor; | |||||
| SparseToDenseParameter op_param; | |||||
| op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; | |||||
| lite::Context *ctx = new lite::Context; | |||||
| ctx->thread_num_ = 2; | |||||
| op_param.validate_indices_ = true; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| kernel::LiteKernel *kernel = | |||||
| creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx, desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto output_tensor_shape = output0_tensor->shape(); | |||||
| ASSERT_EQ(output_tensor_shape, output_shape); | |||||
| kernel->Run(); | |||||
| std::vector<float> except_result = {0, 0, 0, 0, 0, 1, 0, 0, 0, 0}; | |||||
| PrintData("output data", output, output_size); | |||||
| PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); | |||||
| CompareOutputData(output, except_result.data(), output_size, 0.000001); | |||||
| input_tensor1->SetData(nullptr); | |||||
| input_tensor2->SetData(nullptr); | |||||
| input_tensor3->SetData(nullptr); | |||||
| input_tensor4->SetData(nullptr); | |||||
| output0_tensor->SetData(nullptr); | |||||
| delete input_tensor1; | |||||
| delete input_tensor2; | |||||
| delete input_tensor3; | |||||
| delete input_tensor4; | |||||
| delete output0_tensor; | |||||
| delete ctx; | |||||
| } | |||||
| TEST_F(TestSparseToDenseFp32, SparseToDense_test5) { | |||||
| std::vector<int> input1 = {0, 0, 1, 2, 2, 3, 2, 3, 4, 7, 5, 9}; | |||||
| std::vector<int> shape1 = {6, 2}; | |||||
| std::vector<int> input2 = {6, 10}; | |||||
| std::vector<int> shape2 = {2}; | |||||
| std::vector<float> input3 = {1, 2, 3, 4, 5, 6}; | |||||
| std::vector<int> shape3 = {6}; | |||||
| std::vector<float> input4 = {0}; | |||||
| std::vector<int> shape4 = {1}; | |||||
| TypeId tid = kNumberTypeFloat32; | |||||
| lite::Tensor *input_tensor1 = new lite::Tensor; | |||||
| input_tensor1->SetData(input1.data()); | |||||
| input_tensor1->set_shape(shape1); | |||||
| input_tensor1->set_data_type(tid); | |||||
| lite::Tensor *input_tensor2 = new lite::Tensor; | |||||
| input_tensor2->SetData(input2.data()); | |||||
| input_tensor2->set_shape(shape2); | |||||
| input_tensor2->set_data_type(tid); | |||||
| lite::Tensor *input_tensor3 = new lite::Tensor; | |||||
| input_tensor3->SetData(input3.data()); | |||||
| input_tensor3->set_shape(shape3); | |||||
| input_tensor3->set_data_type(tid); | |||||
| lite::Tensor *input_tensor4 = new lite::Tensor; | |||||
| input_tensor4->SetData(input4.data()); | |||||
| input_tensor4->set_shape(shape4); | |||||
| input_tensor4->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> inputs_tensor(4); | |||||
| inputs_tensor[0] = input_tensor1; | |||||
| inputs_tensor[1] = input_tensor2; | |||||
| inputs_tensor[2] = input_tensor3; | |||||
| inputs_tensor[3] = input_tensor4; | |||||
| const int output_size = 60; | |||||
| float output[60]; | |||||
| std::vector<int> output_shape = {6, 10}; | |||||
| lite::Tensor *output0_tensor = new lite::Tensor; | |||||
| output0_tensor->SetData(output); | |||||
| output0_tensor->set_shape(output_shape); | |||||
| output0_tensor->set_data_type(tid); | |||||
| std::vector<lite::Tensor *> outputs_tensor(1); | |||||
| outputs_tensor[0] = output0_tensor; | |||||
| SparseToDenseParameter op_param; | |||||
| op_param.op_parameter_.type_ = schema::PrimitiveType_SpaceToDepth; | |||||
| lite::Context *ctx = new lite::Context; | |||||
| ctx->thread_num_ = 2; | |||||
| op_param.validate_indices_ = true; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, tid, schema::PrimitiveType_SparseToDense}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| kernel::LiteKernel *kernel = | |||||
| creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx, desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto output_tensor_shape = output0_tensor->shape(); | |||||
| ASSERT_EQ(output_tensor_shape, output_shape); | |||||
| kernel->Run(); | |||||
| std::vector<float> except_result = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 6}; | |||||
| PrintData("output data", output, output_size); | |||||
| PrintData("output data shape", output_tensor_shape.data(), output_tensor_shape.size()); | |||||
| CompareOutputData(output, except_result.data(), output_size, 0.000001); | |||||
| input_tensor1->SetData(nullptr); | |||||
| input_tensor2->SetData(nullptr); | |||||
| input_tensor3->SetData(nullptr); | |||||
| input_tensor4->SetData(nullptr); | |||||
| output0_tensor->SetData(nullptr); | |||||
| delete input_tensor1; | |||||
| delete input_tensor2; | |||||
| delete input_tensor3; | |||||
| delete input_tensor4; | |||||
| delete output0_tensor; | |||||
| delete ctx; | |||||
| } | |||||
| } // namespace mindspore | |||||