Browse Source

Issue#1:

DistributedSampler with GeneratorOp bug
Issue#2:
GeneratorOp with iterable source doesn't respect num_samples
Issue#3:
Generator with schemas does not take samplers
tags/v1.2.0-rc1
hesham 4 years ago
parent
commit
edfb2e1414
4 changed files with 10 additions and 3 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
  2. +0
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
  3. +6
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc
  4. +3
    -0
      mindspore/dataset/engine/datasets.py

+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc View File

@@ -193,9 +193,9 @@ Status GeneratorOp::operator()() {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks()));
std::unique_ptr<DataBuffer> fetched_buffer;
int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_;
RETURN_IF_NOT_OK(Init());

int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_;
bool eof = false;
while (!eof) {
// Create new buffer each iteration


+ 0
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc View File

@@ -184,7 +184,6 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
if (device_id_ < remainder) shard_size++;
if (device_id_ < offset_) shard_size--;
} else {
offset_ = 0;
shard_size = (child_num_rows + num_devices_ - 1) / num_devices_;
}
// add 1 to an empty shard


+ 6
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc View File

@@ -42,7 +42,12 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<

GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema,
int64_t source_len, std::shared_ptr<SamplerObj> sampler)
: MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {}
: MappableSourceNode(),
generator_function_(generator_function),
schema_(schema),
reset_ancestor_(nullptr),
sampler_(std::move(sampler)),
source_len_(source_len) {}

std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
std::shared_ptr<GeneratorNode> node;


+ 3
- 0
mindspore/dataset/engine/datasets.py View File

@@ -2233,6 +2233,7 @@ _GLOBAL_PYFUNC_LIST = []
_OP_NAME = dict()
_OP_PROCESS = dict()


# Pyfunc worker init function
# Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all Python function to a global collection and then fork afterwards.
@@ -3781,6 +3782,8 @@ class GeneratorDataset(MappableDataset):
try:
new_op.sampler = None
new_op.sample_fn = sample_fn
new_op.source_len = min(new_op.source_len,
new_op.num_samples) if new_op.num_samples is not None else new_op.source_len
iter(self.source)
except TypeError:
# Use generator function if input callable


Loading…
Cancel
Save