| @@ -34,26 +34,27 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv | |||
| int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.at(0); | |||
| MS_ASSERT(input != nullptr); | |||
| if (input->data_c() == nullptr) { | |||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| int string_num = lite::GetStringCount(input); | |||
| auto output0 = outputs_.at(0); | |||
| auto output1 = outputs_.at(1); | |||
| MS_ASSERT(input != nullptr); | |||
| MS_ASSERT(output0 != nullptr); | |||
| MS_ASSERT(output1 != nullptr); | |||
| output0->set_data_type(input->data_type()); | |||
| output0->SetFormat(input->GetFormat()); | |||
| output1->set_data_type(input->data_type()); | |||
| output1->SetFormat(input->GetFormat()); | |||
| if (input->data_c() == nullptr) { | |||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| std::vector<int> shape; | |||
| int string_num = lite::GetStringCount(input); | |||
| shape.push_back(string_num == 0 ? 1 : string_num); | |||
| output0->set_shape(shape); | |||
| output0->set_data_type(input->data_type()); | |||
| output0->SetFormat(input->GetFormat()); | |||
| output1->set_shape(shape); | |||
| output1->set_data_type(input->data_type()); | |||
| output1->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/hashtable_lookup.h" | |||
| #include "src/common/string_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -30,8 +31,34 @@ int HashtableLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla | |||
| } | |||
| #endif | |||
| int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| PrimitiveC::InferShape(inputs_, outputs_); | |||
| return RET_INFER_INVALID; | |||
| auto input = inputs_.at(0); | |||
| auto values = inputs_.at(2); | |||
| auto output = outputs_.at(0); | |||
| auto hits = outputs_.at(1); | |||
| MS_ASSERT(input != nullptr); | |||
| MS_ASSERT(keys != nullptr); | |||
| MS_ASSERT(values != nullptr); | |||
| MS_ASSERT(output != nullptr); | |||
| MS_ASSERT(hits != nullptr); | |||
| std::vector<int> hits_shape; | |||
| hits_shape.push_back(input->DimensionSize(0)); | |||
| output->set_data_type(values->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| hits->set_shape(hits_shape); | |||
| hits->set_data_type(kNumberTypeUInt8); | |||
| hits->SetFormat(input->GetFormat()); | |||
| if (input->data_c() == nullptr) { | |||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| int string_num = lite::GetStringCount(input); | |||
| std::vector<int> output_shape; | |||
| output_shape.push_back(string_num == 0 ? 1 : string_num); | |||
| output->set_shape(output_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1746,6 +1746,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||
| populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter; | |||
| } | |||
| PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | |||
| @@ -0,0 +1,96 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/string/hashtable_lookup.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/common/string_util.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_HashtableLookup; | |||
| namespace mindspore::kernel { | |||
| int HashtableLookupCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int HashtableLookupCPUKernel::ReSize() { return RET_OK; } | |||
| static int CmpKeyFunc(const void *lhs, const void *rhs) { | |||
| return *static_cast<const int *>(lhs) - *static_cast<const int *>(rhs); | |||
| } | |||
| int HashtableLookupCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; | |||
| return ret; | |||
| } | |||
| auto input_tensor = in_tensors_.at(0); | |||
| auto keys_tensor = in_tensors_.at(1); | |||
| auto values_tensor = in_tensors_.at(2); | |||
| auto output_tensor = out_tensors_.at(0); | |||
| auto hits_tensor = out_tensors_.at(1); | |||
| int rows = values_tensor->DimensionSize(0); | |||
| int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData()); | |||
| uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData()); | |||
| std::vector<lite::StringPack> output_string_pack; | |||
| std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor); | |||
| for (int i = 0; i < input_tensor->ElementsNum(); i++) { | |||
| int index = -1; | |||
| void *p = bsearch(&(input_data[i]), keys_tensor->MutableData(), rows, sizeof(int32_t), CmpKeyFunc); | |||
| if (p != nullptr) { | |||
| index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData()); | |||
| } | |||
| if (index >= rows || index < 0) { | |||
| lite::StringPack tmp = {0, nullptr}; | |||
| output_string_pack.push_back(tmp); | |||
| hits_data[i] = 0; | |||
| } else { | |||
| output_string_pack.push_back(all_string_pack[i]); | |||
| hits_data[i] = 1; | |||
| } | |||
| } | |||
| WriteStringsToTensor(output_tensor, output_string_pack); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuHashtableLookupKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *kernel = new (std::nothrow) HashtableLookupCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new HashtableLookupCPUKernel fail!"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_HashtableLookup, CpuHashtableLookupKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "include/context.h" | |||
| namespace mindspore::kernel { | |||
| class HashtableLookupCPUKernel : public LiteKernel { | |||
| public: | |||
| HashtableLookupCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~HashtableLookupCPUKernel() {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ | |||
| @@ -68,7 +68,7 @@ std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() { | |||
| return label_info_vec; | |||
| } | |||
| bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } | |||
| static bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } | |||
| int PredictCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||