Browse Source

nms op registry

tags/v1.1.0
zhaozhenlong 5 years ago
parent
commit
a43fe26a01
2 changed files with 13 additions and 1 deletions
  1. +11
    -0
      mindspore/lite/src/ops/non_max_suppression.cc
  2. +2
    -1
      mindspore/lite/src/ops/populate/non_max_suppression_populate.cc

+ 11
- 0
mindspore/lite/src/ops/non_max_suppression.cc View File

@@ -16,6 +16,10 @@

#include "src/ops/non_max_suppression.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -44,6 +48,13 @@ int NonMaxSuppression::UnPackToFlatBuilder(const schema::Primitive *primitive, f
int NonMaxSuppression::GetCenterPointBox() const {
return this->primitive_->value_as_NonMaxSuppression()->centerPointBox();
}

PrimitiveC *NonMaxSuppressionCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<NonMaxSuppression>(primitive);
}

Registry NonMaxSuppressionRegistry(schema::PrimitiveType_NonMaxSuppression, NonMaxSuppressionCreator);

#endif
int NonMaxSuppression::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();


+ 2
- 1
mindspore/lite/src/ops/populate/non_max_suppression_populate.cc View File

@@ -35,7 +35,8 @@ OpParameter *PopulateNonMaxSuppressionParameter(const mindspore::lite::Primitive
param->center_point_box_ = prim->GetCenterPointBox();
return reinterpret_cast<OpParameter *>(param);
}
Registry NonMaxSuppressionParameterRegistry(schema::PrimitiveType_OneHot, PopulateNonMaxSuppressionParameter);
Registry NonMaxSuppressionParameterRegistry(schema::PrimitiveType_NonMaxSuppression,
PopulateNonMaxSuppressionParameter);

} // namespace lite
} // namespace mindspore

Loading…
Cancel
Save