DistributedSampler with GeneratorOp bug Issue#2: GeneratorOp with iterable source doesn't respect num_samples Issue#3: Generator with schemas does not take samplerstags/v1.2.0-rc1
| @@ -193,9 +193,9 @@ Status GeneratorOp::operator()() { | |||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); | ||||
| std::unique_ptr<DataBuffer> fetched_buffer; | std::unique_ptr<DataBuffer> fetched_buffer; | ||||
| int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_; | |||||
| RETURN_IF_NOT_OK(Init()); | RETURN_IF_NOT_OK(Init()); | ||||
| int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_; | |||||
| bool eof = false; | bool eof = false; | ||||
| while (!eof) { | while (!eof) { | ||||
| // Create new buffer each iteration | // Create new buffer each iteration | ||||
| @@ -184,7 +184,6 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| if (device_id_ < remainder) shard_size++; | if (device_id_ < remainder) shard_size++; | ||||
| if (device_id_ < offset_) shard_size--; | if (device_id_ < offset_) shard_size--; | ||||
| } else { | } else { | ||||
| offset_ = 0; | |||||
| shard_size = (child_num_rows + num_devices_ - 1) / num_devices_; | shard_size = (child_num_rows + num_devices_ - 1) / num_devices_; | ||||
| } | } | ||||
| // add 1 to an empty shard | // add 1 to an empty shard | ||||
| @@ -42,7 +42,12 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector< | |||||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema, | GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema, | ||||
| int64_t source_len, std::shared_ptr<SamplerObj> sampler) | int64_t source_len, std::shared_ptr<SamplerObj> sampler) | ||||
| : MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {} | |||||
| : MappableSourceNode(), | |||||
| generator_function_(generator_function), | |||||
| schema_(schema), | |||||
| reset_ancestor_(nullptr), | |||||
| sampler_(std::move(sampler)), | |||||
| source_len_(source_len) {} | |||||
| std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | std::shared_ptr<DatasetNode> GeneratorNode::Copy() { | ||||
| std::shared_ptr<GeneratorNode> node; | std::shared_ptr<GeneratorNode> node; | ||||
| @@ -2233,6 +2233,7 @@ _GLOBAL_PYFUNC_LIST = [] | |||||
| _OP_NAME = dict() | _OP_NAME = dict() | ||||
| _OP_PROCESS = dict() | _OP_PROCESS = dict() | ||||
| # Pyfunc worker init function | # Pyfunc worker init function | ||||
| # Python multiprocessing library forbid sending lambda function through pipe. | # Python multiprocessing library forbid sending lambda function through pipe. | ||||
| # This init function allow us to add all Python function to a global collection and then fork afterwards. | # This init function allow us to add all Python function to a global collection and then fork afterwards. | ||||
| @@ -3781,6 +3782,8 @@ class GeneratorDataset(MappableDataset): | |||||
| try: | try: | ||||
| new_op.sampler = None | new_op.sampler = None | ||||
| new_op.sample_fn = sample_fn | new_op.sample_fn = sample_fn | ||||
| new_op.source_len = min(new_op.source_len, | |||||
| new_op.num_samples) if new_op.num_samples is not None else new_op.source_len | |||||
| iter(self.source) | iter(self.source) | ||||
| except TypeError: | except TypeError: | ||||
| # Use generator function if input callable | # Use generator function if input callable | ||||