Merge pull request !7124 from sunsuodong/extract_featuretags/v1.1.0
| @@ -102,5 +102,148 @@ int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector< | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int GetStringCount(const void *data) { return *(static_cast<const int32_t *>(data)); } | |||||
| int GetStringCount(Tensor *tensor) { return GetStringCount(tensor->MutableData()); } | |||||
| // Some primes between 2^63 and 2^64 | |||||
| static uint64_t k0 = 0xc3a5c85c97cb3127ULL; | |||||
| static uint64_t k1 = 0xb492b66fbe98f273ULL; | |||||
| static uint64_t k2 = 0x9ae16a3b2f90404fULL; | |||||
| uint64_t Fetch64Bit(const char *p) { | |||||
| uint64_t result; | |||||
| memcpy(&result, p, sizeof(uint64_t)); | |||||
| return __builtin_bswap64(result); | |||||
| } | |||||
| uint32_t Fetch32Bit(const char *p) { | |||||
| uint32_t result; | |||||
| memcpy(&result, p, sizeof(uint32_t)); | |||||
| return __builtin_bswap32(result); | |||||
| } | |||||
| uint64_t Rotate64(uint64_t value, int shift) { | |||||
| return shift == 0 ? value : ((value >> shift) | (value << (64 - shift))); | |||||
| } | |||||
| uint64_t HashLen16(uint64_t u, uint64_t v, uint64_t multiple) { | |||||
| uint64_t a = (u ^ v) * multiple; | |||||
| a ^= (a >> 47); | |||||
| uint64_t b = (v ^ a) * multiple; | |||||
| b ^= (b >> 47); | |||||
| b *= multiple; | |||||
| return b; | |||||
| } | |||||
| uint64_t ShiftMix(uint64_t value) { return value ^ (value >> 47); } | |||||
| uint64_t HashStringLen0to16(const char *s, size_t len) { | |||||
| if (len >= 8) { | |||||
| uint64_t mul = k2 + len * 2; | |||||
| uint64_t a = Fetch64Bit(s) + k2; | |||||
| uint64_t b = Fetch64Bit(s + len - 8); | |||||
| uint64_t c = Rotate64(b, 37) * mul + a; | |||||
| uint64_t d = (Rotate64(a, 25) + b) * mul; | |||||
| return HashLen16(c, d, mul); | |||||
| } | |||||
| if (len >= 4) { | |||||
| uint64_t mul = k2 + len * 2; | |||||
| uint64_t a = Fetch32Bit(s); | |||||
| return HashLen16(len + (a << 3), Fetch32Bit(s + len - 4), mul); | |||||
| } | |||||
| if (len > 0) { | |||||
| uint8_t a = s[0]; | |||||
| uint8_t b = s[len >> 1]; | |||||
| uint8_t c = s[len - 1]; | |||||
| uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8); | |||||
| uint32_t z = len + (static_cast<uint32_t>(c) << 2); | |||||
| return ShiftMix(y * k2 ^ z * k0) * k2; | |||||
| } | |||||
| return k2; | |||||
| } | |||||
| uint64_t HashStringLen17to32(const char *s, size_t len) { | |||||
| uint64_t mul = k2 + len * 2; | |||||
| uint64_t a = Fetch64Bit(s) * k1; | |||||
| uint64_t b = Fetch64Bit(s + 8); | |||||
| uint64_t c = Fetch64Bit(s + len - 8) * mul; | |||||
| uint64_t d = Fetch64Bit(s + len - 16) * k2; | |||||
| return HashLen16(Rotate64(a + b, 43) + Rotate64(c, 30) + d, a + Rotate64(b + k2, 18) + c, mul); | |||||
| } | |||||
| uint64_t HashStringLen33to64(const char *s, size_t len) { | |||||
| uint64_t mul = k2 + len * 2; | |||||
| uint64_t a = Fetch64Bit(s) * k2; | |||||
| uint64_t b = Fetch64Bit(s + 8); | |||||
| uint64_t c = Fetch64Bit(s + len - 8) * mul; | |||||
| uint64_t d = Fetch64Bit(s + len - 16) * k2; | |||||
| uint64_t y = Rotate64(a + b, 43) + Rotate64(c, 30) + d; | |||||
| uint64_t z = HashLen16(y, a + Rotate64(b + k2, 18) + c, mul); | |||||
| uint64_t e = Fetch64Bit(s + 16) * mul; | |||||
| uint64_t f = Fetch64Bit(s + 24); | |||||
| uint64_t g = (y + Fetch64Bit(s + len - 32)) * mul; | |||||
| uint64_t h = (z + Fetch64Bit(s + len - 24)) * mul; | |||||
| return HashLen16(Rotate64(e + f, 43) + Rotate64(g, 30) + h, e + Rotate64(f + a, 18) + g, mul); | |||||
| } | |||||
| std::pair<uint64_t, uint64_t> HashLen32WithSeeds(const char *s, uint64_t a, uint64_t b) { | |||||
| a += Fetch64Bit(s); | |||||
| b = Rotate64(b + a + Fetch64Bit(s + 24), 21); | |||||
| uint64_t c = a; | |||||
| a += Fetch64Bit(s + 8); | |||||
| a += Fetch64Bit(s + 16); | |||||
| b += Rotate64(a, 44); | |||||
| return std::make_pair(a + Fetch64Bit(s + 24), b + c); | |||||
| } | |||||
| uint64_t StringHash64(const char *s, size_t len) { | |||||
| uint64_t seed_value = 81; | |||||
| if (len <= 16) { | |||||
| return HashStringLen0to16(s, len); | |||||
| } else if (len <= 32) { | |||||
| return HashStringLen17to32(s, len); | |||||
| } else if (len <= 64) { | |||||
| return HashStringLen33to64(s, len); | |||||
| } | |||||
| uint64_t x = seed_value; | |||||
| uint64_t y = seed_value * k1 + 113; | |||||
| uint64_t tmp = y * k2 + 113; | |||||
| uint64_t z = (tmp ^ (tmp >> 47)) * k2; | |||||
| std::pair<uint64_t, uint64_t> v = std::make_pair(0, 0); | |||||
| std::pair<uint64_t, uint64_t> w = std::make_pair(0, 0); | |||||
| x = x * k2 + Fetch64Bit(s); | |||||
| const char *end = s + ((len - 1) / 64) * 64; | |||||
| const char *last64 = end + ((len - 1) & 63) - 63; | |||||
| MS_ASSERT(s + len - 64 == last64); | |||||
| do { | |||||
| x = Rotate64(x + y + v.first + Fetch64Bit(s + 8), 37) * k1; | |||||
| y = Rotate64(y + v.second + Fetch64Bit(s + 48), 42) * k1; | |||||
| x ^= w.second; | |||||
| y += v.first + Fetch64Bit(s + 40); | |||||
| z = Rotate64(z + w.first, 33) * k1; | |||||
| v = HashLen32WithSeeds(s, v.second * k1, x + w.first); | |||||
| w = HashLen32WithSeeds(s + 32, z + w.second, y + Fetch64Bit(s + 16)); | |||||
| std::swap(z, x); | |||||
| s += 64; | |||||
| } while (s != end); | |||||
| uint64_t mul = k1 + ((z & 0xff) << 1); | |||||
| s = last64; | |||||
| w.first += ((len - 1) & 63); | |||||
| v.first += w.first; | |||||
| w.first += v.first; | |||||
| x = Rotate64(x + y + v.first + Fetch64Bit(s + 8), 37) * mul; | |||||
| y = Rotate64(y + v.second + Fetch64Bit(s + 48), 42) * mul; | |||||
| x ^= w.second * 9; | |||||
| y += v.first * 9 + Fetch64Bit(s + 40); | |||||
| z = Rotate64(z + w.first, 33) * mul; | |||||
| v = HashLen32WithSeeds(s, v.second * mul, x + w.first); | |||||
| w = HashLen32WithSeeds(s + 32, z + w.second, y + Fetch64Bit(s + 16)); | |||||
| std::swap(z, x); | |||||
| return HashLen16(HashLen16(v.first, w.first, mul) + ShiftMix(y) * k0 + z, HashLen16(v.second, w.second, mul) + x, | |||||
| mul); | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,12 +32,23 @@ typedef struct { | |||||
| const char *data; | const char *data; | ||||
| } StringPack; | } StringPack; | ||||
| // example of string tensor: | |||||
| // 3, 0, 0, 0 # int32, num of strings | |||||
| // 20, 0, 0, 0 # int32, offset of 0-th string | |||||
| // 23, 0, 0, 0 # int32, offset of 0-th string | |||||
| // 26, 0, 0, 0 # int32, offset of 0-th string | |||||
| // 29, 0, 0, 0 # int32, total length of tensor data | |||||
| // 'h', 'o', 'w', 'a', 'r', 'e', 'y', 'o', 'u' # char, how are you | |||||
| std::vector<StringPack> ParseTensorBuffer(Tensor *tensor); | std::vector<StringPack> ParseTensorBuffer(Tensor *tensor); | ||||
| std::vector<StringPack> ParseStringBuffer(const void *data); | std::vector<StringPack> ParseStringBuffer(const void *data); | ||||
| int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer); | int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer); | ||||
| int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer); | int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer); | ||||
| int GetStringCount(const void *data); | |||||
| int GetStringCount(Tensor *tensor); | |||||
| uint64_t StringHash64(const char *s, size_t len); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,6 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/custom_extract_features.h" | #include "src/ops/custom_extract_features.h" | ||||
| #include "src/common/string_util.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -30,9 +31,30 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | #endif | ||||
| int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| PrimitiveC::InferShape(inputs_, outputs_); | |||||
| return RET_INFER_INVALID; | |||||
| auto input = inputs_.at(0); | |||||
| MS_ASSERT(input != nullptr); | |||||
| if (input->data_c() == nullptr) { | |||||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| int string_num = lite::GetStringCount(input); | |||||
| auto output0 = outputs_.at(0); | |||||
| auto output1 = outputs_.at(1); | |||||
| MS_ASSERT(output0 != nullptr); | |||||
| MS_ASSERT(output1 != nullptr); | |||||
| std::vector<int> shape; | |||||
| shape.push_back(string_num == 0 ? 1 : string_num); | |||||
| output0->set_shape(shape); | |||||
| output0->set_data_type(input->data_type()); | |||||
| output0->SetFormat(input->GetFormat()); | |||||
| output1->set_shape(shape); | |||||
| output1->set_data_type(input->data_type()); | |||||
| output1->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -4,6 +4,7 @@ file(GLOB KERNEL_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc | ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc | ${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc | ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/string/*.cc | |||||
| ) | ) | ||||
| list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc) | list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc) | ||||
| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/string/extract_feature.h" | |||||
| #include <string> | |||||
| #include "src/kernel_registry.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_CustomExtractFeatures; | |||||
| namespace mindspore::kernel { | |||||
| int ExtractFeatureCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int ExtractFeatureCPUKernel::ReSize() { return RET_OK; } | |||||
| bool ExtractFeatureCPUKernel::IsInBlacklist(const lite::StringPack &str) { | |||||
| std::vector<std::string> kBlacklist = {"<S>", "<E>", "<S> <E>"}; | |||||
| for (const auto &s : kBlacklist) { | |||||
| if (str.len != static_cast<int>(s.length())) { | |||||
| continue; | |||||
| } | |||||
| if (memcmp(str.data, s.data(), str.len) == 0) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| int ExtractFeatureCPUKernel::Run() { | |||||
| auto ret = Prepare(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; | |||||
| return ret; | |||||
| } | |||||
| const int kMaxDimension = 1000000; | |||||
| auto input_tensor = in_tensors_.at(0); | |||||
| auto label_data = reinterpret_cast<int32_t *>(out_tensors_.at(0)->MutableData()); | |||||
| auto weight_data = out_tensors_.at(1)->MutableData(); | |||||
| int string_num = lite::GetStringCount(input_tensor); | |||||
| std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor); | |||||
| for (int i = 0; i < string_num; i++) { | |||||
| lite::StringPack str = all_string_pack[i]; | |||||
| if (IsInBlacklist(str)) { | |||||
| label_data[i] = 0; | |||||
| reinterpret_cast<int32_t *>(weight_data)[i] = 0; | |||||
| continue; | |||||
| } | |||||
| int64_t hash_value = lite::StringHash64(str.data, str.len) % kMaxDimension; | |||||
| label_data[i] = hash_value; | |||||
| reinterpret_cast<float *>(weight_data)[i] = std::count(str.data, str.data + str.len, ' ') + 1; | |||||
| } | |||||
| if (string_num == 0) { | |||||
| label_data[0] = 0; | |||||
| reinterpret_cast<int32_t *>(weight_data)[0] = 0; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuExtractFeatureKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto *kernel = new (std::nothrow) ExtractFeatureCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new ExtractFeatureCPUKernel fail!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomExtractFeatures, CpuExtractFeatureKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "include/context.h" | |||||
| #include "src/common/string_util.h" | |||||
| namespace mindspore::kernel { | |||||
| class ExtractFeatureCPUKernel : public LiteKernel { | |||||
| public: | |||||
| ExtractFeatureCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~ExtractFeatureCPUKernel() {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| bool IsInBlacklist(const lite::StringPack &str); | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_ | |||||
| @@ -28,6 +28,7 @@ file(GLOB KERNEL_OP_SRC | |||||
| ${LITE_DIR}/src/runtime/kernel/arm/base/*.cc | ${LITE_DIR}/src/runtime/kernel/arm/base/*.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/arm/fp32/*.cc | ${LITE_DIR}/src/runtime/kernel/arm/fp32/*.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/arm/int8/*.cc | ${LITE_DIR}/src/runtime/kernel/arm/int8/*.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/arm/string/*.cc | |||||
| ${LITE_DIR}/nnacl/*.c | ${LITE_DIR}/nnacl/*.c | ||||
| ${LITE_DIR}/nnacl/fp32/*.c | ${LITE_DIR}/nnacl/fp32/*.c | ||||
| ${LITE_DIR}/nnacl/int8/*.c | ${LITE_DIR}/nnacl/int8/*.c | ||||