From: @lyvette Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -23,16 +23,13 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateSparseToDenseParameter(const mindspore::lite::PrimitiveC *primitive) { | OpParameter *PopulateSparseToDenseParameter(const mindspore::lite::PrimitiveC *primitive) { | ||||
| SparseToDenseParameter *sparse_to_dense_param = | |||||
| reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||||
| auto *sparse_to_dense_param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||||
| if (sparse_to_dense_param == nullptr) { | if (sparse_to_dense_param == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc SparseToDenseParameter failed."; | MS_LOG(ERROR) << "malloc SparseToDenseParameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(sparse_to_dense_param, 0, sizeof(SparseToDenseParameter)); | memset(sparse_to_dense_param, 0, sizeof(SparseToDenseParameter)); | ||||
| sparse_to_dense_param->op_parameter_.type_ = primitive->Type(); | sparse_to_dense_param->op_parameter_.type_ = primitive->Type(); | ||||
| auto param = reinterpret_cast<mindspore::lite::SparseToDense *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| sparse_to_dense_param->validate_indices_ = param->GetValidateIndices(); | |||||
| return reinterpret_cast<OpParameter *>(sparse_to_dense_param); | return reinterpret_cast<OpParameter *>(sparse_to_dense_param); | ||||
| } | } | ||||
| @@ -22,16 +22,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| bool SparseToDense::GetValidateIndices() const { return this->primitive_->value.AsSparseToDense()->validateIndices; } | |||||
| void SparseToDense::SetValidateIndices(bool validate_indices) { | |||||
| this->primitive_->value.AsSparseToDense()->validateIndices = validate_indices; | |||||
| } | |||||
| #else | |||||
| bool SparseToDense::GetValidateIndices() const { return this->primitive_->value_as_SparseToDense()->validateIndices(); } | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| MS_ASSERT(nullptr != fbb); | MS_ASSERT(nullptr != fbb); | ||||
| @@ -40,7 +31,7 @@ int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||||
| MS_LOG(ERROR) << "value_as_SparseToDense return nullptr"; | MS_LOG(ERROR) << "value_as_SparseToDense return nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto val_offset = schema::CreateSparseToDense(*fbb, attr->validateIndices()); | |||||
| auto val_offset = schema::CreateSparseToDense(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SparseToDense, val_offset.o); | auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SparseToDense, val_offset.o); | ||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -36,14 +36,12 @@ class SparseToDense : public PrimitiveC { | |||||
| void SetOutputShape(const std::vector<int> &output_shape); | void SetOutputShape(const std::vector<int> &output_shape); | ||||
| void SetSparseValue(const std::vector<int> &sparse_value); | void SetSparseValue(const std::vector<int> &sparse_value); | ||||
| void SetDefaultValue(const std::vector<int> &default_value); | void SetDefaultValue(const std::vector<int> &default_value); | ||||
| void SetValidateIndices(bool validate_indices); | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| std::vector<int> GetOutputShape() const; | std::vector<int> GetOutputShape() const; | ||||
| std::vector<int> GetSparseValue() const; | std::vector<int> GetSparseValue() const; | ||||
| std::vector<int> GetDefaultValue() const; | std::vector<int> GetDefaultValue() const; | ||||
| bool GetValidateIndices() const; | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | ||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -31,10 +31,4 @@ TEST_F(TestTfliteParserSparseToDense, OpType) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); | ||||
| ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SparseToDense) << "wrong Op Type"; | ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SparseToDense) << "wrong Op Type"; | ||||
| } | } | ||||
| TEST_F(TestTfliteParserSparseToDense, AttrValue) { | |||||
| ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr); | |||||
| auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense(); | |||||
| ASSERT_EQ(val->validateIndices, false); | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,7 +35,6 @@ PrimitiveC *TfliteSparseToDenseParser::ParseLitePrimitive(const std::unique_ptr< | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| attr->validateIndices = false; | |||||
| primitive->value.type = schema::PrimitiveType_SparseToDense; | primitive->value.type = schema::PrimitiveType_SparseToDense; | ||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| return PrimitiveC::Create(primitive.release()); | return PrimitiveC::Create(primitive.release()); | ||||