Browse Source

fix compare function output dataType

tags/v1.1.0
xuanyue 5 years ago
parent
commit
1dca718a1e
11 changed files with 56 additions and 1 deletions
  1. +1
    -1
      mindspore/lite/src/ops/equal.cc
  2. +10
    -0
      mindspore/lite/src/ops/greater.cc
  3. +1
    -0
      mindspore/lite/src/ops/greater.h
  4. +10
    -0
      mindspore/lite/src/ops/greater_equal.cc
  5. +1
    -0
      mindspore/lite/src/ops/greater_equal.h
  6. +10
    -0
      mindspore/lite/src/ops/less.cc
  7. +1
    -0
      mindspore/lite/src/ops/less.h
  8. +10
    -0
      mindspore/lite/src/ops/less_equal.cc
  9. +1
    -0
      mindspore/lite/src/ops/less_equal.h
  10. +10
    -0
      mindspore/lite/src/ops/not_equal.cc
  11. +1
    -0
      mindspore/lite/src/ops/not_equal.h

+ 1
- 1
mindspore/lite/src/ops/equal.cc View File

@@ -35,7 +35,7 @@ int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeUInt8);
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}


+ 10
- 0
mindspore/lite/src/ops/greater.cc View File

@@ -29,5 +29,15 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
return RET_OK;
}
#endif
int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 1
- 0
mindspore/lite/src/ops/greater.h View File

@@ -35,6 +35,7 @@ class Greater : public Arithmetic {

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore


+ 10
- 0
mindspore/lite/src/ops/greater_equal.cc View File

@@ -28,5 +28,15 @@ int GreaterEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
return RET_OK;
}
#endif
int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 1
- 0
mindspore/lite/src/ops/greater_equal.h View File

@@ -36,6 +36,7 @@ class GreaterEqual : public Arithmetic {

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore


+ 10
- 0
mindspore/lite/src/ops/less.cc View File

@@ -30,5 +30,15 @@ int Less::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
return RET_OK;
}
#endif
int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 1
- 0
mindspore/lite/src/ops/less.h View File

@@ -36,6 +36,7 @@ class Less : public Arithmetic {

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore


+ 10
- 0
mindspore/lite/src/ops/less_equal.cc View File

@@ -29,5 +29,15 @@ int LessEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
return RET_OK;
}
#endif
int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 1
- 0
mindspore/lite/src/ops/less_equal.h View File

@@ -36,6 +36,7 @@ class LessEqual : public Arithmetic {

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore


+ 10
- 0
mindspore/lite/src/ops/not_equal.cc View File

@@ -29,5 +29,15 @@ int NotEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
return RET_OK;
}
#endif
int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 1
- 0
mindspore/lite/src/ops/not_equal.h View File

@@ -36,6 +36,7 @@ class NotEqual : public Arithmetic {

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore


Loading…
Cancel
Save