diff --git a/mindspore/lite/src/ops/equal.cc b/mindspore/lite/src/ops/equal.cc index 1178e6f1c7..b7525d7902 100644 --- a/mindspore/lite/src/ops/equal.cc +++ b/mindspore/lite/src/ops/equal.cc @@ -35,7 +35,7 @@ int Equal::InferShape(std::vector inputs_, std::vector outpu auto output = outputs_.front(); MS_ASSERT(output != nullptr); output->set_shape(input->shape()); - output->set_data_type(TypeId::kNumberTypeUInt8); + output->set_data_type(TypeId::kNumberTypeBool); output->SetFormat(input->GetFormat()); return RET_OK; } diff --git a/mindspore/lite/src/ops/greater.cc b/mindspore/lite/src/ops/greater.cc index bd92f1a1b1..0d7bf7f555 100644 --- a/mindspore/lite/src/ops/greater.cc +++ b/mindspore/lite/src/ops/greater.cc @@ -29,5 +29,15 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers return RET_OK; } #endif +int Greater::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(TypeId::kNumberTypeBool); + output->SetFormat(input->GetFormat()); + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater.h b/mindspore/lite/src/ops/greater.h index ef720a7dea..a2024488ed 100644 --- a/mindspore/lite/src/ops/greater.h +++ b/mindspore/lite/src/ops/greater.h @@ -35,6 +35,7 @@ class Greater : public Arithmetic { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater_equal.cc b/mindspore/lite/src/ops/greater_equal.cc index bd2e5b1c45..ce30ba1742 100644 --- a/mindspore/lite/src/ops/greater_equal.cc +++ b/mindspore/lite/src/ops/greater_equal.cc @@ -28,5 +28,15 @@ int GreaterEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu return RET_OK; } #endif +int GreaterEqual::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(TypeId::kNumberTypeBool); + output->SetFormat(input->GetFormat()); + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater_equal.h b/mindspore/lite/src/ops/greater_equal.h index a25932dfed..5a5531a27e 100644 --- a/mindspore/lite/src/ops/greater_equal.h +++ b/mindspore/lite/src/ops/greater_equal.h @@ -36,6 +36,7 @@ class GreaterEqual : public Arithmetic { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less.cc b/mindspore/lite/src/ops/less.cc index 57a98d87a9..efc6e2af9d 100644 --- a/mindspore/lite/src/ops/less.cc +++ b/mindspore/lite/src/ops/less.cc @@ -30,5 +30,15 @@ int Less::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F return RET_OK; } #endif +int Less::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(TypeId::kNumberTypeBool); + output->SetFormat(input->GetFormat()); + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less.h b/mindspore/lite/src/ops/less.h index c88d28d906..6bf7bf35c6 100644 --- a/mindspore/lite/src/ops/less.h +++ b/mindspore/lite/src/ops/less.h @@ -36,6 +36,7 @@ class Less : public Arithmetic { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less_equal.cc b/mindspore/lite/src/ops/less_equal.cc index 7274f8cc22..0acc83c213 100644 --- a/mindspore/lite/src/ops/less_equal.cc +++ b/mindspore/lite/src/ops/less_equal.cc @@ -29,5 +29,15 @@ int LessEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe return RET_OK; } #endif +int LessEqual::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(TypeId::kNumberTypeBool); + output->SetFormat(input->GetFormat()); + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less_equal.h b/mindspore/lite/src/ops/less_equal.h index 705caf1cdd..a86395497a 100644 --- a/mindspore/lite/src/ops/less_equal.h +++ b/mindspore/lite/src/ops/less_equal.h @@ -36,6 +36,7 @@ class LessEqual : public Arithmetic { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/not_equal.cc b/mindspore/lite/src/ops/not_equal.cc index af97dae95a..31904273e3 100644 --- a/mindspore/lite/src/ops/not_equal.cc +++ b/mindspore/lite/src/ops/not_equal.cc @@ -29,5 +29,15 @@ int NotEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer return RET_OK; } #endif +int NotEqual::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(TypeId::kNumberTypeBool); + output->SetFormat(input->GetFormat()); + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/not_equal.h b/mindspore/lite/src/ops/not_equal.h index a75cd96837..77b4cb0fe7 100644 --- a/mindspore/lite/src/ops/not_equal.h +++ b/mindspore/lite/src/ops/not_equal.h @@ -36,6 +36,7 @@ class NotEqual : public Arithmetic { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite } // namespace mindspore