diff --git a/mindspore/lite/nnacl/predict_parameter.h b/mindspore/lite/nnacl/predict_parameter.h new file mode 100644 index 0000000000..997523f599 --- /dev/null +++ b/mindspore/lite/nnacl/predict_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ +#define MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ + +#include "nnacl/op_base.h" +typedef struct { + OpParameter op_parameter_; + int output_num; + float weight_threshold; +} PredictParameter; + +typedef struct { + int label; + float weight; +} LabelInfo; +#endif // MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_ diff --git a/mindspore/lite/src/ops/custom_predict.cc b/mindspore/lite/src/ops/custom_predict.cc index ea630e912f..44044f1346 100644 --- a/mindspore/lite/src/ops/custom_predict.cc +++ b/mindspore/lite/src/ops/custom_predict.cc @@ -18,8 +18,20 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE +int CustomPredict::GetOutputNum() const { return this->primitive_->value.AsCustomPredict()->outputNum; } +float CustomPredict::GetWeightThreshold() const { return this->primitive_->value.AsCustomPredict()->weightThreshold; } + +void CustomPredict::SetOutputNum(int output_num) { this->primitive_->value.AsCustomPredict()->outputNum = output_num; } +void CustomPredict::SetWeightThreshold(float weight_threshold) { + this->primitive_->value.AsCustomPredict()->weightThreshold = weight_threshold; +} int CustomPredict::UnPackAttr(const Primitive &prim, const std::vector &inputs) { return RET_OK; } #else +int CustomPredict::GetOutputNum() const { return this->primitive_->value_as_CustomPredict()->outputNum(); } +float CustomPredict::GetWeightThreshold() const { + return this->primitive_->value_as_CustomPredict()->weightThreshold(); +} + int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); @@ -30,8 +42,23 @@ int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb } #endif int CustomPredict::InferShape(std::vector inputs_, std::vector outputs_) { - PrimitiveC::InferShape(inputs_, outputs_); - return RET_INFER_INVALID; + auto input = inputs_.at(0); + auto output0 = outputs_.at(0); + auto output1 = outputs_.at(1); + MS_ASSERT(input != nullptr); + MS_ASSERT(output0 != nullptr); + MS_ASSERT(output1 != nullptr); + + std::vector shape; + shape.push_back(GetOutputNum()); + + output0->set_shape(shape); + output0->set_data_type(kNumberTypeInt32); + output0->SetFormat(input->GetFormat()); + output1->set_shape(shape); + output1->set_data_type(kNumberTypeFloat32); + output1->SetFormat(input->GetFormat()); + return RET_OK; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/custom_predict.h b/mindspore/lite/src/ops/custom_predict.h index 4706be2cb7..fc29961fb4 100644 --- a/mindspore/lite/src/ops/custom_predict.h +++ b/mindspore/lite/src/ops/custom_predict.h @@ -27,9 +27,15 @@ class CustomPredict : public PrimitiveC { MS_DECLARE_PARENT(CustomPredict, PrimitiveC); CustomPredict() = default; explicit CustomPredict(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int GetOutputNum() const; + float GetWeightThreshold() const; + void SetOutputNum(int output_num); + void SetWeightThreshold(float weight_threshold); int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else CustomPredict() = default; + int GetOutputNum() const; + float GetWeightThreshold() const; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index f9d93610c5..3aa71f0a03 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -116,6 +116,7 @@ #include "src/ops/neg.h" #include "src/ops/detection_post_process.h" #include "src/ops/skip_gram.h" +#include "src/ops/custom_predict.h" #include "nnacl/op_base.h" #include "nnacl/fp32/arg_min_max.h" #include "nnacl/fp32/cast.h" @@ -176,6 +177,7 @@ #include "nnacl/detection_post_process_parameter.h" #include "nnacl/fp32/exp.h" #include "nnacl/fp32/skip_gram.h" +#include "nnacl/predict_parameter.h" namespace mindspore::kernel { @@ -1603,6 +1605,31 @@ OpParameter *PopulateSkipGramParameter(const mindspore::lite::PrimitiveC *primit return reinterpret_cast(skipGramParameter); } +OpParameter *PopulateCommonOpParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *param = reinterpret_cast(malloc(sizeof(OpParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "new OpParameter failed."; + return nullptr; + } + memset(param, 0, sizeof(OpParameter)); + param->type_ = primitive->Type(); + return param; +} + +OpParameter *PopulateCustomPredictParameter(const mindspore::lite::PrimitiveC *primitive) { + PredictParameter *param = reinterpret_cast(malloc(sizeof(PredictParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "malloc param failed."; + return nullptr; + } + memset(param, 0, sizeof(PredictParameter)); + param->op_parameter_.type_ = primitive->Type(); + auto prim = reinterpret_cast(const_cast(primitive)); + param->output_num = prim->GetOutputNum(); + param->weight_threshold = prim->GetWeightThreshold(); + return reinterpret_cast(param); +} + PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter; populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; @@ -1706,6 +1733,8 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter; populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter; populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter; + populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter; + populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; } PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { diff --git a/mindspore/lite/src/runtime/kernel/arm/string/predict.cc b/mindspore/lite/src/runtime/kernel/arm/string/predict.cc new file mode 100644 index 0000000000..c56c2a978d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/string/predict.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/string/predict.h" +#include +#include +#include "src/kernel_registry.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_CustomPredict; + +namespace mindspore::kernel { +int PredictCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int PredictCPUKernel::ReSize() { return RET_OK; } + +std::vector PredictCPUKernel::GetLabelInfo() { + std::vector label_info_vec; + auto input_tensor = in_tensors_.at(0); + auto keys_tensor = in_tensors_.at(1); + auto labels_tensor = in_tensors_.at(2); + auto weights_tensor = in_tensors_.at(3); + + int32_t *input = reinterpret_cast(input_tensor->MutableData()); + int32_t *key_begin = reinterpret_cast(keys_tensor->MutableData()); + int32_t *key_end = key_begin + keys_tensor->ElementsNum(); + int32_t *labels = reinterpret_cast(labels_tensor->MutableData()); + float *weights = reinterpret_cast(weights_tensor->MutableData()); + + int32_t input_elements_num = input_tensor->ElementsNum(); + int32_t items = labels_tensor->shape().at(1); + + for (int i = 0; i < input_elements_num; i++) { + int *p = std::lower_bound(key_begin, key_end, input[i]); + if (p == nullptr || p == key_end || *p != input[i]) { + continue; + } + int index = p - key_begin; + for (int j = 0; j < items; j++) { + int offset = index * items + j; + auto it = std::find_if(label_info_vec.begin(), label_info_vec.end(), + [&](const LabelInfo &element) { return element.label == labels[offset]; }); + if (it != label_info_vec.end()) { + it->weight += weights[offset] / input_elements_num; + } else { + LabelInfo tmp = {labels[offset], weights[offset] / input_elements_num}; + label_info_vec.push_back(tmp); + } + } + } + return label_info_vec; +} + +bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; } + +int PredictCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret; + return ret; + } + std::vector label_info_vec = GetLabelInfo(); + std::sort(label_info_vec.begin(), label_info_vec.end(), LabelInfoCmp); + + auto output_label_tensor = out_tensors_.at(0); + auto output_weight_tensor = out_tensors_.at(1); + auto output_label = reinterpret_cast(output_label_tensor->MutableData()); + auto output_weight = reinterpret_cast(output_weight_tensor->MutableData()); + auto param = reinterpret_cast(op_parameter_); + for (int i = 0; i < output_label_tensor->ElementsNum(); i++) { + if (static_cast(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) { + output_label[i] = -1; + output_weight[i] = 0.0f; + } + output_label[i] = label_info_vec[i].label; + output_weight[i] = label_info_vec[i].weight; + } + return RET_OK; +} + +kernel::LiteKernel *CpuPredictKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) PredictCPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PredictCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomPredict, CpuPredictKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/string/predict.h b/mindspore/lite/src/runtime/kernel/arm/string/predict.h new file mode 100644 index 0000000000..4239c6de78 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/string/predict.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "nnacl/predict_parameter.h" + +namespace mindspore::kernel { +class PredictCPUKernel : public LiteKernel { + public: + PredictCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~PredictCPUKernel() {} + + int Init() override; + int ReSize() override; + int Run() override; + + private: + std::vector GetLabelInfo(); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_