|
|
|
@@ -53,37 +53,37 @@ PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) { |
|
|
|
[](PythonTreeGetters &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) |
|
|
|
.def("GetOutputShapes", |
|
|
|
[](PythonTreeGetters &self) { |
|
|
|
std::vector<TensorShape> shapes; |
|
|
|
std::vector<TensorShape> shapes = {}; |
|
|
|
THROW_IF_ERROR(self.GetOutputShapes(&shapes)); |
|
|
|
return shapesToListOfShape(shapes); |
|
|
|
}) |
|
|
|
.def("GetOutputTypes", |
|
|
|
[](PythonTreeGetters &self) { |
|
|
|
std::vector<DataType> types; |
|
|
|
std::vector<DataType> types = {}; |
|
|
|
THROW_IF_ERROR(self.GetOutputTypes(&types)); |
|
|
|
return typesToListOfType(types); |
|
|
|
}) |
|
|
|
.def("GetNumClasses", |
|
|
|
[](PythonTreeGetters &self) { |
|
|
|
int64_t num_classes; |
|
|
|
int64_t num_classes = -1; |
|
|
|
THROW_IF_ERROR(self.GetNumClasses(&num_classes)); |
|
|
|
return num_classes; |
|
|
|
}) |
|
|
|
.def("GetRepeatCount", |
|
|
|
[](PythonTreeGetters &self) { |
|
|
|
int64_t repeat_count; |
|
|
|
int64_t repeat_count = -1; |
|
|
|
THROW_IF_ERROR(self.GetRepeatCount(&repeat_count)); |
|
|
|
return repeat_count; |
|
|
|
}) |
|
|
|
.def("GetBatchSize", |
|
|
|
[](PythonTreeGetters &self) { |
|
|
|
int64_t batch_size; |
|
|
|
int64_t batch_size = -1; |
|
|
|
THROW_IF_ERROR(self.GetBatchSize(&batch_size)); |
|
|
|
return batch_size; |
|
|
|
}) |
|
|
|
.def("GetColumnNames", |
|
|
|
[](PythonTreeGetters &self) { |
|
|
|
std::vector<std::string> col_names; |
|
|
|
std::vector<std::string> col_names = {}; |
|
|
|
THROW_IF_ERROR(self.GetColumnNames(&col_names)); |
|
|
|
return col_names; |
|
|
|
}) |
|
|
|
|