From 4b98754dee8520359c50fdc321f8db9d4f369020 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Tue, 13 Oct 2020 19:05:33 +0800 Subject: [PATCH] custom op normalize --- mindspore/lite/src/common/string_util.cc | 4 +- mindspore/lite/src/ops/custom_normalize.cc | 21 ++- .../runtime/kernel/arm/string/normalize.cc | 160 ++++++++++++++++++ .../src/runtime/kernel/arm/string/normalize.h | 47 +++++ mindspore/lite/test/CMakeLists.txt | 1 + .../runtime/kernel/arm/string/normalize.cc | 87 ++++++++++ 6 files changed, 316 insertions(+), 4 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/arm/string/normalize.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/string/normalize.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/string/normalize.cc diff --git a/mindspore/lite/src/common/string_util.cc b/mindspore/lite/src/common/string_util.cc index eb166e6eae..ed2f57fd27 100644 --- a/mindspore/lite/src/common/string_util.cc +++ b/mindspore/lite/src/common/string_util.cc @@ -52,8 +52,8 @@ int WriteStringsToTensor(Tensor *tensor, const std::vector &string_b return RET_ERROR; } - auto *string_info = reinterpret_cast(data); - auto *string_data = reinterpret_cast(data); + int32_t *string_info = reinterpret_cast(data); + char *string_data = reinterpret_cast(data); string_info[0] = num; for (int i = 0; i <= num; i++) { diff --git a/mindspore/lite/src/ops/custom_normalize.cc b/mindspore/lite/src/ops/custom_normalize.cc index a92f3b1314..00225d6451 100644 --- a/mindspore/lite/src/ops/custom_normalize.cc +++ b/mindspore/lite/src/ops/custom_normalize.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "src/ops/custom_normalize.h" +#include "src/common/string_util.h" namespace mindspore { namespace lite { @@ -30,8 +31,24 @@ int CustomNormalize::UnPackToFlatBuilder(const schema::Primitive *primitive, fla } #endif int CustomNormalize::InferShape(std::vector inputs_, std::vector outputs_) { - PrimitiveC::InferShape(inputs_, outputs_); - return RET_INFER_INVALID; + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + if (input->data_c() == nullptr) { + MS_LOG(INFO) << "Do infer shape in runtime."; + return RET_INFER_INVALID; + } + int string_num = lite::GetStringCount(input); + auto output = outputs_.at(0); + MS_ASSERT(output != nullptr); + + std::vector shape; + shape.push_back(string_num == 0 ? 1 : string_num); + + output->set_shape(shape); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; } + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/string/normalize.cc b/mindspore/lite/src/runtime/kernel/arm/string/normalize.cc new file mode 100644 index 0000000000..4412ae97ab --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/string/normalize.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/string/normalize.h" +#include +#include +#include +#include +#include "src/kernel_registry.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_CustomNormalize; + +namespace mindspore::kernel { +namespace { +const char kPunctuationsRegex[] = "[.*()\"]"; +const std::map *kRegexTransforms = new (std::nothrow) std::map({ + {"([\\S]+)n't", "$1 not"}, + {"([\\S]+)'nt", "$1 not"}, + {"([\\S]+)'ll", "$1 will"}, + {"([\\S]+)'re", "$1 are"}, + {"([\\S]+)'ve", "$1 have"}, + {"i'm", "i am"}, +}); +const int32_t kMaxStringLength = 300; + +} // namespace + +int NormalizeCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int NormalizeCPUKernel::ReSize() { return RET_OK; } + +std::string NormalizeCPUKernel::Trim(const std::string &str, const std::string &whitespace /*= " \t\n\v\f\r"*/) { + auto begin = str.find_first_not_of(whitespace); + auto end = str.find_last_not_of(whitespace); + const auto range = end - begin + 1; + return str.substr(begin, range); +} + +std::string NormalizeCPUKernel::GlobalReplace(const std::string &str, const std::string ®, + const std::string &replace) { + std::regex e(reg); + return std::regex_replace(str, e, replace); +} + +std::string NormalizeCPUKernel::Normalize(const std::string &str) { + std::string result; + std::transform(str.begin(), str.end(), back_inserter(result), [](unsigned char c) { return std::tolower(c); }); + result = Trim(result); + result = GlobalReplace(result, kPunctuationsRegex, ""); + result = GlobalReplace(result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])", "$1$2"); + result = GlobalReplace(result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "$1"); + // transform shortening to full + MS_ASSERT(kRegexTransforms != nullptr); + for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end(); iter++) { + result = GlobalReplace(result, iter->first, iter->second); + } + result = GlobalReplace(result, "([?])+", "$1"); + result = GlobalReplace(result, "([!])+", "$1"); + result = GlobalReplace(result, "([^?!]+)([?!])", "$1 $2 "); + result = GlobalReplace(result, "([?!])([?!])", "$1 $2"); + + result = GlobalReplace(result, "[\\s,:;\\-&'\"]+$", ""); + result = GlobalReplace(result, "^[\\s,:;\\-&'\"]+", ""); + + result = Trim(result); + if (result.size() > kMaxStringLength) { + result = result.substr(0, kMaxStringLength); + } + + return result; +} + +void NormalizeCPUKernel::FreeBuffer() { + for (size_t j = 0; j < normalized_strs.size(); ++j) { + if (normalized_strs[j] != nullptr) { + context_->allocator->Free(normalized_strs[j]); + normalized_strs[j] = nullptr; + } + } +} + +int NormalizeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; + return ret; + } + auto input_tensor = in_tensors_.at(0); + int string_num = lite::GetStringCount(input_tensor); + std::vector all_string_pack = ParseTensorBuffer(input_tensor); + + std::vector out_string_pack; + normalized_strs.resize(string_num, nullptr); + + for (int i = 0; i < string_num; ++i) { + auto chars = all_string_pack[i]; + std::string str(chars.data); + std::string result = Normalize(str); + int str_length = result.size() + 1; + + char *normalized_str = nullptr; + normalized_str = reinterpret_cast(context_->allocator->Malloc(sizeof(char) * str_length)); + if (normalized_str == nullptr) { + MS_LOG(ERROR) << "Malloc data failed!"; + FreeBuffer(); + return RET_ERROR; + } + normalized_strs[i] = normalized_str; + + memcpy(normalized_str, result.data(), str_length); + out_string_pack.push_back({str_length, normalized_str}); + } + if (string_num == 0) { + out_string_pack.push_back({1, ""}); + } + auto out_tensor = out_tensors_.at(0); + WriteStringsToTensor(out_tensor, out_string_pack); + FreeBuffer(); + return RET_OK; +} + +kernel::LiteKernel *CpuNormalizeKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) NormalizeCPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new NormalizeCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomNormalize, CpuNormalizeKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/string/normalize.h b/mindspore/lite/src/runtime/kernel/arm/string/normalize.h new file mode 100644 index 0000000000..b911af8b32 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/string/normalize.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_NORMALIZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_NORMALIZE_H_ + +#include +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/common/string_util.h" + +namespace mindspore::kernel { +class NormalizeCPUKernel : public LiteKernel { + public: + NormalizeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~NormalizeCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + std::string Trim(const std::string &str, const std::string &whitespace = " \t\n\v\f\r"); + std::string GlobalReplace(const std::string &str, const std::string ®, const std::string &replace); + std::string Normalize(const std::string &str); + std::vector normalized_strs; + void FreeBuffer(); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_NORMALIZE_H_ diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index dd9b1ec10b..f8bb8555bd 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -209,6 +209,7 @@ file(GLOB_RECURSE TEST_CASE_KERNEL_SRC ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc + ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc ) file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/string/normalize.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/string/normalize.cc new file mode 100644 index 0000000000..5237373a44 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/string/normalize.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32/skip_gram.h" +#include "src/runtime/kernel/arm/string/normalize.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "nnacl/fp32/skip_gram.h" +#include "src/common/file_utils.h" +#include "common/common_test.h" +#include "src/common/log_adapter.h" +#include "src/common/string_util.h" + +namespace mindspore { +using mindspore::lite::StringPack; +using mindspore::lite::Tensor; + +class TestNormalize : public mindspore::CommonTest { + public: + TestNormalize() {} + void NormalizeTestInit(); + + public: + Tensor input_tensor_; + Tensor output_tensor_; + std::vector inputs_{&input_tensor_}; + std::vector outputs_{&output_tensor_}; + OpParameter parameter_ = {}; + lite::InnerContext ctx_ = lite::InnerContext(); + kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_CustomNormalize}; + kernel::KernelCreator creator_ = nullptr; + kernel::LiteKernel *kernel_ = nullptr; +}; + +void TestNormalize::NormalizeTestInit() { + input_tensor_.set_data_type(kObjectTypeString); + input_tensor_.SetFormat(schema::Format_NHWC); + + std::vector str_pack; + const char sentence1[] = " I don't know what happened\n"; + str_pack.push_back({static_cast(strlen(sentence1) + 1), sentence1}); + const char sentence2[] = "She's not here when Alex arrived!!!"; + str_pack.push_back({static_cast(strlen(sentence2) + 1), sentence2}); + mindspore::lite::WriteStringsToTensor(&input_tensor_, str_pack); + + output_tensor_.set_data_type(kObjectTypeString); + output_tensor_.SetFormat(schema::Format_NHWC); +} + +TEST_F(TestNormalize, TestSentence) { + NormalizeTestInit(); + ASSERT_EQ(lite::RET_OK, ctx_.Init()); + creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc_); + ASSERT_NE(creator_, nullptr); + kernel_ = creator_(inputs_, outputs_, ¶meter_, &ctx_, desc_, nullptr); + ASSERT_NE(kernel_, nullptr); + auto ret = kernel_->Init(); + MS_ASSERT(ret == 0); + ret = kernel_->Run(); + MS_ASSERT(ret == 0); + + std::vector output = mindspore::lite::ParseTensorBuffer(outputs_[0]); + for (int i = 0; i < output.size(); i++) { + for (int j = 0; j < output[i].len; j++) { + printf("%c", output[i].data[j]); + } + printf("\n"); + } + + input_tensor_.SetData(nullptr); + output_tensor_.SetData(nullptr); +} + +} // namespace mindspore