Merge pull request !7250 from liuwenhao/mastertags/v1.1.0
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| typedef struct LshProjectionParameter { | |||||
| OpParameter op_parameter_; | |||||
| int lsh_type_; | |||||
| int hash_shape_[2]; | |||||
| int in_item_num_; | |||||
| size_t in_item_size_; | |||||
| size_t seed_size_; | |||||
| size_t key_size_; | |||||
| int64_t real_dst_count; | |||||
| int task_id_; | |||||
| int64_t count_unit_; | |||||
| } LshProjectionParameter; | |||||
| #endif // MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_ | |||||
| @@ -14,12 +14,16 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/lsh_projection.h" | #include "src/ops/lsh_projection.h" | ||||
| #include "nnacl/lsh_projection_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int LshProjection::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_OK; } | int LshProjection::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_OK; } | ||||
| int LshProjection::GetLshType() const { return this->primitive_->value.AsLshProjection()->type; } | |||||
| #else | #else | ||||
| int LshProjection::GetLshType() const { return this->primitive_->value_as_LshProjection()->type(); } | |||||
| int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| MS_ASSERT(nullptr != fbb); | MS_ASSERT(nullptr != fbb); | ||||
| @@ -29,9 +33,51 @@ int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | #endif | ||||
| namespace { | |||||
| constexpr int kSparseType = 1; | |||||
| constexpr int kDenseType = 2; | |||||
| } // namespace | |||||
| int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| PrimitiveC::InferShape(inputs_, outputs_); | |||||
| return RET_INFER_INVALID; | |||||
| if (inputs_.size() != kDoubleNum || inputs_.size() != kMultiNum) { | |||||
| MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (outputs_.size() != kSingleNum) { | |||||
| MS_LOG(ERROR) << "outputs to Shape operator should be 1, but " << outputs_.size() << " is given."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto in_hash = inputs_.at(kSingleNum); | |||||
| MS_ASSERT(in_hash->shape().size() == 2); | |||||
| MS_ASSERT(in_hash->DimensionSize(1) <= 32); | |||||
| MS_ASSERT(inputs_.at(kDoubleNum)->shape().size() >= 1); | |||||
| if (inputs_.size() == kMultiNum) { | |||||
| MS_ASSERT(inputs_.at(kMultiNum)->shape().size() == 1); | |||||
| MS_ASSERT(inputs_.at(kMultiNum)->DimensionSize(0) == in_value->DimensionSize(0)); | |||||
| } | |||||
| auto out_tensor = outputs_.front(); | |||||
| out_tensor->set_data_type(kNumberTypeInt32); | |||||
| out_tensor->SetFormat(schema::Format::Format_NHWC); | |||||
| if (!GetInferFlag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| std::vector<int> out_shape; | |||||
| switch (GetLshType()) { | |||||
| case kSparseType: | |||||
| out_shape.push_back(in_hash->DimensionSize(0)); | |||||
| break; | |||||
| case kDenseType: | |||||
| out_shape.push_back(in_hash->DimensionSize(0) * in_hash->DimensionSize(1)); | |||||
| break; | |||||
| default: | |||||
| return RET_ERROR; | |||||
| } | |||||
| out_tensor->set_shape(out_shape); | |||||
| return RET_OK; | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,6 +33,7 @@ class LshProjection : public PrimitiveC { | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) override; | int InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) override; | ||||
| int GetLshType() const; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -54,6 +54,7 @@ | |||||
| #include "src/ops/resize.h" | #include "src/ops/resize.h" | ||||
| #include "src/ops/tile.h" | #include "src/ops/tile.h" | ||||
| #include "src/ops/one_hot.h" | #include "src/ops/one_hot.h" | ||||
| #include "src/ops/lsh_projection.h" | |||||
| #include "src/ops/space_to_depth.h" | #include "src/ops/space_to_depth.h" | ||||
| #include "src/ops/split.h" | #include "src/ops/split.h" | ||||
| #include "src/ops/argmax.h" | #include "src/ops/argmax.h" | ||||
| @@ -131,6 +132,7 @@ | |||||
| #include "nnacl/unstack.h" | #include "nnacl/unstack.h" | ||||
| #include "nnacl/depth_to_space.h" | #include "nnacl/depth_to_space.h" | ||||
| #include "nnacl/conv_parameter.h" | #include "nnacl/conv_parameter.h" | ||||
| #include "nnacl/lsh_projection_parameter.h" | |||||
| #include "nnacl/fp32/pooling.h" | #include "nnacl/fp32/pooling.h" | ||||
| #include "nnacl/matmul_parameter.h" | #include "nnacl/matmul_parameter.h" | ||||
| #include "nnacl/fp32/roi_pooling.h" | #include "nnacl/fp32/roi_pooling.h" | ||||
| @@ -1323,6 +1325,20 @@ OpParameter *PopulateCropParameter(const mindspore::lite::PrimitiveC *primitive) | |||||
| return reinterpret_cast<OpParameter *>(crop_param); | return reinterpret_cast<OpParameter *>(crop_param); | ||||
| } | } | ||||
| OpParameter *PopulateLshProjectionParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| LshProjectionParameter *lsh_project_param = | |||||
| reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter))); | |||||
| if (lsh_project_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc LshProjectionParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(lsh_project_param, 0, sizeof(LshProjectionParameter)); | |||||
| lsh_project_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::LshProjection *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| lsh_project_param->lsh_type_ = param->GetLshType(); | |||||
| return reinterpret_cast<OpParameter *>(lsh_project_param); | |||||
| } | |||||
| OpParameter *PopulateOneHotParameter(const mindspore::lite::PrimitiveC *primitive) { | OpParameter *PopulateOneHotParameter(const mindspore::lite::PrimitiveC *primitive) { | ||||
| OneHotParameter *one_hot_param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter))); | OneHotParameter *one_hot_param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter))); | ||||
| if (one_hot_param == nullptr) { | if (one_hot_param == nullptr) { | ||||
| @@ -1747,6 +1763,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||||
| populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; | populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; | populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter; | populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter; | |||||
| } | } | ||||
| PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | ||||
| @@ -0,0 +1,184 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/lsh_projection.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "src/common/string_util.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_MEMORY_FAILED; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_LshProjection; | |||||
| namespace mindspore::kernel { | |||||
| namespace { | |||||
| constexpr int kSparseType = 1; | |||||
| constexpr int kDenseType = 2; | |||||
| } // namespace | |||||
| int LshProjectionCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int LshProjectionCPUKernel::ReSize() { return RET_OK; } | |||||
| int LshProjectionCPUKernel::Run() { | |||||
| auto ret = Prepare(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| auto input_tensor0 = in_tensors_.at(0); | |||||
| auto input_tensor1 = in_tensors_.at(1); | |||||
| auto out_tensor0 = out_tensors_.at(0); | |||||
| hash = reinterpret_cast<float *>(input_tensor0->MutableData()); | |||||
| in_data = reinterpret_cast<char *>(input_tensor1->MutableData()); | |||||
| weight = in_tensors_.size() == 2 ? nullptr : reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| output = reinterpret_cast<int32_t *>(out_tensor0->MutableData()); | |||||
| const size_t seed_size = sizeof(float); | |||||
| const size_t input_item_size = | |||||
| input_tensor1->ElementsNum() * sizeof(input_tensor1->data_type()) / input_tensor1->DimensionSize(0); | |||||
| const size_t key_size = seed_size + input_item_size; | |||||
| lsh_param_->seed_size_ = seed_size; | |||||
| lsh_param_->in_item_size_ = input_item_size; | |||||
| lsh_param_->key_size_ = key_size; | |||||
| lsh_param_->in_item_num_ = input_tensor1->DimensionSize(0); | |||||
| memcpy(lsh_param_->hash_shape_, input_tensor0->shape().data(), sizeof(int) * input_tensor0->shape().size()); | |||||
| elements_num_ = input_tensor0->DimensionSize(0); | |||||
| count_unit_ = thread_num_ > 1 ? UP_DIV(elements_num_, thread_num_) : elements_num_; | |||||
| ret = ParallelLaunch(this->context_->thread_pool_, LshProjectionRun, this, thread_num_); | |||||
| return ret; | |||||
| } | |||||
| int LshProjectionRun(void *cdata, int task_id) { | |||||
| auto lsh_projection = reinterpret_cast<LshProjectionCPUKernel *>(cdata); | |||||
| lsh_projection->DoExecute(task_id); | |||||
| return RET_OK; | |||||
| } | |||||
| int LshProjectionCPUKernel::DoExecute(int task_id) { | |||||
| int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_); | |||||
| lsh_param_->real_dst_count = real_dst_count; | |||||
| lsh_param_->task_id_ = task_id; | |||||
| lsh_param_->count_unit_ = count_unit_; | |||||
| if (real_dst_count <= 0) { | |||||
| return lite::RET_OK; | |||||
| } | |||||
| switch (lsh_param_->lsh_type_) { | |||||
| case kSparseType: | |||||
| LshProjectionSparse(hash, in_data, weight, output, lsh_param_); | |||||
| break; | |||||
| case kDenseType: | |||||
| LshProjectionDense(hash, in_data, weight, output, lsh_param_); | |||||
| break; | |||||
| default: | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int LshProjectionCPUKernel::GetSignBit(char *in_data, float *weight, float seed, LshProjectionParameter *para) { | |||||
| double score = 0.0; | |||||
| for (int i = 0; i < para->in_item_num_; i++) { | |||||
| char *key = static_cast<char *>(ctx_->allocator->Malloc(lsh_param_->key_size_)); | |||||
| if (key == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc key failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memcpy(key, &seed, para->seed_size_); | |||||
| memcpy(key + para->seed_size_, in_data, para->in_item_size_); | |||||
| in_data += para->in_item_size_; | |||||
| double hash_sign = static_cast<double>(mindspore::lite::StringHash64(key, para->key_size_)); | |||||
| if (weight == nullptr) { | |||||
| score += hash_sign; | |||||
| } else { | |||||
| score += weight[i] * hash_sign; | |||||
| } | |||||
| ctx_->allocator->Free(key); | |||||
| } | |||||
| return (score > 0) ? 1 : 0; | |||||
| } | |||||
| void LshProjectionCPUKernel::LshProjectionSparse(float *hash, char *in_data, float *weight, int32_t *output, | |||||
| LshProjectionParameter *para) { | |||||
| int start = para->task_id_ * para->count_unit_; | |||||
| int end = start + para->real_dst_count; | |||||
| for (int i = start; i < end; i++) { | |||||
| int32_t hash_sign = 0; | |||||
| for (int j = 0; j < para->hash_shape_[1]; j++) { | |||||
| int bit = GetSignBit(in_data, weight, hash[i * para->hash_shape_[1] + j], para); | |||||
| hash_sign = (hash_sign << 1) | bit; | |||||
| } | |||||
| output[i] = hash_sign + i * (1 << para->hash_shape_[1]); | |||||
| } | |||||
| } | |||||
| void LshProjectionCPUKernel::LshProjectionDense(float *hash, char *in_data, float *weight, int32_t *output, | |||||
| LshProjectionParameter *para) { | |||||
| int start = para->task_id_ * para->count_unit_; | |||||
| int end = start + para->real_dst_count; | |||||
| for (int i = start; i < end; i++) { | |||||
| for (int j = 0; j < para->hash_shape_[1]; j++) { | |||||
| output[i * para->hash_shape_[1] + j] = GetSignBit(in_data, weight, hash[i * para->hash_shape_[1] + j], para); | |||||
| } | |||||
| } | |||||
| } | |||||
| kernel::LiteKernel *CpuLshProjectionFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, | |||||
| OpParameter *op_parameter, const lite::InnerContext *ctx, | |||||
| const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| if (op_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Input op_parameter is nullptr!"; | |||||
| return nullptr; | |||||
| } | |||||
| if (ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "Input context is nullptr!"; | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_LshProjection); | |||||
| auto *kernel = new (std::nothrow) LshProjectionCPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new LshProjectionCPUKernel fail!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed! name: " << op_parameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LshProjection, CpuLshProjectionFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_ | |||||
| #include <vector> | |||||
| #include "nnacl/lsh_projection_parameter.h" | |||||
| #include "src/lite_kernel.h" | |||||
| #include "schema/model_generated.h" | |||||
| namespace mindspore::kernel { | |||||
| class LshProjectionCPUKernel : public LiteKernel { | |||||
| public: | |||||
| LshProjectionCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) { | |||||
| lsh_param_ = reinterpret_cast<LshProjectionParameter *>(op_parameter_); | |||||
| } | |||||
| ~LshProjectionCPUKernel() = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int DoExecute(int task_id); | |||||
| int GetSignBit(char *in_data, float *weight, float seed, LshProjectionParameter *para); | |||||
| void LshProjectionSparse(float *hash, char *in_data, float *weight, int32_t *output, LshProjectionParameter *param); | |||||
| void LshProjectionDense(float *hash, char *in_data, float *weight, int32_t *output, LshProjectionParameter *param); | |||||
| private: | |||||
| LshProjectionParameter *lsh_param_ = nullptr; | |||||
| const lite::InnerContext *ctx_; | |||||
| int thread_num_; | |||||
| int64_t elements_num_; | |||||
| int64_t count_unit_; | |||||
| float *hash = nullptr; | |||||
| char *in_data = nullptr; | |||||
| float *weight = nullptr; | |||||
| int32_t *output = nullptr; | |||||
| }; | |||||
| int LshProjectionRun(void *cdata, int task_id); | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_ | |||||
| @@ -0,0 +1,164 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/nnacl/lsh_projection_parameter.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| #include "mindspore/lite/src/lite_kernel.h" | |||||
| #include "mindspore/lite/src/tensor.h" | |||||
| namespace mindspore { | |||||
| namespace { | |||||
| constexpr int kSparseType = 1; | |||||
| constexpr int kDenseType = 2; | |||||
| } // namespace | |||||
| class TestLshProjectionFp32 : public mindspore::CommonTest { | |||||
| public: | |||||
| TestLshProjectionFp32() {} | |||||
| }; | |||||
| TEST_F(TestLshProjectionFp32, Dense1DInputs) { | |||||
| lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2}); | |||||
| lite::Tensor in_tensor1(kNumberTypeInt32, {5}); | |||||
| lite::Tensor in_tensor2(kNumberTypeFloat, {5}); | |||||
| lite::Tensor out_tensor(kNumberTypeInt32, {6}); | |||||
| float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321}; | |||||
| int32_t input_data1[] = {12345, 54321, 67890, 9876, -12345678}; | |||||
| float input_data2[] = {1.0, 1.0, 1.0, 1.0, 1.0}; | |||||
| int32_t output_data[6] = {0}; | |||||
| in_tensor0.SetData(input_data0); | |||||
| in_tensor1.SetData(input_data1); | |||||
| in_tensor2.SetData(input_data2); | |||||
| out_tensor.SetData(output_data); | |||||
| std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1, &in_tensor2}; | |||||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||||
| LshProjectionParameter parameter = {}; | |||||
| parameter.lsh_type_ = kDenseType; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| auto ctx = std::make_shared<lite::InnerContext>(); | |||||
| ctx->thread_num_ = 3; | |||||
| ASSERT_EQ(lite::RET_OK, ctx->Init()); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto ret = kernel->Run(); | |||||
| EXPECT_EQ(0, ret); | |||||
| std::vector<int32_t> except_result = {0, 0, 0, 1, 0, 0}; | |||||
| PrintData("output data", output_data, 6); | |||||
| CompareOutputData(output_data, except_result.data(), 6, 0.000001); | |||||
| in_tensor0.SetData(nullptr); | |||||
| in_tensor1.SetData(nullptr); | |||||
| out_tensor.SetData(nullptr); | |||||
| } | |||||
| TEST_F(TestLshProjectionFp32, Sparse1DInputs) { | |||||
| lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2}); | |||||
| lite::Tensor in_tensor1(kNumberTypeInt32, {5}); | |||||
| lite::Tensor out_tensor(kNumberTypeInt32, {3}); | |||||
| float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321}; | |||||
| int32_t input_data1[] = {12345, 54321, 67890, 9876, -12345678}; | |||||
| int32_t output_data[3] = {0}; | |||||
| in_tensor0.SetData(input_data0); | |||||
| in_tensor1.SetData(input_data1); | |||||
| out_tensor.SetData(output_data); | |||||
| std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1}; | |||||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||||
| LshProjectionParameter parameter = {}; | |||||
| parameter.lsh_type_ = kSparseType; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| auto ctx = std::make_shared<lite::InnerContext>(); | |||||
| ctx->thread_num_ = 1; | |||||
| ASSERT_EQ(lite::RET_OK, ctx->Init()); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto ret = kernel->Run(); | |||||
| EXPECT_EQ(0, ret); | |||||
| std::vector<int32_t> except_result = {0, 5, 8}; | |||||
| PrintData("output data", output_data, 3); | |||||
| CompareOutputData(output_data, except_result.data(), 3, 0.000001); | |||||
| in_tensor0.SetData(nullptr); | |||||
| in_tensor1.SetData(nullptr); | |||||
| out_tensor.SetData(nullptr); | |||||
| } | |||||
| TEST_F(TestLshProjectionFp32, Sparse3DInputs) { | |||||
| lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2}); | |||||
| lite::Tensor in_tensor1(kNumberTypeInt32, {5, 2, 2}); | |||||
| lite::Tensor in_tensor2(kNumberTypeFloat, {5}); | |||||
| lite::Tensor out_tensor(kNumberTypeInt32, {3}); | |||||
| float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321}; | |||||
| int32_t input_data1[] = {1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912, | |||||
| 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543}; | |||||
| float input_data2[] = {0.12, 0.34, 0.56, 0.67, 0.78}; | |||||
| int32_t output_data[3] = {0}; | |||||
| in_tensor0.SetData(input_data0); | |||||
| in_tensor1.SetData(input_data1); | |||||
| in_tensor2.SetData(input_data2); | |||||
| out_tensor.SetData(output_data); | |||||
| std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1, &in_tensor2}; | |||||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||||
| LshProjectionParameter parameter = {}; | |||||
| parameter.lsh_type_ = kSparseType; | |||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| ASSERT_NE(creator, nullptr); | |||||
| auto ctx = std::make_shared<lite::InnerContext>(); | |||||
| ctx->thread_num_ = 3; | |||||
| ASSERT_EQ(lite::RET_OK, ctx->Init()); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr); | |||||
| ASSERT_NE(kernel, nullptr); | |||||
| auto ret = kernel->Run(); | |||||
| EXPECT_EQ(0, ret); | |||||
| std::vector<int32_t> except_result = {2, 5, 9}; | |||||
| PrintData("output data", output_data, 3); | |||||
| CompareOutputData(output_data, except_result.data(), 3, 0.000001); | |||||
| in_tensor0.SetData(nullptr); | |||||
| in_tensor1.SetData(nullptr); | |||||
| out_tensor.SetData(nullptr); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -290,7 +290,7 @@ TEST_F(TestMulInt8, Mul_quant1_thread1) { | |||||
| MulParameter op_param; | MulParameter op_param; | ||||
| op_param.op_parameter_.type_ = schema::PrimitiveType_Mul; | op_param.op_parameter_.type_ = schema::PrimitiveType_Mul; | ||||
| lite::InnerContext *ctx = new lite::InnerContext; | lite::InnerContext *ctx = new lite::InnerContext; | ||||
| ctx->thread_num_ = 2; | |||||
| ctx->thread_num_ = 3; | |||||
| ASSERT_EQ(lite::RET_OK, ctx->Init()); | ASSERT_EQ(lite::RET_OK, ctx->Init()); | ||||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Mul}; | kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Mul}; | ||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | ||||