| @@ -202,6 +202,13 @@ void bindDatasetOps(py::module *m) { | |||
| return count; | |||
| }); | |||
| (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") | |||
| .def_static("get_num_rows", | |||
| [](const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples) { | |||
| int64_t count = 0; | |||
| THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count)); | |||
| return count; | |||
| }) | |||
| .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, | |||
| const std::string &task_mode, const py::dict &dict, int64_t numSamples) { | |||
| std::map<std::string, int32_t> output_class_indexing; | |||
| @@ -442,6 +442,32 @@ Status VOCOp::GetNumRowsInDataset(int64_t *num) const { | |||
| return Status::OK(); | |||
| } | |||
| Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples, int64_t *count) { | |||
| if (task_type == "Detection") { | |||
| std::map<std::string, int32_t> input_class_indexing; | |||
| for (auto p : dict) { | |||
| (void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first), | |||
| py::reinterpret_borrow<py::int_>(p.second))); | |||
| } | |||
| std::shared_ptr<VOCOp> op; | |||
| RETURN_IF_NOT_OK( | |||
| Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | |||
| RETURN_IF_NOT_OK(op->ParseAnnotationIds()); | |||
| *count = static_cast<int64_t>(op->image_ids_.size()); | |||
| } else if (task_type == "Segmentation") { | |||
| std::shared_ptr<VOCOp> op; | |||
| RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | |||
| *count = static_cast<int64_t>(op->image_ids_.size()); | |||
| } | |||
| *count = (numSamples == 0 || *count < numSamples) ? *count : numSamples; | |||
| return Status::OK(); | |||
| } | |||
| Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples, | |||
| std::map<std::string, int32_t> *output_class_indexing) { | |||
| @@ -208,6 +208,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // @param show_all | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // @param const std::string &dir - VOC dir path | |||
| // @param const std::string &task_type - task type of reading voc job | |||
| // @param const std::string &task_mode - task mode of reading voc job | |||
| // @param const py::dict &dict - input dict of class index | |||
| // @param int64_t numSamples - samples number of VOCDataset | |||
| // @param int64_t *count - output rows number of VOCDataset | |||
| static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples, int64_t *count); | |||
| // @param const std::string &dir - VOC dir path | |||
| // @param const std::string &task_type - task type of reading voc job | |||
| // @param const std::string &task_mode - task mode of reading voc job | |||
| @@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset): | |||
| >>> new_sampler = ds.DistributedSampler(10, 2) | |||
| >>> data.use_sampler(new_sampler) | |||
| """ | |||
| if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): | |||
| raise TypeError("new_sampler is not an instance of a sampler.") | |||
| if new_sampler is None: | |||
| raise TypeError("Input sampler could not be None.") | |||
| if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): | |||
| raise TypeError("Input sampler is not an instance of a sampler.") | |||
| self.sampler = self.sampler.child_sampler | |||
| self.add_sampler(new_sampler) | |||
| @@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| if self.class_indexing is None: | |||
| class_indexing = dict() | |||
| else: | |||
| class_indexing = self.class_indexing | |||
| num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is None: | |||
| return self.num_samples | |||
| return rows_per_shard | |||
| return min(rows_from_sampler, self.num_samples) | |||
| return min(rows_from_sampler, rows_per_shard) | |||
| def get_class_indexing(self): | |||
| """ | |||
| @@ -115,6 +115,23 @@ def test_case_1(): | |||
| assert (num == 18) | |||
| def test_case_2(): | |||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True) | |||
| sizes = [0.5, 0.5] | |||
| randomize = False | |||
| dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize) | |||
| num_iter = 0 | |||
| for _ in dataset1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert (num_iter == 5) | |||
| num_iter = 0 | |||
| for _ in dataset2.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert (num_iter == 5) | |||
| def test_voc_exception(): | |||
| try: | |||
| data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True) | |||
| @@ -172,4 +189,5 @@ if __name__ == '__main__': | |||
| test_voc_get_class_indexing() | |||
| test_case_0() | |||
| test_case_1() | |||
| test_case_2() | |||
| test_voc_exception() | |||