Browse Source

add predict kernel

tags/v1.1.0
sunsuodong 5 years ago
parent
commit
2f967536ed
6 changed files with 254 additions and 2 deletions
  1. +30
    -0
      mindspore/lite/nnacl/predict_parameter.h
  2. +29
    -2
      mindspore/lite/src/ops/custom_predict.cc
  3. +6
    -0
      mindspore/lite/src/ops/custom_predict.h
  4. +29
    -0
      mindspore/lite/src/populate_parameter.cc
  5. +118
    -0
      mindspore/lite/src/runtime/kernel/arm/string/predict.cc
  6. +42
    -0
      mindspore/lite/src/runtime/kernel/arm/string/predict.h

+ 30
- 0
mindspore/lite/nnacl/predict_parameter.h View File

@@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_

#include "nnacl/op_base.h"
typedef struct {
OpParameter op_parameter_;
int output_num;
float weight_threshold;
} PredictParameter;

typedef struct {
int label;
float weight;
} LabelInfo;
#endif // MINDSPORE_LITE_NNACL_PREDICT_PARAMETER_H_

+ 29
- 2
mindspore/lite/src/ops/custom_predict.cc View File

@@ -18,8 +18,20 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int CustomPredict::GetOutputNum() const { return this->primitive_->value.AsCustomPredict()->outputNum; }
float CustomPredict::GetWeightThreshold() const { return this->primitive_->value.AsCustomPredict()->weightThreshold; }

void CustomPredict::SetOutputNum(int output_num) { this->primitive_->value.AsCustomPredict()->outputNum = output_num; }
void CustomPredict::SetWeightThreshold(float weight_threshold) {
this->primitive_->value.AsCustomPredict()->weightThreshold = weight_threshold;
}
int CustomPredict::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_OK; }
#else
int CustomPredict::GetOutputNum() const { return this->primitive_->value_as_CustomPredict()->outputNum(); }
float CustomPredict::GetWeightThreshold() const {
return this->primitive_->value_as_CustomPredict()->weightThreshold();
}

int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
@@ -30,8 +42,23 @@ int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
}
#endif
int CustomPredict::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
PrimitiveC::InferShape(inputs_, outputs_);
return RET_INFER_INVALID;
auto input = inputs_.at(0);
auto output0 = outputs_.at(0);
auto output1 = outputs_.at(1);
MS_ASSERT(input != nullptr);
MS_ASSERT(output0 != nullptr);
MS_ASSERT(output1 != nullptr);

std::vector<int> shape;
shape.push_back(GetOutputNum());

output0->set_shape(shape);
output0->set_data_type(kNumberTypeInt32);
output0->SetFormat(input->GetFormat());
output1->set_shape(shape);
output1->set_data_type(kNumberTypeFloat32);
output1->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 6
- 0
mindspore/lite/src/ops/custom_predict.h View File

