From: @gongdaguo Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -260,6 +260,7 @@ union PrimitiveType { | |||
| SigmoidCrossEntropyWithLogitsGrad, | |||
| Reciprocal, | |||
| Merge, | |||
| Mod, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -890,6 +890,9 @@ table FloorDiv { | |||
| table FloorMod { | |||
| } | |||
| table Mod { | |||
| } | |||
| table L2Norm { | |||
| axis: [int]; | |||
| epsilon: float; | |||
| @@ -35,6 +35,36 @@ void ArgMin::SetTopK(int top_k) { this->primitive_->value.AsArgMin()->topK = top | |||
| void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->keepDims = keep_dims; } | |||
| void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; } | |||
| int ArgMin::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_ArgMin; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_ArgMin) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::ArgMinT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (prim.GetAttr("axis") != nullptr) { | |||
| attr->axis = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("axis"))); | |||
| } | |||
| if (prim.GetAttr("keep_dims") != nullptr) { | |||
| attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims"))); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -66,7 +96,7 @@ int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { | |||
| if (inputs_.size() != kSingleNum || outputs_.size() > kDoubleNum) { | |||
| MS_LOG(ERROR) << "tensor number is error."; | |||
| } | |||
| output->set_format(input->format()); | |||
| @@ -88,6 +118,11 @@ int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te | |||
| } | |||
| output->set_shape(output_shape); | |||
| if (outputs_.size() == kDoubleNum) { | |||
| outputs_.at(1)->set_format(input->format()); | |||
| outputs_.at(1)->set_data_type(input->data_type()); | |||
| outputs_.at(1)->set_shape(output_shape); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -36,6 +36,7 @@ class ArgMin : public PrimitiveC { | |||
| void SetTopK(int top_k); | |||
| void SetKeepDims(bool keep_dims); | |||
| void SetAxisType(int axis_type); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/mod.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Mod::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Mod; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Mod) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| delete this->primitive_; | |||
| this->primitive_ = nullptr; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::ModT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| delete this->primitive_; | |||
| this->primitive_ = nullptr; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Mod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto val_offset = schema::CreateMod(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mod, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *ModCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mod>(primitive); } | |||
| Registry ModRegistry(schema::PrimitiveType_Mod, ModCreator); | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_MOD_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_MOD_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Mod : public Arithmetic { | |||
| public: | |||
| Mod() = default; | |||
| ~Mod() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Mod, Arithmetic); | |||
| explicit Mod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_ | |||
| @@ -67,6 +67,7 @@ Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithme | |||
| Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); | |||
| Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); | |||
| Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); | |||
| Registry ModParameterRegistry(schema::PrimitiveType_Mod, PopulateArithmetic); | |||
| Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -101,6 +101,7 @@ | |||
| #include "src/ops/logical_not.h" | |||
| #include "src/ops/floor_div.h" | |||
| #include "src/ops/floor_mod.h" | |||
| #include "src/ops/mod.h" | |||
| #include "src/ops/equal.h" | |||
| #include "src/ops/not_equal.h" | |||
| #include "src/ops/less.h" | |||
| @@ -597,6 +598,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Equal>(prim, inputs, quantType); | |||
| } else if (op_type == "TopK") { | |||
| return NewPrimitiveC<TopK>(prim, inputs, quantType); | |||
| } else if (op_type == "Mod") { | |||
| return NewPrimitiveC<Mod>(prim, inputs, quantType); | |||
| } else if (op_type == "ArgMinWithValue") { | |||
| return NewPrimitiveC<ArgMin>(prim, inputs, quantType); | |||
| } else if (op_type == "Range") { | |||
| return NewPrimitiveC<Range>(prim, inputs, quantType); | |||
| } else if (op_type == "Tile") { | |||
| @@ -805,6 +810,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) FloorDiv(primitive); | |||
| case schema::PrimitiveType_FloorMod: | |||
| return new (std::nothrow) FloorMod(primitive); | |||
| case schema::PrimitiveType_Mod: | |||
| return new (std::nothrow) Mod(primitive); | |||
| case schema::PrimitiveType_Equal: | |||
| return new (std::nothrow) Equal(primitive); | |||
| case schema::PrimitiveType_NotEqual: | |||