| @@ -20,7 +20,7 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| std::vector<StringPack> ParseTensorBuffer(Tensor *tensor) { | std::vector<StringPack> ParseTensorBuffer(Tensor *tensor) { | ||||
| if (tensor->MutableData() == nullptr) { | |||||
| if (tensor->data_c() == nullptr) { | |||||
| MS_LOG(ERROR) << "Tensor data is null, cannot be parsed"; | MS_LOG(ERROR) << "Tensor data is null, cannot be parsed"; | ||||
| return std::vector<StringPack>{}; | return std::vector<StringPack>{}; | ||||
| } | } | ||||
| @@ -65,8 +65,12 @@ int LiteSession::ConvertTensors(const lite::Model *model) { | |||||
| MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr"; | MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr"; | ||||
| } else { | } else { | ||||
| if (TensorCategory(srcTensor) == Tensor::Category::CONST) { | if (TensorCategory(srcTensor) == Tensor::Category::CONST) { | ||||
| for (size_t j = 0; j < srcTensor->dims()->size(); j++) { | |||||
| shape.push_back(srcTensor->dims()->data()[j]); | |||||
| if (srcTensor->dataType() == kObjectTypeString && srcTensor->data() != nullptr) { | |||||
| shape.push_back(srcTensor->data()->size()); | |||||
| } else { | |||||
| for (size_t j = 0; j < srcTensor->dims()->size(); j++) { | |||||
| shape.push_back(srcTensor->dims()->data()[j]); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -48,11 +48,12 @@ int HashtableLookupCPUKernel::Run() { | |||||
| auto output_tensor = out_tensors_.at(0); | auto output_tensor = out_tensors_.at(0); | ||||
| auto hits_tensor = out_tensors_.at(1); | auto hits_tensor = out_tensors_.at(1); | ||||
| int rows = values_tensor->DimensionSize(0); | |||||
| int rows = GetStringCount(values_tensor); | |||||
| int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData()); | int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData()); | ||||
| uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData()); | uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData()); | ||||
| std::vector<lite::StringPack> output_string_pack; | |||||
| std::vector<lite::StringPack> output_string_pack(input_tensor->ElementsNum()); | |||||
| std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor); | std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor); | ||||
| lite::StringPack null_string_pack = {0, nullptr}; | |||||
| for (int i = 0; i < input_tensor->ElementsNum(); i++) { | for (int i = 0; i < input_tensor->ElementsNum(); i++) { | ||||
| int index = -1; | int index = -1; | ||||
| @@ -61,11 +62,10 @@ int HashtableLookupCPUKernel::Run() { | |||||
| index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData()); | index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData()); | ||||
| } | } | ||||
| if (index >= rows || index < 0) { | if (index >= rows || index < 0) { | ||||
| lite::StringPack tmp = {0, nullptr}; | |||||
| output_string_pack.push_back(tmp); | |||||
| output_string_pack[i] = null_string_pack; | |||||
| hits_data[i] = 0; | hits_data[i] = 0; | ||||
| } else { | } else { | ||||
| output_string_pack.push_back(all_string_pack[i]); | |||||
| output_string_pack[i] = all_string_pack[i]; | |||||
| hits_data[i] = 1; | hits_data[i] = 1; | ||||
| } | } | ||||
| } | } | ||||
| @@ -88,9 +88,10 @@ int PredictCPUKernel::Run() { | |||||
| if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) { | if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) { | ||||
| output_label[i] = -1; | output_label[i] = -1; | ||||
| output_weight[i] = 0.0f; | output_weight[i] = 0.0f; | ||||
| } else { | |||||
| output_label[i] = label_info_vec[i].label; | |||||
| output_weight[i] = label_info_vec[i].weight; | |||||
| } | } | ||||
| output_label[i] = label_info_vec[i].label; | |||||
| output_weight[i] = label_info_vec[i].weight; | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||