Merge pull request !6946 from 徐安越/mastertags/v1.1.0
| @@ -35,7 +35,7 @@ int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu | |||||
| auto output = outputs_.front(); | auto output = outputs_.front(); | ||||
| MS_ASSERT(output != nullptr); | MS_ASSERT(output != nullptr); | ||||
| output->set_shape(input->shape()); | output->set_shape(input->shape()); | ||||
| output->set_data_type(TypeId::kNumberTypeUInt8); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->SetFormat(input->GetFormat()); | output->SetFormat(input->GetFormat()); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,5 +29,15 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | #endif | ||||
| int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,6 +35,7 @@ class Greater : public Arithmetic { | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,5 +28,15 @@ int GreaterEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | #endif | ||||
| int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,6 +36,7 @@ class GreaterEqual : public Arithmetic { | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,5 +30,15 @@ int Less::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | #endif | ||||
| int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,6 +36,7 @@ class Less : public Arithmetic { | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,5 +29,15 @@ int LessEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | #endif | ||||
| int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,6 +36,7 @@ class LessEqual : public Arithmetic { | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,5 +29,15 @@ int NotEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #endif | #endif | ||||
| int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,6 +36,7 @@ class NotEqual : public Arithmetic { | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||