Browse Source

add common_populate

tags/v1.1.0
yeyunpeng 5 years ago
parent
commit
9ee2f1baa3
9 changed files with 37 additions and 18 deletions
  1. +0
    -1
      mindspore/lite/nnacl/fp32/gatherNd_fp32.h
  2. +0
    -1
      mindspore/lite/schema/ops.fbs
  3. +1
    -5
      mindspore/lite/src/ops/gather_nd.cc
  4. +0
    -3
      mindspore/lite/src/ops/gather_nd.h
  5. +36
    -0
      mindspore/lite/src/ops/populate/common_populate.cc
  6. +0
    -3
      mindspore/lite/src/ops/populate/gather_nd_populate.cc
  7. +0
    -1
      mindspore/lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc
  8. +0
    -2
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc
  9. +0
    -2
      mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc

+ 0
- 1
mindspore/lite/nnacl/fp32/gatherNd_fp32.h View File

@@ -21,7 +21,6 @@

typedef struct GatherNdParameter {
OpParameter op_parameter_;
int batchDims_;
} GatherNdParameter;

#ifdef __cplusplus


+ 0
- 1
mindspore/lite/schema/ops.fbs View File

@@ -822,7 +822,6 @@ table Gather {
}

table GatherNd {
batchDims: int;
}

table Fill {


+ 1
- 5
mindspore/lite/src/ops/gather_nd.cc View File

@@ -23,9 +23,6 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int GatherNd::GetBatchDims() const { return this->primitive_->value.AsGatherNd()->batchDims; }

void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd()->batchDims = batch_dims; }

#else
int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
@@ -37,12 +34,11 @@ int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
return RET_ERROR;
}

auto val_offset = schema::CreateGatherNd(*fbb, attr->batchDims());
auto val_offset = schema::CreateGatherNd(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GatherNd, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }

PrimitiveC *GatherNdCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<GatherNd>(primitive);


+ 0
- 3
mindspore/lite/src/ops/gather_nd.h View File

@@ -32,13 +32,10 @@ class GatherNd : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GatherNd, PrimitiveC);
explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBatchDims(int batch_dims);

#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetBatchDims() const;
};
} // namespace lite
} // namespace mindspore


+ 36
- 0
mindspore/lite/src/ops/populate/common_populate.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {

OpParameter *PopulateCommonParameter(const mindspore::lite::PrimitiveC *primitive) {
auto *common_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (common_parameter == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr;
}
memset(common_parameter, 0, sizeof(OpParameter));
return common_parameter;
}

Registry ZerosLikeParameterRegistry(schema::PrimitiveType_ZerosLike, PopulateCommonParameter);

} // namespace lite
} // namespace mindspore

+ 0
- 3
mindspore/lite/src/ops/populate/gather_nd_populate.cc View File

@@ -30,9 +30,6 @@ OpParameter *PopulateGatherNdParameter(const mindspore::lite::PrimitiveC *primit
}
memset(gather_nd_param, 0, sizeof(GatherNdParameter));
gather_nd_param->op_parameter_.type_ = primitive->Type();
auto gatherNd_attr =
reinterpret_cast<mindspore::lite::GatherNd *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
gather_nd_param->batchDims_ = gatherNd_attr->GetBatchDims();
return reinterpret_cast<OpParameter *>(gather_nd_param);
}



+ 0
- 1
mindspore/lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc View File

@@ -37,7 +37,6 @@ TEST_F(TestGatherNdInt8, GatherNdTest) {

GatherNdParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_GatherNd;
op_param.batchDims_ = 1;
std::vector<int> shape = {1, 2, 2, 5};
std::vector<int> out_shape = {1, 3, 5};



+ 0
- 2
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc View File

@@ -33,8 +33,6 @@ TEST_F(TestTfliteParserGatherNd, OpType) {

TEST_F(TestTfliteParserGatherNd, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd();
ASSERT_EQ(val->batchDims, 0);
}

} // namespace mindspore

+ 0
- 2
mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc View File

@@ -34,8 +34,6 @@ PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr<tflit
return nullptr;
}

attr->batchDims = 0;

primitive->value.type = schema::PrimitiveType_GatherNd;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());


Loading…
Cancel
Save