Fix bug in batch getDatasetSize Add getDatasetSize support for generator_op.cc Fix header guards of some filestags/v1.1.0
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <map> | #include <map> | ||||
| @@ -260,4 +260,4 @@ class DEPipeline { | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_DE_PIPELINE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef API_PYBIND_API_H_ | |||||
| #define API_PYBIND_API_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| @@ -78,4 +78,4 @@ class PybindDefineRegisterer { | |||||
| #endif | #endif | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // API_PYBIND_API_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_REGISTER_H_ | |||||
| @@ -546,9 +546,9 @@ Status BatchOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); | RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); | ||||
| if (num_rows > 0 && start_batch_size_ > 0) { | if (num_rows > 0 && start_batch_size_ > 0) { | ||||
| if (drop_) { | if (drop_) { | ||||
| num_rows = floor(num_rows / start_batch_size_); | |||||
| num_rows = floor(num_rows / (1.0 * start_batch_size_)); | |||||
| } else { | } else { | ||||
| num_rows = ceil(num_rows / start_batch_size_); | |||||
| num_rows = ceil(num_rows / (1.0 * start_batch_size_)); | |||||
| } | } | ||||
| } | } | ||||
| *dataset_size = num_rows; | *dataset_size = num_rows; | ||||
| @@ -312,7 +312,8 @@ Status DatasetOp::GetNumClasses(int64_t *num_classes) { | |||||
| if (!child_.empty()) { | if (!child_.empty()) { | ||||
| return child_[0]->GetNumClasses(num_classes); | return child_[0]->GetNumClasses(num_classes); | ||||
| } else { | } else { | ||||
| RETURN_STATUS_UNEXPECTED("Can't get the dataset size for the current tree."); | |||||
| *num_classes = -1; | |||||
| RETURN_STATUS_UNEXPECTED("Can't get the number of classes for the current tree."); | |||||
| } | } | ||||
| } | } | ||||
| @@ -274,5 +274,11 @@ Status GeneratorOp::ComputeColMap() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status GeneratorOp::GetDatasetSize(int64_t *dataset_size) { // Get Dataset size | |||||
| // We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to | |||||
| // iterate over the dataset and count the size | |||||
| *dataset_size = dataset_size_; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -136,6 +136,8 @@ class GeneratorOp : public PipelineOp { | |||||
| Status Init(); | Status Init(); | ||||
| Status GetDatasetSize(int64_t *dataset_size) override; | |||||
| private: | private: | ||||
| py::function generator_function_; | py::function generator_function_; | ||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||