| @@ -21,7 +21,6 @@ | |||||
| typedef struct GatherNdParameter { | typedef struct GatherNdParameter { | ||||
| OpParameter op_parameter_; | OpParameter op_parameter_; | ||||
| int batchDims_; | |||||
| } GatherNdParameter; | } GatherNdParameter; | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -822,7 +822,6 @@ table Gather { | |||||
| } | } | ||||
| table GatherNd { | table GatherNd { | ||||
| batchDims: int; | |||||
| } | } | ||||
| table Fill { | table Fill { | ||||
| @@ -23,9 +23,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int GatherNd::GetBatchDims() const { return this->primitive_->value.AsGatherNd()->batchDims; } | |||||
| void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd()->batchDims = batch_dims; } | |||||
| #else | #else | ||||
| int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| @@ -37,12 +34,11 @@ int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto val_offset = schema::CreateGatherNd(*fbb, attr->batchDims()); | |||||
| auto val_offset = schema::CreateGatherNd(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GatherNd, val_offset.o); | auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GatherNd, val_offset.o); | ||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); } | |||||
| PrimitiveC *GatherNdCreator(const schema::Primitive *primitive) { | PrimitiveC *GatherNdCreator(const schema::Primitive *primitive) { | ||||
| return PrimitiveC::NewPrimitiveC<GatherNd>(primitive); | return PrimitiveC::NewPrimitiveC<GatherNd>(primitive); | ||||
| @@ -32,13 +32,10 @@ class GatherNd : public PrimitiveC { | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(GatherNd, PrimitiveC); | MS_DECLARE_PARENT(GatherNd, PrimitiveC); | ||||
| explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetBatchDims(int batch_dims); | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | ||||
| int GetBatchDims() const; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateCommonParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto *common_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (common_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc OpParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(common_parameter, 0, sizeof(OpParameter)); | |||||
| return common_parameter; | |||||
| } | |||||
| Registry ZerosLikeParameterRegistry(schema::PrimitiveType_ZerosLike, PopulateCommonParameter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -30,9 +30,6 @@ OpParameter *PopulateGatherNdParameter(const mindspore::lite::PrimitiveC *primit | |||||
| } | } | ||||
| memset(gather_nd_param, 0, sizeof(GatherNdParameter)); | memset(gather_nd_param, 0, sizeof(GatherNdParameter)); | ||||
| gather_nd_param->op_parameter_.type_ = primitive->Type(); | gather_nd_param->op_parameter_.type_ = primitive->Type(); | ||||
| auto gatherNd_attr = | |||||
| reinterpret_cast<mindspore::lite::GatherNd *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| gather_nd_param->batchDims_ = gatherNd_attr->GetBatchDims(); | |||||
| return reinterpret_cast<OpParameter *>(gather_nd_param); | return reinterpret_cast<OpParameter *>(gather_nd_param); | ||||
| } | } | ||||
| @@ -37,7 +37,6 @@ TEST_F(TestGatherNdInt8, GatherNdTest) { | |||||
| GatherNdParameter op_param; | GatherNdParameter op_param; | ||||
| op_param.op_parameter_.type_ = schema::PrimitiveType_GatherNd; | op_param.op_parameter_.type_ = schema::PrimitiveType_GatherNd; | ||||
| op_param.batchDims_ = 1; | |||||
| std::vector<int> shape = {1, 2, 2, 5}; | std::vector<int> shape = {1, 2, 2, 5}; | ||||
| std::vector<int> out_shape = {1, 3, 5}; | std::vector<int> out_shape = {1, 3, 5}; | ||||
| @@ -33,8 +33,6 @@ TEST_F(TestTfliteParserGatherNd, OpType) { | |||||
| TEST_F(TestTfliteParserGatherNd, AttrValue) { | TEST_F(TestTfliteParserGatherNd, AttrValue) { | ||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr); | ||||
| auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd(); | |||||
| ASSERT_EQ(val->batchDims, 0); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,8 +34,6 @@ PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr<tflit | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| attr->batchDims = 0; | |||||
| primitive->value.type = schema::PrimitiveType_GatherNd; | primitive->value.type = schema::PrimitiveType_GatherNd; | ||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| return PrimitiveC::Create(primitive.release()); | return PrimitiveC::Create(primitive.release()); | ||||