| @@ -34,26 +34,27 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv | |||||
| int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| auto input = inputs_.at(0); | auto input = inputs_.at(0); | ||||
| MS_ASSERT(input != nullptr); | |||||
| if (input->data_c() == nullptr) { | |||||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| int string_num = lite::GetStringCount(input); | |||||
| auto output0 = outputs_.at(0); | auto output0 = outputs_.at(0); | ||||
| auto output1 = outputs_.at(1); | auto output1 = outputs_.at(1); | ||||
| MS_ASSERT(input != nullptr); | |||||
| MS_ASSERT(output0 != nullptr); | MS_ASSERT(output0 != nullptr); | ||||
| MS_ASSERT(output1 != nullptr); | MS_ASSERT(output1 != nullptr); | ||||
| output0->set_data_type(input->data_type()); | |||||
| output0->SetFormat(input->GetFormat()); | |||||
| output1->set_data_type(input->data_type()); | |||||
| output1->SetFormat(input->GetFormat()); | |||||
| if (input->data_c() == nullptr) { | |||||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| std::vector<int> shape; | std::vector<int> shape; | ||||
| int string_num = lite::GetStringCount(input); | |||||
| shape.push_back(string_num == 0 ? 1 : string_num); | shape.push_back(string_num == 0 ? 1 : string_num); | ||||
| output0->set_shape(shape); | output0->set_shape(shape); | ||||
| output0->set_data_type(input->data_type()); | |||||
| output0->SetFormat(input->GetFormat()); | |||||
| output1->set_shape(shape); | output1->set_shape(shape); | ||||
| output1->set_data_type(input->data_type()); | |||||
| output1->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -14,6 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/hashtable_lookup.h" | #include "src/ops/hashtable_lookup.h" | ||||
| #include "src/common/string_util.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -30,8 +31,34 @@ int HashtableLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla | |||||
| } | } | ||||
| #endif | #endif | ||||
| int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| PrimitiveC::InferShape(inputs_, outputs_); | |||||
| return RET_INFER_INVALID; | |||||
| auto input = inputs_.at(0); | |||||
| auto values = inputs_.at(2); | |||||
| auto output = outputs_.at(0); | |||||
| auto hits = outputs_.at(1); | |||||
| MS_ASSERT(input != nullptr); | |||||
| MS_ASSERT(keys != nullptr); | |||||
| MS_ASSERT(values != nullptr); | |||||
| MS_ASSERT(output != nullptr); | |||||
| MS_ASSERT(hits != nullptr); | |||||
| std::vector<int> hits_shape; | |||||
| hits_shape.push_back(input->DimensionSize(0)); | |||||
| output->set_data_type(values->data_type()); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| hits->set_shape(hits_shape); | |||||
| hits->set_data_type(kNumberTypeUInt8); | |||||
| hits->SetFormat(input->GetFormat()); | |||||
| if (input->data_c() == nullptr) { | |||||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| int string_num = lite::GetStringCount(input); | |||||
| std::vector<int> output_shape; | |||||
| output_shape.push_back(string_num == 0 ? 1 : string_num); | |||||
| output->set_shape(output_shape); | |||||
| return RET_OK; | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1746,6 +1746,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||||
| populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; | populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; | populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; | populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter; | |||||
| } | } | ||||
| PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | ||||
| @@ -0,0 +1,96 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/string/hashtable_lookup.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/common/string_util.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_HashtableLookup; | |||||
| namespace mindspore::kernel { | |||||
| int HashtableLookupCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int HashtableLookupCPUKernel::ReSize() { return RET_OK; } | |||||
| static int CmpKeyFunc(const void *lhs, const void *rhs) { | |||||
| return *static_cast<const int *>(lhs) - *static_cast<const int *>(rhs); | |||||
| } | |||||
| int HashtableLookupCPUKernel::Run() { | |||||
| auto ret = Prepare(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; | |||||
| return ret; | |||||
| } | |||||
| auto input_tensor = in_tensors_.at(0); | |||||
| auto keys_tensor = in_tensors_.at(1); | |||||
| auto values_tensor = in_tensors_.at(2); | |||||
| auto output_tensor = out_tensors_.at(0); | |||||
| auto hits_tensor = out_tensors_.at(1); | |||||
| int rows = values_tensor->DimensionSize(0); | |||||
| int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData()); | |||||
| uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData()); | |||||
| std::vector<lite::StringPack> output_string_pack; | |||||
| std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor); | |||||
| for (int i = 0; i < input_tensor->ElementsNum(); i++) { | |||||
| int index = -1; | |||||
| void *p = bsearch(&(input_data[i]), keys_tensor->MutableData(), rows, sizeof(int32_t), CmpKeyFunc); | |||||
| if (p != nullptr) { | |||||
| index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData()); | |||||
| } | |||||
| if (index >= rows || index < 0) { | |||||
| lite::StringPack tmp = {0, nullptr}; | |||||
| output_string_pack.push_back(tmp); | |||||
| hits_data[i] = 0; | |||||
| } else { | |||||
| output_string_pack.push_back(all_string_pack[i]); | |||||
| hits_data[i] = 1; | |||||
| } | |||||
| } | |||||
| WriteStringsToTensor(output_tensor, output_string_pack); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuHashtableLookupKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto *kernel = new (std::nothrow) HashtableLookupCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new HashtableLookupCPUKernel fail!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_HashtableLookup, CpuHashtableLookupKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "include/context.h" | |||||
| namespace mindspore::kernel { | |||||
| class HashtableLookupCPUKernel : public LiteKernel { | |||||
| public: | |||||
| HashtableLookupCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~HashtableLookupCPUKernel() {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ | |||||
| @@ -68,7 +68,7 @@ std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() { | |||||
| return label_info_vec; | return label_info_vec; | ||||
| } | } | ||||
| bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } | |||||
| static bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } | |||||
| int PredictCPUKernel::Run() { | int PredictCPUKernel::Run() { | ||||
| auto ret = Prepare(); | auto ret = Prepare(); | ||||