Browse Source

!7426 [MSLITE][Develop] fix smart reply kernel

Merge pull request !7426 from sunsuodong/fix_lite_kernel
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
21b9d8e8a7
6 changed files with 10 additions and 7 deletions
  1. +2
    -0
      mindspore/lite/src/common/string_util.cc
  2. +1
    -1
      mindspore/lite/src/ops/lsh_projection.cc
  3. +1
    -1
      mindspore/lite/src/ops/skip_gram.cc
  4. +2
    -1
      mindspore/lite/src/populate_parameter.cc
  5. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc
  6. +3
    -3
      mindspore/lite/src/runtime/kernel/arm/string/normalize.cc

+ 2
- 0
mindspore/lite/src/common/string_util.cc View File

@@ -47,6 +47,7 @@ int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_b
}
std::vector<int> shape = {offset[num]};
tensor->set_shape(shape);
tensor->FreeData();
void *data = tensor->MutableData();
if (data == nullptr) {
return RET_ERROR;
@@ -80,6 +81,7 @@ int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<

std::vector<int> shape = {offset[num]};
tensor->set_shape(shape);
tensor->FreeData();
void *data = tensor->MutableData();
if (data == nullptr) {
return RET_ERROR;


+ 1
- 1
mindspore/lite/src/ops/lsh_projection.cc View File

@@ -38,7 +38,7 @@ constexpr int kSparseType = 1;
constexpr int kDenseType = 2;
} // namespace
int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (inputs_.size() != kDoubleNum || inputs_.size() != kMultiNum) {
if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) {
MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given.";
return RET_ERROR;
}


+ 1
- 1
mindspore/lite/src/ops/skip_gram.cc View File

@@ -40,7 +40,7 @@ int SkipGram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
return RET_ERROR;
}

auto val_offset = schema::CreateSkipGram(*fbb, attr->ngramSize(), attr->maxSkipSize(), attr->includeAllGrams());
auto val_offset = schema::CreateSkipGram(*fbb, attr->includeAllGrams(), attr->maxSkipSize(), attr->ngramSize());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SkipGram, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;


+ 2
- 1
mindspore/lite/src/populate_parameter.cc View File

@@ -1759,11 +1759,12 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_Elu] = PopulateEluParameter;
populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter;
populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter;
populate_parameter_funcs_[schema::PrimitiveType_CustomNormalize] = PopulateCommonOpParameter;
populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter;
populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter;
populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter;
populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter;
populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter;
populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter;
}

PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc View File

@@ -52,7 +52,7 @@ int HashtableLookupCPUKernel::Run() {
int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData());
uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData());
std::vector<lite::StringPack> output_string_pack;
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor);
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor);

for (int i = 0; i < input_tensor->ElementsNum(); i++) {
int index = -1;


+ 3
- 3
mindspore/lite/src/runtime/kernel/arm/string/normalize.cc View File

@@ -84,7 +84,7 @@ std::string NormalizeCPUKernel::Normalize(const std::string &str) {
if (result.size() > kMaxStringLength) {
result = result.substr(0, kMaxStringLength);
}
result = "<S> " + result + " <E>";
return result;
}

@@ -112,9 +112,9 @@ int NormalizeCPUKernel::Run() {

for (int i = 0; i < string_num; ++i) {
auto chars = all_string_pack[i];
std::string str(chars.data);
std::string str(chars.data, chars.len);
std::string result = Normalize(str);
int str_length = result.size() + 1;
int str_length = result.size();

char *normalized_str = nullptr;
normalized_str = reinterpret_cast<char *>(context_->allocator->Malloc(sizeof(char) * str_length));


Loading…
Cancel
Save