Browse Source

add common_populate

tags/v1.1.0
yeyunpeng 5 years ago
parent
commit
9ee2f1baa3
9 changed files with 37 additions and 18 deletions
  1. +0
    -1
      mindspore/lite/nnacl/fp32/gatherNd_fp32.h
  2. +0
    -1
      mindspore/lite/schema/ops.fbs
  3. +1
    -5
      mindspore/lite/src/ops/gather_nd.cc
  4. +0
    -3
      mindspore/lite/src/ops/gather_nd.h
  5. +36
    -0
      mindspore/lite/src/ops/populate/common_populate.cc
  6. +0
    -3
      mindspore/lite/src/ops/populate/gather_nd_populate.cc
  7. +0
    -1
      mindspore/lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc
  8. +0
    -2
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc
  9. +0
    -2
      mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc

+ 0
- 1
mindspore/lite/nnacl/fp32/gatherNd_fp32.h View File

@@ -21,7 +21,6 @@


typedef struct GatherNdParameter { typedef struct GatherNdParameter {
OpParameter op_parameter_; OpParameter op_parameter_;
int batchDims_;
} GatherNdParameter; } GatherNdParameter;


#ifdef __cplusplus #ifdef __cplusplus


+ 0
- 1
mindspore/lite/schema/ops.fbs View File

@@ -822,7 +822,6 @@ table Gather {
} }


table GatherNd { table GatherNd {
batchDims: int;
} }


table Fill { table Fill {


+ 1
- 5
mindspore/lite/src/ops/gather_nd.cc View File

@@ -23,9 +23,6 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
int GatherNd::GetBatchDims() const { return this->primitive_->value.AsGatherNd()->batchDims; }

void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd()->batchDims = batch_dims; }


#else #else
int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
@@ -37,12 +34,11 @@ int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
return RET_ERROR; return RET_ERROR;
} }


auto val_offset = schema::CreateGatherNd(*fbb, attr->batchDims());
auto val_offset = schema::CreateGatherNd(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GatherNd, val_offset.o); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GatherNd, val_offset.o);
fbb->Finish(prim_offset); fbb->Finish(prim_offset);
return RET_OK; return RET_OK;
} }
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }


PrimitiveC *GatherNdCreator(const schema::Primitive *primitive) { PrimitiveC *GatherNdCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<GatherNd>(primitive); return PrimitiveC::NewPrimitiveC<GatherNd>(primitive);


+ 0
- 3
mindspore/lite/src/ops/gather_nd.h View File

@@ -32,13 +32,10 @@ class GatherNd : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GatherNd, PrimitiveC); MS_DECLARE_PARENT(GatherNd, PrimitiveC);
explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBatchDims(int batch_dims);

#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetBatchDims() const;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 36
- 0
mindspore/lite/src/ops/populate/common_populate.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {

OpParameter *PopulateCommonParameter(const mindspore::lite::PrimitiveC *primitive) {
auto *common_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (common_parameter == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr;
}
memset(common_parameter, 0, sizeof(OpParameter));
return common_parameter;
}

Registry ZerosLikeParameterRegistry(schema::PrimitiveType_ZerosLike, PopulateCommonParameter);

} // namespace lite
} // namespace mindspore

+ 0
- 3
mindspore/lite/src/ops/populate/gather_nd_populate.cc View File

@@ -30,9 +30,6 @@ OpParameter *PopulateGatherNdParameter(const mindspore::lite::PrimitiveC *primit
} }
memset(gather_nd_param, 0, sizeof(GatherNdParameter)); memset(gather_nd_param, 0, sizeof(GatherNdParameter));
gather_nd_param->op_parameter_.type_ = primitive->Type(); gather_nd_param->op_parameter_.type_ = primitive->Type();
auto gatherNd_attr =
reinterpret_cast<mindspore::lite::GatherNd *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
gather_nd_param->batchDims_ = gatherNd_attr->GetBatchDims();
return reinterpret_cast<OpParameter *>(gather_nd_param); return reinterpret_cast<OpParameter *>(gather_nd_param);
} }




+ 0
- 1
mindspore/lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc View File

@@ -37,7 +37,6 @@ TEST_F(TestGatherNdInt8, GatherNdTest) {


GatherNdParameter op_param; GatherNdParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_GatherNd; op_param.op_parameter_.type_ = schema::PrimitiveType_GatherNd;
op_param.batchDims_ = 1;
std::vector<int> shape = {1, 2, 2, 5}; std::vector<int> shape = {1, 2, 2, 5};
std::vector<int> out_shape = {1, 3, 5}; std::vector<int> out_shape = {1, 3, 5};




+ 0
- 2
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc View File

@@ -33,8 +33,6 @@ TEST_F(TestTfliteParserGatherNd, OpType) {


TEST_F(TestTfliteParserGatherNd, AttrValue) { TEST_F(TestTfliteParserGatherNd, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd();
ASSERT_EQ(val->batchDims, 0);
} }


} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc View File

@@ -34,8 +34,6 @@ PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr<tflit
return nullptr; return nullptr;
} }


attr->batchDims = 0;

primitive->value.type = schema::PrimitiveType_GatherNd; primitive->value.type = schema::PrimitiveType_GatherNd;
primitive->value.value = attr.release(); primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release()); return PrimitiveC::Create(primitive.release());


Loading…
Cancel
Save