From 50b783ee1360031588e39483b64c2e039a4b8c46 Mon Sep 17 00:00:00 2001 From: anzhengqi Date: Fri, 15 Jan 2021 11:46:53 +0800 Subject: [PATCH] fix generator or user defined sampler len method unmatch iter method --- .../engine/datasetops/source/generator_op.cc | 15 +++++++++++++-- .../engine/datasetops/source/generator_op.h | 5 ++++- .../engine/ir/datasetops/source/generator_node.cc | 3 ++- mindspore/dataset/engine/datasets.py | 4 +++- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index fa31f09156..e70213a03e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -49,14 +49,16 @@ Status GeneratorOp::Builder::Build(std::shared_ptr *ptr) { GeneratorOp::GeneratorOp(py::function generator_function, std::vector column_names, std::vector column_types, int32_t prefetch_size, int32_t buffer_size, - int32_t connector_size) + int32_t connector_size, int64_t pre_counter_size) : PipelineOp(connector_size), generator_function_(generator_function), column_names_(column_names), column_types_(column_types), prefetch_size_(prefetch_size), buffer_size_(buffer_size), - buffer_id_(0) {} + pre_counter_size_(pre_counter_size), + buffer_id_(0), + generator_counter_(0) {} GeneratorOp::~GeneratorOp() { this->Dealloc(); } @@ -146,6 +148,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) { TensorRow row; RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row)); tt->push_back(std::move(row)); + generator_counter_++; } return Status::OK(); } @@ -209,6 +212,13 @@ Status GeneratorOp::operator()() { if (!eoe) { return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); } + if (pre_counter_size_ != -1 && pre_counter_size_ != generator_counter_) { + std::stringstream ss; + ss << "The actual amount of data read from generator " << generator_counter_ + << " is different from generator.len " << pre_counter_size_ + << ", you should adjust generator.len to make them match."; + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str()); + } } } if (fetched_table->size() > 0) { @@ -254,6 +264,7 @@ Status GeneratorOp::Reset() { // Wake up master thread wp_.Set(); } + generator_counter_ = 0; return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h index 609bf62a5c..a4d12cf9c2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -93,7 +93,8 @@ class GeneratorOp : public PipelineOp { }; GeneratorOp(py::function generator_function, std::vector column_names, - std::vector column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size); + std::vector column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size, + int64_t pre_counter_size = 0); ~GeneratorOp(); @@ -142,6 +143,8 @@ class GeneratorOp : public PipelineOp { std::vector column_types_; int32_t prefetch_size_; int32_t buffer_size_; + int64_t pre_counter_size_; + int64_t generator_counter_; py::object generator_; int32_t buffer_id_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index 993ee7523e..49635e7345 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -46,6 +46,7 @@ std::shared_ptr GeneratorNode::Copy() { } else { node = std::make_shared(generator_function_, schema_); } + node->SetGeneratorDatasetSize(dataset_size_); return node; } @@ -72,7 +73,7 @@ Status GeneratorNode::Build(std::vector> *const node_ // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by // GeneratorOp internally. Here it is given a zero which is the default in generator builder std::shared_ptr op = std::make_shared(generator_function_, column_names_, column_types_, 0, - rows_per_buffer_, connector_que_size_); + rows_per_buffer_, connector_que_size_, dataset_size_); // Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init // needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private). diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index edc60183b1..668b88a360 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2663,7 +2663,7 @@ class ConcatDataset(Dataset): tem_sampler = copy.deepcopy(sampler) tem_sampler.set_offset(cumulative_samples_nums) - child.sampler = tem_sampler + child.use_sampler(tem_sampler) cumulative_samples_nums += self.children_sizes_[index] cumulative_samples_nums %= sampler.num_shards @@ -3808,6 +3808,8 @@ class GeneratorDataset(MappableDataset): self.dataset_size = math.ceil(len(self.source) / self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() + if self.num_samples is not None and self.num_samples < rows_from_sampler: + rows_from_sampler = self.num_samples if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: self.dataset_size = rows_from_sampler