From d2559d11115099fc7215ec980d4b2ed4e1d1c221 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Fri, 9 Oct 2020 18:19:11 +0800 Subject: [PATCH] add extract_feature --- mindspore/lite/src/common/string_util.cc | 143 ++++++++++++++++++ mindspore/lite/src/common/string_util.h | 11 ++ .../lite/src/ops/custom_extract_features.cc | 26 +++- .../src/runtime/kernel/arm/CMakeLists.txt | 1 + .../kernel/arm/string/extract_feature.cc | 97 ++++++++++++ .../kernel/arm/string/extract_feature.h | 42 +++++ mindspore/lite/test/CMakeLists.txt | 1 + 7 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/arm/string/extract_feature.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/string/extract_feature.h diff --git a/mindspore/lite/src/common/string_util.cc b/mindspore/lite/src/common/string_util.cc index 56579f6184..eb166e6eae 100644 --- a/mindspore/lite/src/common/string_util.cc +++ b/mindspore/lite/src/common/string_util.cc @@ -102,5 +102,148 @@ int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector(data)); } + +int GetStringCount(Tensor *tensor) { return GetStringCount(tensor->MutableData()); } + +// Some primes between 2^63 and 2^64 +static uint64_t k0 = 0xc3a5c85c97cb3127ULL; +static uint64_t k1 = 0xb492b66fbe98f273ULL; +static uint64_t k2 = 0x9ae16a3b2f90404fULL; + +uint64_t Fetch64Bit(const char *p) { + uint64_t result; + memcpy(&result, p, sizeof(uint64_t)); + return __builtin_bswap64(result); +} + +uint32_t Fetch32Bit(const char *p) { + uint32_t result; + memcpy(&result, p, sizeof(uint32_t)); + return __builtin_bswap32(result); +} + +uint64_t Rotate64(uint64_t value, int shift) { + return shift == 0 ? value : ((value >> shift) | (value << (64 - shift))); +} + +uint64_t HashLen16(uint64_t u, uint64_t v, uint64_t multiple) { + uint64_t a = (u ^ v) * multiple; + a ^= (a >> 47); + uint64_t b = (v ^ a) * multiple; + b ^= (b >> 47); + b *= multiple; + return b; +} + +uint64_t ShiftMix(uint64_t value) { return value ^ (value >> 47); } + +uint64_t HashStringLen0to16(const char *s, size_t len) { + if (len >= 8) { + uint64_t mul = k2 + len * 2; + uint64_t a = Fetch64Bit(s) + k2; + uint64_t b = Fetch64Bit(s + len - 8); + uint64_t c = Rotate64(b, 37) * mul + a; + uint64_t d = (Rotate64(a, 25) + b) * mul; + return HashLen16(c, d, mul); + } + if (len >= 4) { + uint64_t mul = k2 + len * 2; + uint64_t a = Fetch32Bit(s); + return HashLen16(len + (a << 3), Fetch32Bit(s + len - 4), mul); + } + if (len > 0) { + uint8_t a = s[0]; + uint8_t b = s[len >> 1]; + uint8_t c = s[len - 1]; + uint32_t y = static_cast(a) + (static_cast(b) << 8); + uint32_t z = len + (static_cast(c) << 2); + return ShiftMix(y * k2 ^ z * k0) * k2; + } + return k2; +} + +uint64_t HashStringLen17to32(const char *s, size_t len) { + uint64_t mul = k2 + len * 2; + uint64_t a = Fetch64Bit(s) * k1; + uint64_t b = Fetch64Bit(s + 8); + uint64_t c = Fetch64Bit(s + len - 8) * mul; + uint64_t d = Fetch64Bit(s + len - 16) * k2; + return HashLen16(Rotate64(a + b, 43) + Rotate64(c, 30) + d, a + Rotate64(b + k2, 18) + c, mul); +} + +uint64_t HashStringLen33to64(const char *s, size_t len) { + uint64_t mul = k2 + len * 2; + uint64_t a = Fetch64Bit(s) * k2; + uint64_t b = Fetch64Bit(s + 8); + uint64_t c = Fetch64Bit(s + len - 8) * mul; + uint64_t d = Fetch64Bit(s + len - 16) * k2; + uint64_t y = Rotate64(a + b, 43) + Rotate64(c, 30) + d; + uint64_t z = HashLen16(y, a + Rotate64(b + k2, 18) + c, mul); + uint64_t e = Fetch64Bit(s + 16) * mul; + uint64_t f = Fetch64Bit(s + 24); + uint64_t g = (y + Fetch64Bit(s + len - 32)) * mul; + uint64_t h = (z + Fetch64Bit(s + len - 24)) * mul; + return HashLen16(Rotate64(e + f, 43) + Rotate64(g, 30) + h, e + Rotate64(f + a, 18) + g, mul); +} + +std::pair HashLen32WithSeeds(const char *s, uint64_t a, uint64_t b) { + a += Fetch64Bit(s); + b = Rotate64(b + a + Fetch64Bit(s + 24), 21); + uint64_t c = a; + a += Fetch64Bit(s + 8); + a += Fetch64Bit(s + 16); + b += Rotate64(a, 44); + return std::make_pair(a + Fetch64Bit(s + 24), b + c); +} + +uint64_t StringHash64(const char *s, size_t len) { + uint64_t seed_value = 81; + if (len <= 16) { + return HashStringLen0to16(s, len); + } else if (len <= 32) { + return HashStringLen17to32(s, len); + } else if (len <= 64) { + return HashStringLen33to64(s, len); + } + + uint64_t x = seed_value; + uint64_t y = seed_value * k1 + 113; + uint64_t tmp = y * k2 + 113; + uint64_t z = (tmp ^ (tmp >> 47)) * k2; + std::pair v = std::make_pair(0, 0); + std::pair w = std::make_pair(0, 0); + x = x * k2 + Fetch64Bit(s); + + const char *end = s + ((len - 1) / 64) * 64; + const char *last64 = end + ((len - 1) & 63) - 63; + MS_ASSERT(s + len - 64 == last64); + do { + x = Rotate64(x + y + v.first + Fetch64Bit(s + 8), 37) * k1; + y = Rotate64(y + v.second + Fetch64Bit(s + 48), 42) * k1; + x ^= w.second; + y += v.first + Fetch64Bit(s + 40); + z = Rotate64(z + w.first, 33) * k1; + v = HashLen32WithSeeds(s, v.second * k1, x + w.first); + w = HashLen32WithSeeds(s + 32, z + w.second, y + Fetch64Bit(s + 16)); + std::swap(z, x); + s += 64; + } while (s != end); + uint64_t mul = k1 + ((z & 0xff) << 1); + s = last64; + w.first += ((len - 1) & 63); + v.first += w.first; + w.first += v.first; + x = Rotate64(x + y + v.first + Fetch64Bit(s + 8), 37) * mul; + y = Rotate64(y + v.second + Fetch64Bit(s + 48), 42) * mul; + x ^= w.second * 9; + y += v.first * 9 + Fetch64Bit(s + 40); + z = Rotate64(z + w.first, 33) * mul; + v = HashLen32WithSeeds(s, v.second * mul, x + w.first); + w = HashLen32WithSeeds(s + 32, z + w.second, y + Fetch64Bit(s + 16)); + std::swap(z, x); + return HashLen16(HashLen16(v.first, w.first, mul) + ShiftMix(y) * k0 + z, HashLen16(v.second, w.second, mul) + x, + mul); +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/common/string_util.h b/mindspore/lite/src/common/string_util.h index 9e43ad2885..1ccf271d25 100644 --- a/mindspore/lite/src/common/string_util.h +++ b/mindspore/lite/src/common/string_util.h @@ -32,12 +32,23 @@ typedef struct { const char *data; } StringPack; +// example of string tensor: +// 3, 0, 0, 0 # int32, num of strings +// 20, 0, 0, 0 # int32, offset of 0-th string +// 23, 0, 0, 0 # int32, offset of 0-th string +// 26, 0, 0, 0 # int32, offset of 0-th string +// 29, 0, 0, 0 # int32, total length of tensor data +// 'h', 'o', 'w', 'a', 'r', 'e', 'y', 'o', 'u' # char, how are you std::vector ParseTensorBuffer(Tensor *tensor); std::vector ParseStringBuffer(const void *data); int WriteStringsToTensor(Tensor *tensor, const std::vector &string_buffer); int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector> &string_buffer); +int GetStringCount(const void *data); +int GetStringCount(Tensor *tensor); + +uint64_t StringHash64(const char *s, size_t len); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/custom_extract_features.cc b/mindspore/lite/src/ops/custom_extract_features.cc index 52f511d44c..601c1321e6 100644 --- a/mindspore/lite/src/ops/custom_extract_features.cc +++ b/mindspore/lite/src/ops/custom_extract_features.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "src/ops/custom_extract_features.h" +#include "src/common/string_util.h" namespace mindspore { namespace lite { @@ -30,9 +31,30 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv return RET_OK; } #endif + int CustomExtractFeatures::InferShape(std::vector inputs_, std::vector outputs_) { - PrimitiveC::InferShape(inputs_, outputs_); - return RET_INFER_INVALID; + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + if (input->data_c() == nullptr) { + MS_LOG(INFO) << "Do infer shape in runtime."; + return RET_INFER_INVALID; + } + int string_num = lite::GetStringCount(input); + auto output0 = outputs_.at(0); + auto output1 = outputs_.at(1); + MS_ASSERT(output0 != nullptr); + MS_ASSERT(output1 != nullptr); + + std::vector shape; + shape.push_back(string_num == 0 ? 1 : string_num); + + output0->set_shape(shape); + output0->set_data_type(input->data_type()); + output0->SetFormat(input->GetFormat()); + output1->set_shape(shape); + output1->set_data_type(input->data_type()); + output1->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt index a108bce633..fceac9912b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -4,6 +4,7 @@ file(GLOB KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/string/*.cc ) list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc) diff --git a/mindspore/lite/src/runtime/kernel/arm/string/extract_feature.cc b/mindspore/lite/src/runtime/kernel/arm/string/extract_feature.cc new file mode 100644 index 0000000000..08bb100db0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/string/extract_feature.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/string/extract_feature.h" +#include +#include "src/kernel_registry.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_CustomExtractFeatures; + +namespace mindspore::kernel { +int ExtractFeatureCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int ExtractFeatureCPUKernel::ReSize() { return RET_OK; } + +bool ExtractFeatureCPUKernel::IsInBlacklist(const lite::StringPack &str) { + std::vector kBlacklist = {"", "", " "}; + for (const auto &s : kBlacklist) { + if (str.len != static_cast(s.length())) { + continue; + } + if (memcmp(str.data, s.data(), str.len) == 0) { + return true; + } + } + return false; +} + +int ExtractFeatureCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; + return ret; + } + const int kMaxDimension = 1000000; + auto input_tensor = in_tensors_.at(0); + auto label_data = reinterpret_cast(out_tensors_.at(0)->MutableData()); + auto weight_data = out_tensors_.at(1)->MutableData(); + int string_num = lite::GetStringCount(input_tensor); + std::vector all_string_pack = ParseTensorBuffer(input_tensor); + + for (int i = 0; i < string_num; i++) { + lite::StringPack str = all_string_pack[i]; + if (IsInBlacklist(str)) { + label_data[i] = 0; + reinterpret_cast(weight_data)[i] = 0; + continue; + } + int64_t hash_value = lite::StringHash64(str.data, str.len) % kMaxDimension; + label_data[i] = hash_value; + reinterpret_cast(weight_data)[i] = std::count(str.data, str.data + str.len, ' ') + 1; + } + if (string_num == 0) { + label_data[0] = 0; + reinterpret_cast(weight_data)[0] = 0; + } + return RET_OK; +} + +kernel::LiteKernel *CpuExtractFeatureKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) ExtractFeatureCPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ExtractFeatureCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomExtractFeatures, CpuExtractFeatureKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/string/extract_feature.h b/mindspore/lite/src/runtime/kernel/arm/string/extract_feature.h new file mode 100644 index 0000000000..72e8d23b6a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/string/extract_feature.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/common/string_util.h" + +namespace mindspore::kernel { +class ExtractFeatureCPUKernel : public LiteKernel { + public: + ExtractFeatureCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~ExtractFeatureCPUKernel() {} + + int Init() override; + int ReSize() override; + int Run() override; + + private: + bool IsInBlacklist(const lite::StringPack &str); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_ diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 09df4130a5..fc7311d3d6 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -28,6 +28,7 @@ file(GLOB KERNEL_OP_SRC ${LITE_DIR}/src/runtime/kernel/arm/base/*.cc ${LITE_DIR}/src/runtime/kernel/arm/fp32/*.cc ${LITE_DIR}/src/runtime/kernel/arm/int8/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/string/*.cc ${LITE_DIR}/nnacl/*.c ${LITE_DIR}/nnacl/fp32/*.c ${LITE_DIR}/nnacl/int8/*.c