From b0822ba9080ace61fa6b53b5fd05a711a61dda30 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Wed, 21 Oct 2020 11:21:47 +0800 Subject: [PATCH] add lite string interface --- mindspore/lite/include/lite_utils.h | 14 ++++++++++++ mindspore/lite/src/common/string_util.cc | 28 +++++++++++++++++++++++- mindspore/lite/src/common/string_util.h | 11 +++++----- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/include/lite_utils.h b/mindspore/lite/include/lite_utils.h index 9a86b054ef..a17524a683 100644 --- a/mindspore/lite/include/lite_utils.h +++ b/mindspore/lite/include/lite_utils.h @@ -20,6 +20,7 @@ #include #include #include "schema/model_generated.h" +#include "include/ms_tensor.h" namespace mindspore::lite { /// \brief Allocator defined a memory pool for malloc memory and free memory dynamically. @@ -36,5 +37,18 @@ using Uint32Vector = std::vector; using String = std::string; using NodeType = schema::NodeType; using AllocatorPtr = std::shared_ptr; + +/// \brief Set data of MSTensor from string vector. +/// +/// \param[in] input string vector. +/// \param[out] MSTensor. +/// +/// \return STATUS as an error code of this interface, STATUS is defined in errorcode.h. +int StringsToMSTensor(const std::vector &inputs, tensor::MSTensor *tensor); + +/// \brief Get string vector from MSTensor. +/// \param[in] MSTensor. +/// \return string vector. +std::vector MSTensorToStrings(const tensor::MSTensor *tensor); } // namespace mindspore::lite #endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_ diff --git a/mindspore/lite/src/common/string_util.cc b/mindspore/lite/src/common/string_util.cc index 82de77790e..909f710145 100644 --- a/mindspore/lite/src/common/string_util.cc +++ b/mindspore/lite/src/common/string_util.cc @@ -15,6 +15,8 @@ */ #include "src/common/string_util.h" +#include +#include "include/ms_tensor.h" namespace mindspore { namespace lite { @@ -28,9 +30,13 @@ std::vector ParseTensorBuffer(Tensor *tensor) { } std::vector ParseStringBuffer(const void *data) { + std::vector buffer; + if (data == nullptr) { + MS_LOG(ERROR) << "data is nullptr"; + return buffer; + } const int32_t *offset = reinterpret_cast(data); int32_t num = *offset; - std::vector buffer; for (int i = 0; i < num; i++) { offset += 1; buffer.push_back(StringPack{(*(offset + 1)) - (*offset), reinterpret_cast(data) + (*offset)}); @@ -108,6 +114,26 @@ int GetStringCount(const void *data) { return *(static_cast(dat int GetStringCount(Tensor *tensor) { return GetStringCount(tensor->MutableData()); } +int StringsToMSTensor(const std::vector &inputs, tensor::MSTensor *tensor) { + std::vector all_pack; + for (auto &input : inputs) { + StringPack pack = {static_cast(input.length()), input.data()}; + all_pack.push_back(pack); + } + return WriteStringsToTensor(static_cast(tensor), all_pack); +} + +std::vector MSTensorToStrings(const tensor::MSTensor *tensor) { + const void *ptr = static_cast(tensor)->data_c(); + std::vector all_pack = ParseStringBuffer(ptr); + std::vector result(all_pack.size()); + std::transform(all_pack.begin(), all_pack.end(), result.begin(), [](StringPack &pack) { + std::string str(pack.data, pack.len); + return str; + }); + return result; +} + // Some primes between 2^63 and 2^64 static uint64_t k0 = 0xc3a5c85c97cb3127ULL; static uint64_t k1 = 0xb492b66fbe98f273ULL; diff --git a/mindspore/lite/src/common/string_util.h b/mindspore/lite/src/common/string_util.h index 1ccf271d25..d62d2b2c34 100644 --- a/mindspore/lite/src/common/string_util.h +++ b/mindspore/lite/src/common/string_util.h @@ -33,12 +33,11 @@ typedef struct { } StringPack; // example of string tensor: -// 3, 0, 0, 0 # int32, num of strings -// 20, 0, 0, 0 # int32, offset of 0-th string -// 23, 0, 0, 0 # int32, offset of 0-th string -// 26, 0, 0, 0 # int32, offset of 0-th string -// 29, 0, 0, 0 # int32, total length of tensor data -// 'h', 'o', 'w', 'a', 'r', 'e', 'y', 'o', 'u' # char, how are you +// 2, 0, 0, 0 # int32, num of strings +// 16, 0, 0, 0 # int32, offset of 0-th string +// 21, 0, 0, 0 # int32, offset of 1-th string +// 30, 0, 0, 0 # int32, total length of tensor data +// 'h', 'e', 'l', 'l', 'o', 'h', 'o', 'w', 'a', 'r', 'e', 'y', 'o', 'u' # char, "hello", "how are you" std::vector ParseTensorBuffer(Tensor *tensor); std::vector ParseStringBuffer(const void *data);