From: @gongdaguo Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -260,6 +260,7 @@ union PrimitiveType { | |||||
| SigmoidCrossEntropyWithLogitsGrad, | SigmoidCrossEntropyWithLogitsGrad, | ||||
| Reciprocal, | Reciprocal, | ||||
| Merge, | Merge, | ||||
| Mod, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -890,6 +890,9 @@ table FloorDiv { | |||||
| table FloorMod { | table FloorMod { | ||||
| } | } | ||||
| table Mod { | |||||
| } | |||||
| table L2Norm { | table L2Norm { | ||||
| axis: [int]; | axis: [int]; | ||||
| epsilon: float; | epsilon: float; | ||||
| @@ -35,6 +35,36 @@ void ArgMin::SetTopK(int top_k) { this->primitive_->value.AsArgMin()->topK = top | |||||
| void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->keepDims = keep_dims; } | void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->keepDims = keep_dims; } | ||||
| void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; } | void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; } | ||||
| int ArgMin::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_ArgMin; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_ArgMin) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::ArgMinT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (prim.GetAttr("axis") != nullptr) { | |||||
| attr->axis = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("axis"))); | |||||
| } | |||||
| if (prim.GetAttr("keep_dims") != nullptr) { | |||||
| attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims"))); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -66,7 +96,7 @@ int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te | |||||
| MS_ASSERT(input != nullptr); | MS_ASSERT(input != nullptr); | ||||
| auto output = outputs_.front(); | auto output = outputs_.front(); | ||||
| MS_ASSERT(output != nullptr); | MS_ASSERT(output != nullptr); | ||||
| if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { | |||||
| if (inputs_.size() != kSingleNum || outputs_.size() > kDoubleNum) { | |||||
| MS_LOG(ERROR) << "tensor number is error."; | MS_LOG(ERROR) << "tensor number is error."; | ||||
| } | } | ||||
| output->set_format(input->format()); | output->set_format(input->format()); | ||||
| @@ -88,6 +118,11 @@ int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te | |||||
| } | } | ||||
| output->set_shape(output_shape); | output->set_shape(output_shape); | ||||
| if (outputs_.size() == kDoubleNum) { | |||||
| outputs_.at(1)->set_format(input->format()); | |||||
| outputs_.at(1)->set_data_type(input->data_type()); | |||||
| outputs_.at(1)->set_shape(output_shape); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -36,6 +36,7 @@ class ArgMin : public PrimitiveC { | |||||
| void SetTopK(int top_k); | void SetTopK(int top_k); | ||||
| void SetKeepDims(bool keep_dims); | void SetKeepDims(bool keep_dims); | ||||
| void SetAxisType(int axis_type); | void SetAxisType(int axis_type); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/mod.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int Mod::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Mod; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Mod) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| delete this->primitive_; | |||||
| this->primitive_ = nullptr; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::ModT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| delete this->primitive_; | |||||
| this->primitive_ = nullptr; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int Mod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto val_offset = schema::CreateMod(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mod, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *ModCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mod>(primitive); } | |||||
| Registry ModRegistry(schema::PrimitiveType_Mod, ModCreator); | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_MOD_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_MOD_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/arithmetic.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Mod : public Arithmetic { | |||||
| public: | |||||
| Mod() = default; | |||||
| ~Mod() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Mod, Arithmetic); | |||||
| explicit Mod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_ | |||||
| @@ -67,6 +67,7 @@ Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithme | |||||
| Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); | Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); | ||||
| Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); | Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); | ||||
| Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); | Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); | ||||
| Registry ModParameterRegistry(schema::PrimitiveType_Mod, PopulateArithmetic); | |||||
| Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic); | Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -101,6 +101,7 @@ | |||||
| #include "src/ops/logical_not.h" | #include "src/ops/logical_not.h" | ||||
| #include "src/ops/floor_div.h" | #include "src/ops/floor_div.h" | ||||
| #include "src/ops/floor_mod.h" | #include "src/ops/floor_mod.h" | ||||
| #include "src/ops/mod.h" | |||||
| #include "src/ops/equal.h" | #include "src/ops/equal.h" | ||||
| #include "src/ops/not_equal.h" | #include "src/ops/not_equal.h" | ||||
| #include "src/ops/less.h" | #include "src/ops/less.h" | ||||
| @@ -597,6 +598,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<Equal>(prim, inputs, quantType); | return NewPrimitiveC<Equal>(prim, inputs, quantType); | ||||
| } else if (op_type == "TopK") { | } else if (op_type == "TopK") { | ||||
| return NewPrimitiveC<TopK>(prim, inputs, quantType); | return NewPrimitiveC<TopK>(prim, inputs, quantType); | ||||
| } else if (op_type == "Mod") { | |||||
| return NewPrimitiveC<Mod>(prim, inputs, quantType); | |||||
| } else if (op_type == "ArgMinWithValue") { | |||||
| return NewPrimitiveC<ArgMin>(prim, inputs, quantType); | |||||
| } else if (op_type == "Range") { | } else if (op_type == "Range") { | ||||
| return NewPrimitiveC<Range>(prim, inputs, quantType); | return NewPrimitiveC<Range>(prim, inputs, quantType); | ||||
| } else if (op_type == "Tile") { | } else if (op_type == "Tile") { | ||||
| @@ -805,6 +810,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) FloorDiv(primitive); | return new (std::nothrow) FloorDiv(primitive); | ||||
| case schema::PrimitiveType_FloorMod: | case schema::PrimitiveType_FloorMod: | ||||
| return new (std::nothrow) FloorMod(primitive); | return new (std::nothrow) FloorMod(primitive); | ||||
| case schema::PrimitiveType_Mod: | |||||
| return new (std::nothrow) Mod(primitive); | |||||
| case schema::PrimitiveType_Equal: | case schema::PrimitiveType_Equal: | ||||
| return new (std::nothrow) Equal(primitive); | return new (std::nothrow) Equal(primitive); | ||||
| case schema::PrimitiveType_NotEqual: | case schema::PrimitiveType_NotEqual: | ||||