diff --git a/mindspore/lite/src/ops/custom_extract_features.cc b/mindspore/lite/src/ops/custom_extract_features.cc index 601c1321e6..27c68cd5e1 100644 --- a/mindspore/lite/src/ops/custom_extract_features.cc +++ b/mindspore/lite/src/ops/custom_extract_features.cc @@ -34,26 +34,27 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv int CustomExtractFeatures::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.at(0); - MS_ASSERT(input != nullptr); - if (input->data_c() == nullptr) { - MS_LOG(INFO) << "Do infer shape in runtime."; - return RET_INFER_INVALID; - } - int string_num = lite::GetStringCount(input); auto output0 = outputs_.at(0); auto output1 = outputs_.at(1); + MS_ASSERT(input != nullptr); MS_ASSERT(output0 != nullptr); MS_ASSERT(output1 != nullptr); + output0->set_data_type(input->data_type()); + output0->SetFormat(input->GetFormat()); + output1->set_data_type(input->data_type()); + output1->SetFormat(input->GetFormat()); + + if (input->data_c() == nullptr) { + MS_LOG(INFO) << "Do infer shape in runtime."; + return RET_INFER_INVALID; + } std::vector shape; + int string_num = lite::GetStringCount(input); shape.push_back(string_num == 0 ? 1 : string_num); output0->set_shape(shape); - output0->set_data_type(input->data_type()); - output0->SetFormat(input->GetFormat()); output1->set_shape(shape); - output1->set_data_type(input->data_type()); - output1->SetFormat(input->GetFormat()); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/hashtable_lookup.cc b/mindspore/lite/src/ops/hashtable_lookup.cc index cd2e95371b..f1c3652217 100644 --- a/mindspore/lite/src/ops/hashtable_lookup.cc +++ b/mindspore/lite/src/ops/hashtable_lookup.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "src/ops/hashtable_lookup.h" +#include "src/common/string_util.h" namespace mindspore { namespace lite { @@ -30,8 +31,34 @@ int HashtableLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla } #endif int HashtableLookup::InferShape(std::vector inputs_, std::vector outputs_) { - PrimitiveC::InferShape(inputs_, outputs_); - return RET_INFER_INVALID; + auto input = inputs_.at(0); + auto values = inputs_.at(2); + auto output = outputs_.at(0); + auto hits = outputs_.at(1); + MS_ASSERT(input != nullptr); + MS_ASSERT(keys != nullptr); + MS_ASSERT(values != nullptr); + MS_ASSERT(output != nullptr); + MS_ASSERT(hits != nullptr); + + std::vector hits_shape; + hits_shape.push_back(input->DimensionSize(0)); + + output->set_data_type(values->data_type()); + output->SetFormat(input->GetFormat()); + hits->set_shape(hits_shape); + hits->set_data_type(kNumberTypeUInt8); + hits->SetFormat(input->GetFormat()); + + if (input->data_c() == nullptr) { + MS_LOG(INFO) << "Do infer shape in runtime."; + return RET_INFER_INVALID; + } + int string_num = lite::GetStringCount(input); + std::vector output_shape; + output_shape.push_back(string_num == 0 ? 1 : string_num); + output->set_shape(output_shape); + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 59983c1974..78561749e2 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -1746,6 +1746,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; + populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter; } PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { diff --git a/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc new file mode 100644 index 0000000000..8a87074b15 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/string/hashtable_lookup.h" +#include +#include +#include "src/kernel_registry.h" +#include "src/common/string_util.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_HashtableLookup; + +namespace mindspore::kernel { +int HashtableLookupCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int HashtableLookupCPUKernel::ReSize() { return RET_OK; } + +static int CmpKeyFunc(const void *lhs, const void *rhs) { + return *static_cast(lhs) - *static_cast(rhs); +} + +int HashtableLookupCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; + return ret; + } + auto input_tensor = in_tensors_.at(0); + auto keys_tensor = in_tensors_.at(1); + auto values_tensor = in_tensors_.at(2); + auto output_tensor = out_tensors_.at(0); + auto hits_tensor = out_tensors_.at(1); + + int rows = values_tensor->DimensionSize(0); + int32_t *input_data = reinterpret_cast(input_tensor->MutableData()); + uint8_t *hits_data = reinterpret_cast(hits_tensor->MutableData()); + std::vector output_string_pack; + std::vector all_string_pack = ParseTensorBuffer(input_tensor); + + for (int i = 0; i < input_tensor->ElementsNum(); i++) { + int index = -1; + void *p = bsearch(&(input_data[i]), keys_tensor->MutableData(), rows, sizeof(int32_t), CmpKeyFunc); + if (p != nullptr) { + index = reinterpret_cast(p) - reinterpret_cast(keys_tensor->MutableData()); + } + if (index >= rows || index < 0) { + lite::StringPack tmp = {0, nullptr}; + output_string_pack.push_back(tmp); + hits_data[i] = 0; + } else { + output_string_pack.push_back(all_string_pack[i]); + hits_data[i] = 1; + } + } + WriteStringsToTensor(output_tensor, output_string_pack); + return RET_OK; +} + +kernel::LiteKernel *CpuHashtableLookupKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) HashtableLookupCPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new HashtableLookupCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_HashtableLookup, CpuHashtableLookupKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.h b/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.h new file mode 100644 index 0000000000..75faadebdc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" + +namespace mindspore::kernel { +class HashtableLookupCPUKernel : public LiteKernel { + public: + HashtableLookupCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~HashtableLookupCPUKernel() {} + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/string/predict.cc b/mindspore/lite/src/runtime/kernel/arm/string/predict.cc index c56c2a978d..dd30921f17 100644 --- a/mindspore/lite/src/runtime/kernel/arm/string/predict.cc +++ b/mindspore/lite/src/runtime/kernel/arm/string/predict.cc @@ -68,7 +68,7 @@ std::vector PredictCPUKernel::GetLabelInfo() { return label_info_vec; } -bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } +static bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } int PredictCPUKernel::Run() { auto ret = Prepare();