Browse Source

!7505 [MSLITE][Develop] fix SmartReply

Merge pull request !7505 from sunsuodong/fix_lite_kernel
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c8a8c2623d
4 changed files with 15 additions and 10 deletions
  1. +1
    -1
      mindspore/lite/src/common/string_util.cc
  2. +6
    -2
      mindspore/lite/src/lite_session.cc
  3. +5
    -5
      mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc
  4. +3
    -2
      mindspore/lite/src/runtime/kernel/arm/string/predict.cc

+ 1
- 1
mindspore/lite/src/common/string_util.cc View File

@@ -20,7 +20,7 @@ namespace mindspore {
namespace lite {

std::vector<StringPack> ParseTensorBuffer(Tensor *tensor) {
if (tensor->MutableData() == nullptr) {
if (tensor->data_c() == nullptr) {
MS_LOG(ERROR) << "Tensor data is null, cannot be parsed";
return std::vector<StringPack>{};
}


+ 6
- 2
mindspore/lite/src/lite_session.cc View File

@@ -65,8 +65,12 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr";
} else {
if (TensorCategory(srcTensor) == Tensor::Category::CONST) {
for (size_t j = 0; j < srcTensor->dims()->size(); j++) {
shape.push_back(srcTensor->dims()->data()[j]);
if (srcTensor->dataType() == kObjectTypeString && srcTensor->data() != nullptr) {
shape.push_back(srcTensor->data()->size());
} else {
for (size_t j = 0; j < srcTensor->dims()->size(); j++) {
shape.push_back(srcTensor->dims()->data()[j]);
}
}
}
}


+ 5
- 5
mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc View File

@@ -48,11 +48,12 @@ int HashtableLookupCPUKernel::Run() {
auto output_tensor = out_tensors_.at(0);
auto hits_tensor = out_tensors_.at(1);

int rows = values_tensor->DimensionSize(0);
int rows = GetStringCount(values_tensor);
int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData());
uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData());
std::vector<lite::StringPack> output_string_pack;
std::vector<lite::StringPack> output_string_pack(input_tensor->ElementsNum());
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor);
lite::StringPack null_string_pack = {0, nullptr};

for (int i = 0; i < input_tensor->ElementsNum(); i++) {
int index = -1;
@@ -61,11 +62,10 @@ int HashtableLookupCPUKernel::Run() {
index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData());
}
if (index >= rows || index < 0) {
lite::StringPack tmp = {0, nullptr};
output_string_pack.push_back(tmp);
output_string_pack[i] = null_string_pack;
hits_data[i] = 0;
} else {
output_string_pack.push_back(all_string_pack[i]);
output_string_pack[i] = all_string_pack[i];
hits_data[i] = 1;
}
}


+ 3
- 2
mindspore/lite/src/runtime/kernel/arm/string/predict.cc View File

@@ -88,9 +88,10 @@ int PredictCPUKernel::Run() {
if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) {
output_label[i] = -1;
output_weight[i] = 0.0f;
} else {
output_label[i] = label_info_vec[i].label;
output_weight[i] = label_info_vec[i].weight;
}
output_label[i] = label_info_vec[i].label;
output_weight[i] = label_info_vec[i].weight;
}
return RET_OK;
}


Loading…
Cancel
Save