From: @xu_anyue Reviewed-by: @hangangqiang Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -168,7 +168,7 @@ table FlattenGrad { | |||||
| } | } | ||||
| table Concat { | table Concat { | ||||
| axis: int; | axis: int; | ||||
| n: int; | |||||
| n: int; // DEPRECATED | |||||
| } | } | ||||
| table SoftMax { | table SoftMax { | ||||
| @@ -822,6 +822,7 @@ table Gather { | |||||
| } | } | ||||
| table GatherNd { | table GatherNd { | ||||
| batchDims: int; // DEPRECATED | |||||
| } | } | ||||
| table Fill { | table Fill { | ||||
| @@ -27,10 +27,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; } | int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; } | ||||
| int Concat::GetN() const { return this->primitive_->value.AsConcat()->n; } | |||||
| void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; } | void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; } | ||||
| void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; } | |||||
| int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | ||||
| if (this->primitive_ == nullptr) { | if (this->primitive_ == nullptr) { | ||||
| @@ -71,13 +69,12 @@ int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: | |||||
| MS_LOG(ERROR) << "value_as_Concat return nullptr"; | MS_LOG(ERROR) << "value_as_Concat return nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto val_offset = schema::CreateConcat(*fbb, attr->axis(), attr->n()); | |||||
| auto val_offset = schema::CreateConcat(*fbb, attr->axis()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o); | auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o); | ||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } | int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } | ||||
| int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); } | |||||
| PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); } | PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); } | ||||
| Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator); | Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator); | ||||
| @@ -33,13 +33,11 @@ class Concat : public PrimitiveC { | |||||
| explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | ||||
| void SetAxis(int axis); | void SetAxis(int axis); | ||||
| void SetN(int n); | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | ||||
| int GetAxis() const; | int GetAxis() const; | ||||
| int GetN() const; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -93,7 +93,6 @@ TEST_F(SchedulerTest, TestConstructSubGraphsTwoBranch) { | |||||
| concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat; | concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat; | ||||
| auto concat_primitive = new mindspore::schema::ConcatT; | auto concat_primitive = new mindspore::schema::ConcatT; | ||||
| concat_primitive->axis = 3; | concat_primitive->axis = 3; | ||||
| concat_primitive->n = 2; | |||||
| concat->primitive->value.value = concat_primitive; | concat->primitive->value.value = concat_primitive; | ||||
| concat->name = "concat"; | concat->name = "concat"; | ||||
| @@ -255,7 +254,6 @@ TEST_F(SchedulerTest, TestConstructSubGraphsThreeBranch) { | |||||
| concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat; | concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat; | ||||
| auto concat_primitive = new mindspore::schema::ConcatT; | auto concat_primitive = new mindspore::schema::ConcatT; | ||||
| concat_primitive->axis = 3; | concat_primitive->axis = 3; | ||||
| concat_primitive->n = 2; | |||||
| concat->primitive->value.value = concat_primitive; | concat->primitive->value.value = concat_primitive; | ||||
| concat->name = "concat"; | concat->name = "concat"; | ||||
| @@ -35,7 +35,6 @@ TEST_F(TestTfliteParserConcat, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr); | ||||
| auto val = meta_graph->nodes.front()->primitive->value.AsConcat(); | auto val = meta_graph->nodes.front()->primitive->value.AsConcat(); | ||||
| ASSERT_EQ(val->axis, 1); | ASSERT_EQ(val->axis, 1); | ||||
| ASSERT_EQ(val->n, 2); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,7 +60,6 @@ STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe: | |||||
| MS_LOG(DEBUG) << "by default, set axis = 1"; | MS_LOG(DEBUG) << "by default, set axis = 1"; | ||||
| attr->axis = 1; | attr->axis = 1; | ||||
| } | } | ||||
| attr->n = proto.bottom_size(); | |||||
| op->name = proto.name(); | op->name = proto.name(); | ||||
| op->primitive->value.type = schema::PrimitiveType_Concat; | op->primitive->value.type = schema::PrimitiveType_Concat; | ||||
| @@ -34,7 +34,6 @@ PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite: | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| attr->axis = tfliteAttr->axis; | attr->axis = tfliteAttr->axis; | ||||
| attr->n = tflite_op->inputs.size(); | |||||
| primitive->value.type = schema::PrimitiveType_Concat; | primitive->value.type = schema::PrimitiveType_Concat; | ||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| return PrimitiveC::Create(primitive.release()); | return PrimitiveC::Create(primitive.release()); | ||||