|
|
|
@@ -48,11 +48,12 @@ int HashtableLookupCPUKernel::Run() { |
|
|
|
auto output_tensor = out_tensors_.at(0); |
|
|
|
auto hits_tensor = out_tensors_.at(1); |
|
|
|
|
|
|
|
int rows = values_tensor->DimensionSize(0); |
|
|
|
int rows = GetStringCount(values_tensor); |
|
|
|
int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData()); |
|
|
|
uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData()); |
|
|
|
std::vector<lite::StringPack> output_string_pack; |
|
|
|
std::vector<lite::StringPack> output_string_pack(input_tensor->ElementsNum()); |
|
|
|
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor); |
|
|
|
lite::StringPack null_string_pack = {0, nullptr}; |
|
|
|
|
|
|
|
for (int i = 0; i < input_tensor->ElementsNum(); i++) { |
|
|
|
int index = -1; |
|
|
|
@@ -61,11 +62,10 @@ int HashtableLookupCPUKernel::Run() { |
|
|
|
index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData()); |
|
|
|
} |
|
|
|
if (index >= rows || index < 0) { |
|
|
|
lite::StringPack tmp = {0, nullptr}; |
|
|
|
output_string_pack.push_back(tmp); |
|
|
|
output_string_pack[i] = null_string_pack; |
|
|
|
hits_data[i] = 0; |
|
|
|
} else { |
|
|
|
output_string_pack.push_back(all_string_pack[i]); |
|
|
|
output_string_pack[i] = all_string_pack[i]; |
|
|
|
hits_data[i] = 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|