Browse Source

!10115 [MS][LITE]Add Mod parser and ArgMinWithValue parser

From: @gongdaguo
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e7aec28280
8 changed files with 161 additions and 1 deletions
  1. +1
    -0
      mindspore/lite/schema/model.fbs
  2. +3
    -0
      mindspore/lite/schema/ops.fbs
  3. +36
    -1
      mindspore/lite/src/ops/argmin.cc
  4. +1
    -0
      mindspore/lite/src/ops/argmin.h
  5. +70
    -0
      mindspore/lite/src/ops/mod.cc
  6. +42
    -0
      mindspore/lite/src/ops/mod.h
  7. +1
    -0
      mindspore/lite/src/ops/populate/arithmetic_populate.cc
  8. +7
    -0
      mindspore/lite/src/ops/primitive_c.cc

+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -260,6 +260,7 @@ union PrimitiveType {
SigmoidCrossEntropyWithLogitsGrad,
Reciprocal,
Merge,
Mod,
}

enum QuantType: int {


+ 3
- 0
mindspore/lite/schema/ops.fbs View File

@@ -890,6 +890,9 @@ table FloorDiv {
table FloorMod {
}

table Mod {
}

table L2Norm {
axis: [int];
epsilon: float;


+ 36
- 1
mindspore/lite/src/ops/argmin.cc View File

@@ -35,6 +35,36 @@ void ArgMin::SetTopK(int top_k) { this->primitive_->value.AsArgMin()->topK = top
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->keepDims = keep_dims; }
void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; }

int ArgMin::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_ArgMin;
}
if (this->primitive_->value.type != schema::PrimitiveType_ArgMin) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ArgMinT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (prim.GetAttr("axis") != nullptr) {
attr->axis = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("axis")));
}
if (prim.GetAttr("keep_dims") != nullptr) {
attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims")));
}
}
return RET_OK;
}

#else
int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
@@ -66,7 +96,7 @@ int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
if (inputs_.size() != kSingleNum || outputs_.size() > kDoubleNum) {
MS_LOG(ERROR) << "tensor number is error.";
}
output->set_format(input->format());
@@ -88,6 +118,11 @@ int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
}

output->set_shape(output_shape);
if (outputs_.size() == kDoubleNum) {
outputs_.at(1)->set_format(input->format());
outputs_.at(1)->set_data_type(input->data_type());
outputs_.at(1)->set_shape(output_shape);
}
return RET_OK;
}
} // namespace lite


+ 1
- 0
mindspore/lite/src/ops/argmin.h View File

@@ -36,6 +36,7 @@ class ArgMin : public PrimitiveC {
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 70
- 0
mindspore/lite/src/ops/mod.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/mod.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE

int Mod::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Mod;
}
if (this->primitive_->value.type != schema::PrimitiveType_Mod) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ModT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
this->primitive_->value.value = attr;
}
return RET_OK;
}

#else

int Mod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateMod(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mod, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *ModCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mod>(primitive); }
Registry ModRegistry(schema::PrimitiveType_Mod, ModCreator);
#endif

} // namespace lite
} // namespace mindspore

+ 42
- 0
mindspore/lite/src/ops/mod.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_MOD_H_
#define LITE_MINDSPORE_LITE_C_OPS_MOD_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"

namespace mindspore {
namespace lite {
class Mod : public Arithmetic {
public:
Mod() = default;
~Mod() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Mod, Arithmetic);
explicit Mod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_

+ 1
- 0
mindspore/lite/src/ops/populate/arithmetic_populate.cc View File

@@ -67,6 +67,7 @@ Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithme
Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic);
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic);
Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic);
Registry ModParameterRegistry(schema::PrimitiveType_Mod, PopulateArithmetic);
Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic);
} // namespace lite
} // namespace mindspore

+ 7
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -101,6 +101,7 @@
#include "src/ops/logical_not.h"
#include "src/ops/floor_div.h"
#include "src/ops/floor_mod.h"
#include "src/ops/mod.h"
#include "src/ops/equal.h"
#include "src/ops/not_equal.h"
#include "src/ops/less.h"
@@ -597,6 +598,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Equal>(prim, inputs, quantType);
} else if (op_type == "TopK") {
return NewPrimitiveC<TopK>(prim, inputs, quantType);
} else if (op_type == "Mod") {
return NewPrimitiveC<Mod>(prim, inputs, quantType);
} else if (op_type == "ArgMinWithValue") {
return NewPrimitiveC<ArgMin>(prim, inputs, quantType);
} else if (op_type == "Range") {
return NewPrimitiveC<Range>(prim, inputs, quantType);
} else if (op_type == "Tile") {
@@ -805,6 +810,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) FloorDiv(primitive);
case schema::PrimitiveType_FloorMod:
return new (std::nothrow) FloorMod(primitive);
case schema::PrimitiveType_Mod:
return new (std::nothrow) Mod(primitive);
case schema::PrimitiveType_Equal:
return new (std::nothrow) Equal(primitive);
case schema::PrimitiveType_NotEqual:


Loading…
Cancel
Save