| @@ -202,6 +202,13 @@ void bindDatasetOps(py::module *m) { | |||||
| return count; | return count; | ||||
| }); | }); | ||||
| (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") | (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") | ||||
| .def_static("get_num_rows", | |||||
| [](const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||||
| const py::dict &dict, int64_t numSamples) { | |||||
| int64_t count = 0; | |||||
| THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count)); | |||||
| return count; | |||||
| }) | |||||
| .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, | .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, | ||||
| const std::string &task_mode, const py::dict &dict, int64_t numSamples) { | const std::string &task_mode, const py::dict &dict, int64_t numSamples) { | ||||
| std::map<std::string, int32_t> output_class_indexing; | std::map<std::string, int32_t> output_class_indexing; | ||||
| @@ -442,6 +442,32 @@ Status VOCOp::GetNumRowsInDataset(int64_t *num) const { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||||
| const py::dict &dict, int64_t numSamples, int64_t *count) { | |||||
| if (task_type == "Detection") { | |||||
| std::map<std::string, int32_t> input_class_indexing; | |||||
| for (auto p : dict) { | |||||
| (void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first), | |||||
| py::reinterpret_borrow<py::int_>(p.second))); | |||||
| } | |||||
| std::shared_ptr<VOCOp> op; | |||||
| RETURN_IF_NOT_OK( | |||||
| Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); | |||||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | |||||
| RETURN_IF_NOT_OK(op->ParseAnnotationIds()); | |||||
| *count = static_cast<int64_t>(op->image_ids_.size()); | |||||
| } else if (task_type == "Segmentation") { | |||||
| std::shared_ptr<VOCOp> op; | |||||
| RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op)); | |||||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | |||||
| *count = static_cast<int64_t>(op->image_ids_.size()); | |||||
| } | |||||
| *count = (numSamples == 0 || *count < numSamples) ? *count : numSamples; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, | Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, | ||||
| const py::dict &dict, int64_t numSamples, | const py::dict &dict, int64_t numSamples, | ||||
| std::map<std::string, int32_t> *output_class_indexing) { | std::map<std::string, int32_t> *output_class_indexing) { | ||||
| @@ -208,6 +208,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param show_all | // @param show_all | ||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| // @param const std::string &dir - VOC dir path | |||||
| // @param const std::string &task_type - task type of reading voc job | |||||
| // @param const std::string &task_mode - task mode of reading voc job | |||||
| // @param const py::dict &dict - input dict of class index | |||||
| // @param int64_t numSamples - samples number of VOCDataset | |||||
| // @param int64_t *count - output rows number of VOCDataset | |||||
| static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||||
| const py::dict &dict, int64_t numSamples, int64_t *count); | |||||
| // @param const std::string &dir - VOC dir path | // @param const std::string &dir - VOC dir path | ||||
| // @param const std::string &task_type - task type of reading voc job | // @param const std::string &task_type - task type of reading voc job | ||||
| // @param const std::string &task_mode - task mode of reading voc job | // @param const std::string &task_mode - task mode of reading voc job | ||||
| @@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset): | |||||
| >>> new_sampler = ds.DistributedSampler(10, 2) | >>> new_sampler = ds.DistributedSampler(10, 2) | ||||
| >>> data.use_sampler(new_sampler) | >>> data.use_sampler(new_sampler) | ||||
| """ | """ | ||||
| if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): | |||||
| raise TypeError("new_sampler is not an instance of a sampler.") | |||||
| if new_sampler is None: | |||||
| raise TypeError("Input sampler could not be None.") | |||||
| if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): | |||||
| raise TypeError("Input sampler is not an instance of a sampler.") | |||||
| self.sampler = self.sampler.child_sampler | self.sampler = self.sampler.child_sampler | ||||
| self.add_sampler(new_sampler) | self.add_sampler(new_sampler) | ||||
| @@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset): | |||||
| Return: | Return: | ||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| if self.num_samples is None: | |||||
| num_samples = 0 | |||||
| else: | |||||
| num_samples = self.num_samples | |||||
| if self.class_indexing is None: | |||||
| class_indexing = dict() | |||||
| else: | |||||
| class_indexing = self.class_indexing | |||||
| num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) | |||||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||||
| rows_from_sampler = self._get_sampler_dataset_size() | rows_from_sampler = self._get_sampler_dataset_size() | ||||
| if rows_from_sampler is None: | if rows_from_sampler is None: | ||||
| return self.num_samples | |||||
| return rows_per_shard | |||||
| return min(rows_from_sampler, self.num_samples) | |||||
| return min(rows_from_sampler, rows_per_shard) | |||||
| def get_class_indexing(self): | def get_class_indexing(self): | ||||
| """ | """ | ||||
| @@ -115,6 +115,23 @@ def test_case_1(): | |||||
| assert (num == 18) | assert (num == 18) | ||||
| def test_case_2(): | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True) | |||||
| sizes = [0.5, 0.5] | |||||
| randomize = False | |||||
| dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize) | |||||
| num_iter = 0 | |||||
| for _ in dataset1.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert (num_iter == 5) | |||||
| num_iter = 0 | |||||
| for _ in dataset2.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert (num_iter == 5) | |||||
| def test_voc_exception(): | def test_voc_exception(): | ||||
| try: | try: | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True) | data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True) | ||||
| @@ -172,4 +189,5 @@ if __name__ == '__main__': | |||||
| test_voc_get_class_indexing() | test_voc_get_class_indexing() | ||||
| test_case_0() | test_case_0() | ||||
| test_case_1() | test_case_1() | ||||
| test_case_2() | |||||
| test_voc_exception() | test_voc_exception() | ||||