From 3b8279af10ef23caa424bd1930f5b5942da185d5 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Sat, 17 Oct 2020 18:29:26 +0800 Subject: [PATCH] fix_lite_kernel --- mindspore/lite/src/common/string_util.cc | 2 ++ mindspore/lite/src/ops/lsh_projection.cc | 2 +- mindspore/lite/src/ops/skip_gram.cc | 2 +- mindspore/lite/src/populate_parameter.cc | 3 ++- .../lite/src/runtime/kernel/arm/string/hashtable_lookup.cc | 2 +- mindspore/lite/src/runtime/kernel/arm/string/normalize.cc | 6 +++--- 6 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/src/common/string_util.cc b/mindspore/lite/src/common/string_util.cc index ed2f57fd27..e7ac78926b 100644 --- a/mindspore/lite/src/common/string_util.cc +++ b/mindspore/lite/src/common/string_util.cc @@ -47,6 +47,7 @@ int WriteStringsToTensor(Tensor *tensor, const std::vector &string_b } std::vector shape = {offset[num]}; tensor->set_shape(shape); + tensor->FreeData(); void *data = tensor->MutableData(); if (data == nullptr) { return RET_ERROR; @@ -80,6 +81,7 @@ int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector shape = {offset[num]}; tensor->set_shape(shape); + tensor->FreeData(); void *data = tensor->MutableData(); if (data == nullptr) { return RET_ERROR; diff --git a/mindspore/lite/src/ops/lsh_projection.cc b/mindspore/lite/src/ops/lsh_projection.cc index da14d6b36b..411e619f97 100644 --- a/mindspore/lite/src/ops/lsh_projection.cc +++ b/mindspore/lite/src/ops/lsh_projection.cc @@ -38,7 +38,7 @@ constexpr int kSparseType = 1; constexpr int kDenseType = 2; } // namespace int LshProjection::InferShape(std::vector inputs_, std::vector outputs_) { - if (inputs_.size() != kDoubleNum || inputs_.size() != kMultiNum) { + if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given."; return RET_ERROR; } diff --git a/mindspore/lite/src/ops/skip_gram.cc b/mindspore/lite/src/ops/skip_gram.cc index 8b0058fa40..a7e64210f6 100644 --- a/mindspore/lite/src/ops/skip_gram.cc +++ b/mindspore/lite/src/ops/skip_gram.cc @@ -40,7 +40,7 @@ int SkipGram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer return RET_ERROR; } - auto val_offset = schema::CreateSkipGram(*fbb, attr->ngramSize(), attr->maxSkipSize(), attr->includeAllGrams()); + auto val_offset = schema::CreateSkipGram(*fbb, attr->includeAllGrams(), attr->maxSkipSize(), attr->ngramSize()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SkipGram, val_offset.o); fbb->Finish(prim_offset); return RET_OK; diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 292da34f65..fd7fb83ba1 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -1759,11 +1759,12 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_Elu] = PopulateEluParameter; populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter; populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter; + populate_parameter_funcs_[schema::PrimitiveType_CustomNormalize] = PopulateCommonOpParameter; populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; + populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter; populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter; - populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter; } PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { diff --git a/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc index 77d5607f90..501571a3c6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc +++ b/mindspore/lite/src/runtime/kernel/arm/string/hashtable_lookup.cc @@ -52,7 +52,7 @@ int HashtableLookupCPUKernel::Run() { int32_t *input_data = reinterpret_cast(input_tensor->MutableData()); uint8_t *hits_data = reinterpret_cast(hits_tensor->MutableData()); std::vector output_string_pack; - std::vector all_string_pack = ParseTensorBuffer(input_tensor); + std::vector all_string_pack = ParseTensorBuffer(values_tensor); for (int i = 0; i < input_tensor->ElementsNum(); i++) { int index = -1; diff --git a/mindspore/lite/src/runtime/kernel/arm/string/normalize.cc b/mindspore/lite/src/runtime/kernel/arm/string/normalize.cc index 7865bca8f5..70bfeb0e7c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/string/normalize.cc +++ b/mindspore/lite/src/runtime/kernel/arm/string/normalize.cc @@ -84,7 +84,7 @@ std::string NormalizeCPUKernel::Normalize(const std::string &str) { if (result.size() > kMaxStringLength) { result = result.substr(0, kMaxStringLength); } - + result = " " + result + " "; return result; } @@ -112,9 +112,9 @@ int NormalizeCPUKernel::Run() { for (int i = 0; i < string_num; ++i) { auto chars = all_string_pack[i]; - std::string str(chars.data); + std::string str(chars.data, chars.len); std::string result = Normalize(str); - int str_length = result.size() + 1; + int str_length = result.size(); char *normalized_str = nullptr; normalized_str = reinterpret_cast(context_->allocator->Malloc(sizeof(char) * str_length));