@@ -27,9 +27,15 @@ class CustomPredict : public PrimitiveC {
MS_DECLARE_PARENT(CustomPredict, PrimitiveC);
CustomPredict() = default;
explicit CustomPredict(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int GetOutputNum() const;
float GetWeightThreshold() const;
void SetOutputNum(int output_num);
void SetWeightThreshold(float weight_threshold);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
CustomPredict() = default;
int GetOutputNum() const;
float GetWeightThreshold() const;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) override;


+ 29
- 0
mindspore/lite/src/populate_parameter.cc View File

@@ -116,6 +116,7 @@
#include "src/ops/neg.h"
#include "src/ops/detection_post_process.h"
#include "src/ops/skip_gram.h"
#include "src/ops/custom_predict.h"
#include "nnacl/op_base.h"
#include "nnacl/fp32/arg_min_max.h"
#include "nnacl/fp32/cast.h"
@@ -176,6 +177,7 @@
#include "nnacl/detection_post_process_parameter.h"
#include "nnacl/fp32/exp.h"
#include "nnacl/fp32/skip_gram.h"
#include "nnacl/predict_parameter.h"

namespace mindspore::kernel {

@@ -1603,6 +1605,31 @@ OpParameter *PopulateSkipGramParameter(const mindspore::lite::PrimitiveC *primit
return reinterpret_cast<OpParameter *>(skipGramParameter);
}

OpParameter *PopulateCommonOpParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "new OpParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}

OpParameter *PopulateCustomPredictParameter(const mindspore::lite::PrimitiveC *primitive) {
PredictParameter *param = reinterpret_cast<PredictParameter *>(malloc(sizeof(PredictParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc param failed.";
return nullptr;
}
memset(param, 0, sizeof(PredictParameter));
param->op_parameter_.type_ = primitive->Type();
auto prim = reinterpret_cast<mindspore::lite::CustomPredict *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
param->output_num = prim->GetOutputNum();
param->weight_threshold = prim->GetWeightThreshold();
return reinterpret_cast<OpParameter *>(param);
}

PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter;
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
@@ -1706,6 +1733,8 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter;
populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter;
populate_parameter_funcs_[schema::PrimitiveType_SkipGram] = PopulateSkipGramParameter;
populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter;
populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter;
}

PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {


+ 118
- 0
mindspore/lite/src/runtime/kernel/arm/string/predict.cc View File

@@ -0,0 +1,118 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/string/predict.h"
#include <string>
#include <algorithm>
#include "src/kernel_registry.h"

using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_CustomPredict;

namespace mindspore::kernel {
int PredictCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int PredictCPUKernel::ReSize() { return RET_OK; }

std::vector<LabelInfo> PredictCPUKernel::GetLabelInfo() {
std::vector<LabelInfo> label_info_vec;
auto input_tensor = in_tensors_.at(0);
auto keys_tensor = in_tensors_.at(1);
auto labels_tensor = in_tensors_.at(2);
auto weights_tensor = in_tensors_.at(3);

int32_t *input = reinterpret_cast<int32_t *>(input_tensor->MutableData());
int32_t *key_begin = reinterpret_cast<int32_t *>(keys_tensor->MutableData());
int32_t *key_end = key_begin + keys_tensor->ElementsNum();
int32_t *labels = reinterpret_cast<int32_t *>(labels_tensor->MutableData());
float *weights = reinterpret_cast<float *>(weights_tensor->MutableData());

int32_t input_elements_num = input_tensor->ElementsNum();
int32_t items = labels_tensor->shape().at(1);

for (int i = 0; i < input_elements_num; i++) {
int *p = std::lower_bound(key_begin, key_end, input[i]);
if (p == nullptr || p == key_end || *p != input[i]) {
continue;
}
int index = p - key_begin;
for (int j = 0; j < items; j++) {
int offset = index * items + j;
auto it = std::find_if(label_info_vec.begin(), label_info_vec.end(),
[&](const LabelInfo &element) { return element.label == labels[offset]; });
if (it != label_info_vec.end()) {
it->weight += weights[offset] / input_elements_num;
} else {
LabelInfo tmp = {labels[offset], weights[offset] / input_elements_num};
label_info_vec.push_back(tmp);
}
}
}
return label_info_vec;
}

bool LabelInfoCmp(const LabelInfo &lhs, const LabelInfo &rhs) { return lhs.weight > rhs.weight; }

int PredictCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
return ret;
}
std::vector<LabelInfo> label_info_vec = GetLabelInfo();
std::sort(label_info_vec.begin(), label_info_vec.end(), LabelInfoCmp);

auto output_label_tensor = out_tensors_.at(0);
auto output_weight_tensor = out_tensors_.at(1);
auto output_label = reinterpret_cast<int32_t *>(output_label_tensor->MutableData());
auto output_weight = reinterpret_cast<float *>(output_weight_tensor->MutableData());
auto param = reinterpret_cast<PredictParameter *>(op_parameter_);
for (int i = 0; i < output_label_tensor->ElementsNum(); i++) {
if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) {
output_label[i] = -1;
output_weight[i] = 0.0f;
}
output_label[i] = label_info_vec[i].label;
output_weight[i] = label_info_vec[i].weight;
}
return RET_OK;
}

kernel::LiteKernel *CpuPredictKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) PredictCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PredictCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomPredict, CpuPredictKernelCreator)
} // namespace mindspore::kernel

+ 42
- 0
mindspore/lite/src/runtime/kernel/arm/string/predict.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_

#include <vector>
#include "src/lite_kernel.h"
#include "include/context.h"
#include "nnacl/predict_parameter.h"

namespace mindspore::kernel {
class PredictCPUKernel : public LiteKernel {
public:
PredictCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~PredictCPUKernel() {}

int Init() override;
int ReSize() override;
int Run() override;

private:
std::vector<LabelInfo> GetLabelInfo();
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_PREDICT_H_

Loading…
Cancel
Save