| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| typedef struct { | |||||
| OpParameter op_parameter_; | |||||
| int output_num; | |||||
| float weight_threshold; | |||||
| } PredictParameter; | |||||
| typedef struct { | |||||
| int label; | |||||
| float weight; | |||||
| } LabelInfo; | |||||
| #endif // MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ | |||||
| @@ -18,8 +18,20 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int CustomPredict::GetOutputNum() const { return this->primitive_->value.AsCustomPredict()->outputNum; } | |||||
| float CustomPredict::GetWeightThreshold() const { return this->primitive_->value.AsCustomPredict()->weightThreshold; } | |||||
| void CustomPredict::SetOutputNum(int output_num) { this->primitive_->value.AsCustomPredict()->outputNum = output_num; } | |||||
| void CustomPredict::SetWeightThreshold(float weight_threshold) { | |||||
| this->primitive_->value.AsCustomPredict()->weightThreshold = weight_threshold; | |||||
| } | |||||
| int CustomPredict::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_OK; } | int CustomPredict::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_OK; } | ||||
| #else | #else | ||||
| int CustomPredict::GetOutputNum() const { return this->primitive_->value_as_CustomPredict()->outputNum(); } | |||||
| float CustomPredict::GetWeightThreshold() const { | |||||
| return this->primitive_->value_as_CustomPredict()->weightThreshold(); | |||||
| } | |||||
| int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| MS_ASSERT(nullptr != fbb); | MS_ASSERT(nullptr != fbb); | ||||
| @@ -30,8 +42,23 @@ int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||||
| } | } | ||||
| #endif | #endif | ||||
| int CustomPredict::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int CustomPredict::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| PrimitiveC::InferShape(inputs_, outputs_); | |||||
| return RET_INFER_INVALID; | |||||
| auto input = inputs_.at(0); | |||||
| auto output0 = outputs_.at(0); | |||||
| auto output1 = outputs_.at(1); | |||||
| MS_ASSERT(input != nullptr); | |||||
| MS_ASSERT(output0 != nullptr); | |||||
| MS_ASSERT(output1 != nullptr); | |||||
| std::vector<int> shape; | |||||
| shape.push_back(GetOutputNum()); | |||||
| output0->set_shape(shape); | |||||
| output0->set_data_type(kNumberTypeInt32); | |||||
| output0->SetFormat(input->GetFormat()); | |||||
| output1->set_shape(shape); | |||||
| output1->set_data_type(kNumberTypeFloat32); | |||||
| output1->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,9 +27,15 @@ class CustomPredict : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(CustomPredict, PrimitiveC); | MS_DECLARE_PARENT(CustomPredict, PrimitiveC); | ||||
| CustomPredict() = default; | CustomPredict() = default; | ||||
| explicit CustomPredict(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit CustomPredict(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| int GetOutputNum() const; | |||||
| float GetWeightThreshold() const; | |||||
| void SetOutputNum(int output_num); | |||||
| void SetWeightThreshold(float weight_threshold); | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | ||||
| #else | #else | ||||
| CustomPredict() = default; | CustomPredict() = default; | ||||
| int GetOutputNum() const; | |||||
| float GetWeightThreshold() const; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) override; | int InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) override; | ||||
| @@ -116,6 +116,7 @@ | |||||
| #include "src/ops/neg.h" | #include "src/ops/neg.h" | ||||
| #include "src/ops/detection_post_process.h" | #include "src/ops/detection_post_process.h" | ||||
| #include "src/ops/skip_gram.h" | #include "src/ops/skip_gram.h" | ||||
| #include "src/ops/custom_predict.h" | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/fp32/arg_min_max.h" | #include "nnacl/fp32/arg_min_max.h" | ||||
| #include "nnacl/fp32/cast.h" | #include "nnacl/fp32/cast.h" | ||||
| @@ -176,6 +177,7 @@ | |||||
| #include "nnacl/detection_post_process_parameter.h" | #include "nnacl/detection_post_process_parameter.h" | ||||
| #include "nnacl/fp32/exp.h" | #include "nnacl/fp32/exp.h" | ||||
| #include "nnacl/fp32/skip_gram.h" | #include "nnacl/fp32/skip_gram.h" | ||||
| #include "nnacl/predict_parameter.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -1603,6 +1605,31 @@ OpParameter *PopulateSkipGramParameter(const mindspore::lite::PrimitiveC *primit | |||||
| return reinterpret_cast<OpParameter *>(skipGramParameter); | return reinterpret_cast<OpParameter *>(skipGramParameter); | ||||
| } | } | ||||
| OpParameter *PopulateCommonOpParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "new OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(OpParameter)); | |||||
| param->type_ = primitive->Type(); | |||||
| return param; | |||||
| } | |||||
| OpParameter *PopulateCustomPredictParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| PredictParameter *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc param failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(param, 0, sizeof(PredictParameter)); | |||||
| param->op_parameter_.type_ = primitive->Type(); | |||||
| auto prim = reinterpret_cast<mindspore::lite::CustomPredict *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| param->output_num = prim->GetOutputNum(); | |||||
| param->weight_threshold = prim->GetWeightThreshold(); | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | |||||
| PopulateParameterRegistry::PopulateParameterRegistry() { | PopulateParameterRegistry::PopulateParameterRegistry() { | ||||
| populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter; | populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; | populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; | ||||
| @@ -1706,6 +1733,8 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||||
| populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter; | populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter; | populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; | populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; | |||||
| populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; | |||||
| } | } | ||||
| PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | ||||
| @@ -0,0 +1,118 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/string/predict.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include "src/kernel_registry.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_CustomPredict; | |||||
| namespace mindspore::kernel { | |||||
| int PredictCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int PredictCPUKernel::ReSize() { return RET_OK; } | |||||
| std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() { | |||||
| std::vector<LabelInfo> label_info_vec; | |||||
| auto input_tensor = in_tensors_.at(0); | |||||
| auto keys_tensor = in_tensors_.at(1); | |||||
| auto labels_tensor = in_tensors_.at(2); | |||||
| auto weights_tensor = in_tensors_.at(3); | |||||
| int32_t *input = reinterpret_cast<int32_t *>(input_tensor->MutableData()); | |||||
| int32_t *key_begin = reinterpret_cast<int32_t *>(keys_tensor->MutableData()); | |||||
| int32_t *key_end = key_begin + keys_tensor->ElementsNum(); | |||||
| int32_t *labels = reinterpret_cast<int32_t *>(labels_tensor->MutableData()); | |||||
| float *weights = reinterpret_cast<float *>(weights_tensor->MutableData()); | |||||
| int32_t input_elements_num = input_tensor->ElementsNum(); | |||||
| int32_t items = labels_tensor->shape().at(1); | |||||
| for (int i = 0; i < input_elements_num; i++) { | |||||
| int *p = std::lower_bound(key_begin, key_end, input[i]); | |||||
| if (p == nullptr || p == key_end || *p != input[i]) { | |||||
| continue; | |||||
| } | |||||
| int index = p - key_begin; | |||||
| for (int j = 0; j < items; j++) { | |||||
| int offset = index * items + j; | |||||
| auto it = std::find_if(label_info_vec.begin(), label_info_vec.end(), | |||||
| [&](const LabelInfo &element) { return element.label == labels[offset]; }); | |||||
| if (it != label_info_vec.end()) { | |||||
| it->weight += weights[offset] / input_elements_num; | |||||
| } else { | |||||
| LabelInfo tmp = {labels[offset], weights[offset] / input_elements_num}; | |||||
| label_info_vec.push_back(tmp); | |||||
| } | |||||
| } | |||||
| } | |||||
| return label_info_vec; | |||||
| } | |||||
| bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } | |||||
| int PredictCPUKernel::Run() { | |||||
| auto ret = Prepare(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; | |||||
| return ret; | |||||
| } | |||||
| std::vector<LabelInfo> label_info_vec = GetLabelInfo(); | |||||
| std::sort(label_info_vec.begin(), label_info_vec.end(), LabelInfoCmp); | |||||
| auto output_label_tensor = out_tensors_.at(0); | |||||
| auto output_weight_tensor = out_tensors_.at(1); | |||||
| auto output_label = reinterpret_cast<int32_t *>(output_label_tensor->MutableData()); | |||||
| auto output_weight = reinterpret_cast<float *>(output_weight_tensor->MutableData()); | |||||
| auto param = reinterpret_cast<PredictParameter *>(op_parameter_); | |||||
| for (int i = 0; i < output_label_tensor->ElementsNum(); i++) { | |||||
| if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) { | |||||
| output_label[i] = -1; | |||||
| output_weight[i] = 0.0f; | |||||
| } | |||||
| output_label[i] = label_info_vec[i].label; | |||||
| output_weight[i] = label_info_vec[i].weight; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuPredictKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto *kernel = new (std::nothrow) PredictCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new PredictCPUKernel fail!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomPredict, CpuPredictKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "include/context.h" | |||||
| #include "nnacl/predict_parameter.h" | |||||
| namespace mindspore::kernel { | |||||
| class PredictCPUKernel : public LiteKernel { | |||||
| public: | |||||
| PredictCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~PredictCPUKernel() {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| std::vector<LabelInfo> GetLabelInfo(); | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_ | |||||