| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ | |||
| #define MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| typedef struct { | |||
| OpParameter op_parameter_; | |||
| int output_num; | |||
| float weight_threshold; | |||
| } PredictParameter; | |||
| typedef struct { | |||
| int label; | |||
| float weight; | |||
| } LabelInfo; | |||
| #endif // MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ | |||
| @@ -18,8 +18,20 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int CustomPredict::GetOutputNum() const { return this->primitive_->value.AsCustomPredict()->outputNum; } | |||
| float CustomPredict::GetWeightThreshold() const { return this->primitive_->value.AsCustomPredict()->weightThreshold; } | |||
| void CustomPredict::SetOutputNum(int output_num) { this->primitive_->value.AsCustomPredict()->outputNum = output_num; } | |||
| void CustomPredict::SetWeightThreshold(float weight_threshold) { | |||
| this->primitive_->value.AsCustomPredict()->weightThreshold = weight_threshold; | |||
| } | |||
| int CustomPredict::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_OK; } | |||
| #else | |||
| int CustomPredict::GetOutputNum() const { return this->primitive_->value_as_CustomPredict()->outputNum(); } | |||
| float CustomPredict::GetWeightThreshold() const { | |||
| return this->primitive_->value_as_CustomPredict()->weightThreshold(); | |||
| } | |||
| int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -30,8 +42,23 @@ int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||
| } | |||
| #endif | |||
| int CustomPredict::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| PrimitiveC::InferShape(inputs_, outputs_); | |||
| return RET_INFER_INVALID; | |||
| auto input = inputs_.at(0); | |||
| auto output0 = outputs_.at(0); | |||
| auto output1 = outputs_.at(1); | |||
| MS_ASSERT(input != nullptr); | |||
| MS_ASSERT(output0 != nullptr); | |||
| MS_ASSERT(output1 != nullptr); | |||
| std::vector<int> shape; | |||
| shape.push_back(GetOutputNum()); | |||
| output0->set_shape(shape); | |||
| output0->set_data_type(kNumberTypeInt32); | |||
| output0->SetFormat(input->GetFormat()); | |||
| output1->set_shape(shape); | |||
| output1->set_data_type(kNumberTypeFloat32); | |||
| output1->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -27,9 +27,15 @@ class CustomPredict : public PrimitiveC { | |||
| MS_DECLARE_PARENT(CustomPredict, PrimitiveC); | |||
| CustomPredict() = default; | |||
| explicit CustomPredict(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int GetOutputNum() const; | |||
| float GetWeightThreshold() const; | |||
| void SetOutputNum(int output_num); | |||
| void SetWeightThreshold(float weight_threshold); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| CustomPredict() = default; | |||
| int GetOutputNum() const; | |||
| float GetWeightThreshold() const; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) override; | |||
| @@ -116,6 +116,7 @@ | |||
| #include "src/ops/neg.h" | |||
| #include "src/ops/detection_post_process.h" | |||
| #include "src/ops/skip_gram.h" | |||
| #include "src/ops/custom_predict.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/fp32/arg_min_max.h" | |||
| #include "nnacl/fp32/cast.h" | |||
| @@ -176,6 +177,7 @@ | |||
| #include "nnacl/detection_post_process_parameter.h" | |||
| #include "nnacl/fp32/exp.h" | |||
| #include "nnacl/fp32/skip_gram.h" | |||
| #include "nnacl/predict_parameter.h" | |||
| namespace mindspore::kernel { | |||
| @@ -1603,6 +1605,31 @@ OpParameter *PopulateSkipGramParameter(const mindspore::lite::PrimitiveC *primit | |||
| return reinterpret_cast<OpParameter *>(skipGramParameter); | |||
| } | |||
| OpParameter *PopulateCommonOpParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "new OpParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(param, 0, sizeof(OpParameter)); | |||
| param->type_ = primitive->Type(); | |||
| return param; | |||
| } | |||
| OpParameter *PopulateCustomPredictParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| PredictParameter *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "malloc param failed."; | |||
| return nullptr; | |||
| } | |||
| memset(param, 0, sizeof(PredictParameter)); | |||
| param->op_parameter_.type_ = primitive->Type(); | |||
| auto prim = reinterpret_cast<mindspore::lite::CustomPredict *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| param->output_num = prim->GetOutputNum(); | |||
| param->weight_threshold = prim->GetWeightThreshold(); | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| PopulateParameterRegistry::PopulateParameterRegistry() { | |||
| populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; | |||
| @@ -1706,6 +1733,8 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||
| populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; | |||
| populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; | |||
| } | |||
| PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { | |||
| @@ -0,0 +1,118 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/string/predict.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "src/kernel_registry.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_CustomPredict; | |||
| namespace mindspore::kernel { | |||
| int PredictCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int PredictCPUKernel::ReSize() { return RET_OK; } | |||
| std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() { | |||
| std::vector<LabelInfo> label_info_vec; | |||
| auto input_tensor = in_tensors_.at(0); | |||
| auto keys_tensor = in_tensors_.at(1); | |||
| auto labels_tensor = in_tensors_.at(2); | |||
| auto weights_tensor = in_tensors_.at(3); | |||
| int32_t *input = reinterpret_cast<int32_t *>(input_tensor->MutableData()); | |||
| int32_t *key_begin = reinterpret_cast<int32_t *>(keys_tensor->MutableData()); | |||
| int32_t *key_end = key_begin + keys_tensor->ElementsNum(); | |||
| int32_t *labels = reinterpret_cast<int32_t *>(labels_tensor->MutableData()); | |||
| float *weights = reinterpret_cast<float *>(weights_tensor->MutableData()); | |||
| int32_t input_elements_num = input_tensor->ElementsNum(); | |||
| int32_t items = labels_tensor->shape().at(1); | |||
| for (int i = 0; i < input_elements_num; i++) { | |||
| int *p = std::lower_bound(key_begin, key_end, input[i]); | |||
| if (p == nullptr || p == key_end || *p != input[i]) { | |||
| continue; | |||
| } | |||
| int index = p - key_begin; | |||
| for (int j = 0; j < items; j++) { | |||
| int offset = index * items + j; | |||
| auto it = std::find_if(label_info_vec.begin(), label_info_vec.end(), | |||
| [&](const LabelInfo &element) { return element.label == labels[offset]; }); | |||
| if (it != label_info_vec.end()) { | |||
| it->weight += weights[offset] / input_elements_num; | |||
| } else { | |||
| LabelInfo tmp = {labels[offset], weights[offset] / input_elements_num}; | |||
| label_info_vec.push_back(tmp); | |||
| } | |||
| } | |||
| } | |||
| return label_info_vec; | |||
| } | |||
| bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } | |||
| int PredictCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; | |||
| return ret; | |||
| } | |||
| std::vector<LabelInfo> label_info_vec = GetLabelInfo(); | |||
| std::sort(label_info_vec.begin(), label_info_vec.end(), LabelInfoCmp); | |||
| auto output_label_tensor = out_tensors_.at(0); | |||
| auto output_weight_tensor = out_tensors_.at(1); | |||
| auto output_label = reinterpret_cast<int32_t *>(output_label_tensor->MutableData()); | |||
| auto output_weight = reinterpret_cast<float *>(output_weight_tensor->MutableData()); | |||
| auto param = reinterpret_cast<PredictParameter *>(op_parameter_); | |||
| for (int i = 0; i < output_label_tensor->ElementsNum(); i++) { | |||
| if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) { | |||
| output_label[i] = -1; | |||
| output_weight[i] = 0.0f; | |||
| } | |||
| output_label[i] = label_info_vec[i].label; | |||
| output_weight[i] = label_info_vec[i].weight; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuPredictKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *kernel = new (std::nothrow) PredictCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new PredictCPUKernel fail!"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomPredict, CpuPredictKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "include/context.h" | |||
| #include "nnacl/predict_parameter.h" | |||
| namespace mindspore::kernel { | |||
| class PredictCPUKernel : public LiteKernel { | |||
| public: | |||
| PredictCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~PredictCPUKernel() {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| std::vector<LabelInfo> GetLabelInfo(); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_ | |||