Browse Source

add voc support split

tags/v0.5.0-beta
xiefangqi 5 years ago
parent
commit
5e4728c50f
5 changed files with 78 additions and 4 deletions
  1. +7
    -0
      mindspore/ccsrc/dataset/api/python_bindings.cc
  2. +26
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
  3. +9
    -0
      mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
  4. +18
    -4
      mindspore/dataset/engine/datasets.py
  5. +18
    -0
      tests/ut/python/dataset/test_datasets_voc.py

+ 7
- 0
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -202,6 +202,13 @@ void bindDatasetOps(py::module *m) {
return count; return count;
}); });
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp") (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples) {
int64_t count = 0;
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
return count;
})
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict, int64_t numSamples) { const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
std::map<std::string, int32_t> output_class_indexing; std::map<std::string, int32_t> output_class_indexing;


+ 26
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc View File

@@ -442,6 +442,32 @@ Status VOCOp::GetNumRowsInDataset(int64_t *num) const {
return Status::OK(); return Status::OK();
} }


Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples, int64_t *count) {
if (task_type == "Detection") {
std::map<std::string, int32_t> input_class_indexing;
for (auto p : dict) {
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
py::reinterpret_borrow<py::int_>(p.second)));
}

std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
*count = static_cast<int64_t>(op->image_ids_.size());
} else if (task_type == "Segmentation") {
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
*count = static_cast<int64_t>(op->image_ids_.size());
}
*count = (numSamples == 0 || *count < numSamples) ? *count : numSamples;

return Status::OK();
}

Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples, const py::dict &dict, int64_t numSamples,
std::map<std::string, int32_t> *output_class_indexing) { std::map<std::string, int32_t> *output_class_indexing) {


+ 9
- 0
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h View File

@@ -208,6 +208,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param show_all // @param show_all
void Print(std::ostream &out, bool show_all) const override; void Print(std::ostream &out, bool show_all) const override;


// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
// @param const py::dict &dict - input dict of class index
// @param int64_t numSamples - samples number of VOCDataset
// @param int64_t *count - output rows number of VOCDataset
static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples, int64_t *count);

// @param const std::string &dir - VOC dir path // @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job // @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job // @param const std::string &task_mode - task mode of reading voc job


+ 18
- 4
mindspore/dataset/engine/datasets.py View File

@@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset):
>>> new_sampler = ds.DistributedSampler(10, 2) >>> new_sampler = ds.DistributedSampler(10, 2)
>>> data.use_sampler(new_sampler) >>> data.use_sampler(new_sampler)
""" """
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("new_sampler is not an instance of a sampler.")
if new_sampler is None:
raise TypeError("Input sampler could not be None.")
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("Input sampler is not an instance of a sampler.")


self.sampler = self.sampler.child_sampler self.sampler = self.sampler.child_sampler
self.add_sampler(new_sampler) self.add_sampler(new_sampler)
@@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples

if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing

num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size() rows_from_sampler = self._get_sampler_dataset_size()


if rows_from_sampler is None: if rows_from_sampler is None:
return self.num_samples
return rows_per_shard


return min(rows_from_sampler, self.num_samples)
return min(rows_from_sampler, rows_per_shard)


def get_class_indexing(self): def get_class_indexing(self):
""" """


+ 18
- 0
tests/ut/python/dataset/test_datasets_voc.py View File

@@ -115,6 +115,23 @@ def test_case_1():
assert (num == 18) assert (num == 18)




def test_case_2():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
sizes = [0.5, 0.5]
randomize = False
dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize)

num_iter = 0
for _ in dataset1.create_dict_iterator():
num_iter += 1
assert (num_iter == 5)

num_iter = 0
for _ in dataset2.create_dict_iterator():
num_iter += 1
assert (num_iter == 5)


def test_voc_exception(): def test_voc_exception():
try: try:
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True) data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True)
@@ -172,4 +189,5 @@ if __name__ == '__main__':
test_voc_get_class_indexing() test_voc_get_class_indexing()
test_case_0() test_case_0()
test_case_1() test_case_1()
test_case_2()
test_voc_exception() test_voc_exception()

Loading…
Cancel
Save