| @@ -517,7 +517,7 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<Datase | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "generator_function") { | |||
| if (key == "source") { | |||
| py::object obj = py::cast(&value); | |||
| if (!py::isinstance<py::function>(obj)) { | |||
| std::string err_msg = "Error: generator is invalid or not set."; | |||
| @@ -384,7 +384,16 @@ void bindTensorOps4(py::module *m) { | |||
| } | |||
| void bindSamplerOps(py::module *m) { | |||
| (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler"); | |||
| (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler") | |||
| .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) | |||
| .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) | |||
| .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) | |||
| .def("get_indices", [](Sampler &self) { | |||
| py::array ret; | |||
| THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); | |||
| return ret; | |||
| }); | |||
| (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); | |||
| (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | |||
| @@ -491,6 +491,8 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const { | |||
| // return data as numpy, should return status | |||
| Status Tensor::GetDataAsNumpy(py::array *data) { | |||
| RETURN_UNEXPECTED_IF_NULL(data_); | |||
| RETURN_UNEXPECTED_IF_NULL(data); | |||
| if (type_ == DataType::DE_BOOL) { | |||
| *data = py::array_t<bool>(shape_.AsVector(), reinterpret_cast<bool *>(data_)); | |||
| } else if (type_ == DataType::DE_INT8) { | |||
| @@ -100,7 +100,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() { | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(ParseImageAttrInfo()); | |||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -240,7 +240,7 @@ Status CifarOp::Reset() { | |||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | |||
| Status CifarOp::InitSampler() { | |||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -258,7 +258,7 @@ Status ImageFolderOp::Reset() { | |||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | |||
| Status ImageFolderOp::InitSampler() { | |||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -254,7 +254,7 @@ Status ManifestOp::Reset() { | |||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | |||
| Status ManifestOp::InitSampler() { | |||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -205,7 +205,7 @@ Status MnistOp::Reset() { | |||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | |||
| Status MnistOp::InitSampler() { | |||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -31,8 +31,9 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu | |||
| num_devices_(num_dev), | |||
| shuffle_(shuffle) {} | |||
| Status DistributedSampler::Init(const RandomAccessOp *op) { | |||
| RETURN_IF_NOT_OK(Sampler::Init(op)); | |||
| Status DistributedSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, | |||
| "fail to init DistributedSampler"); | |||
| rnd_.seed(seed_++); | |||
| @@ -41,10 +41,8 @@ class DistributedSampler : public Sampler { | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // first handshake between StorageOp and Sampler | |||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | |||
| // @return | |||
| Status Init(const RandomAccessOp *) override; | |||
| // Init sampler, called by base class or python | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| @@ -28,9 +28,7 @@ PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer) | |||
| num_pk_samples_(0), | |||
| samples_per_class_(val) {} | |||
| Status PKSampler::Init(const RandomAccessOp *op) { | |||
| RETURN_UNEXPECTED_IF_NULL(op); | |||
| RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); | |||
| Status PKSampler::InitSampler() { | |||
| labels_.reserve(label_to_ids_.size()); | |||
| for (const auto &pair : label_to_ids_) { | |||
| if (pair.second.empty() == false) { | |||
| @@ -79,5 +77,13 @@ Status PKSampler::Reset() { | |||
| rnd_.seed(seed_++); | |||
| return Status::OK(); | |||
| } | |||
| Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| RETURN_UNEXPECTED_IF_NULL(op); | |||
| RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); | |||
| RETURN_IF_NOT_OK(InitSampler()); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -45,7 +45,10 @@ class PKSampler : public Sampler { // NOT YET FINISHED | |||
| // first handshake between StorageOp and Sampler | |||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | |||
| // @return | |||
| Status Init(const RandomAccessOp *op) override; | |||
| Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; | |||
| // init sampler, to be called by python or Handshake | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| @@ -49,10 +49,9 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| return Status::OK(); | |||
| } | |||
| Status RandomSampler::Init(const RandomAccessOp *op) { | |||
| RETURN_IF_NOT_OK(Sampler::Init(op)); | |||
| Status RandomSampler::InitSampler() { | |||
| num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "Fail to init RandomSampler"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| if (replacement_ == false) { | |||
| shuffled_ids_.reserve(num_rows_); | |||
| @@ -42,10 +42,8 @@ class RandomSampler : public Sampler { | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // first handshake between StorageOp and Sampler | |||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | |||
| // @return | |||
| Status Init(const RandomAccessOp *op) override; | |||
| // meant to be called by base class or python | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| @@ -20,12 +20,13 @@ namespace dataset { | |||
| Sampler::Sampler(int64_t samples_per_buffer) | |||
| : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | |||
| Status Sampler::Init(const RandomAccessOp *op) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr && samples_per_buffer_ > 0, "Fail to init Sampler()\n"); | |||
| Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); | |||
| RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); | |||
| RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); | |||
| // It's up to the derived class to check the validity of the two args | |||
| // Because some sampler only needs one of the arg (weighted_random_sampler) | |||
| RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback | |||
| return Status::OK(); | |||
| } | |||
| @@ -42,5 +43,49 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t | |||
| (void)(*sample_ids)->StartAddr(); // allocate memory in case user forgets! | |||
| return Status::OK(); | |||
| } | |||
| Status Sampler::GetAllIdsThenReset(py::array *data) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| // check samples_per_buffer is properly set and doesn't overflow | |||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid"); | |||
| // A call to derived class to get sample ids wrapped inside a buffer | |||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||
| // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch | |||
| RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); | |||
| // check this buffer is not a ctrl buffer | |||
| CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data)); | |||
| } catch (const std::runtime_error &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| } | |||
| // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch | |||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); | |||
| // Reset Sampler since this is the end of the epoch | |||
| RETURN_IF_NOT_OK(Reset()); | |||
| return Status::OK(); | |||
| } | |||
| Status Sampler::SetNumSamples(int64_t num_samples) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0"); | |||
| num_samples_ = num_samples; | |||
| return Status::OK(); | |||
| } | |||
| Status Sampler::SetNumRowsInDataset(int64_t num_rows) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0"); | |||
| num_rows_ = num_rows; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -78,14 +78,26 @@ class Sampler : public DatasetOp { | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0; | |||
| // return all ids in one epoch as a numpy array, then call reset | |||
| Status GetAllIdsThenReset(py::array *data); | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| Status Reset() override = 0; | |||
| // first handshake between StorageOp and Sampler. Base class init will call both GetNumRows and GetNumSamples | |||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | |||
| // setter function for num_rows_ | |||
| Status SetNumRowsInDataset(int64_t num_rows); | |||
| // setter function for num_samples_ | |||
| Status SetNumSamples(int64_t num_samples); | |||
| // first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples | |||
| // @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds() | |||
| // @return | |||
| virtual Status Init(const RandomAccessOp *op); | |||
| virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); | |||
| // initialize sampler and perform checks on certain vars | |||
| virtual Status InitSampler() { return Status::OK(); } | |||
| // Not meant to be called | |||
| // @return | |||
| @@ -41,9 +41,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) | |||
| return Status::OK(); | |||
| } | |||
| Status SequentialSampler::Init(const RandomAccessOp *op) { | |||
| RETURN_UNEXPECTED_IF_NULL(op); | |||
| RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); | |||
| Status SequentialSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| return Status::OK(); | |||
| @@ -32,10 +32,8 @@ class SequentialSampler : public Sampler { | |||
| // Destructor. | |||
| ~SequentialSampler() = default; | |||
| // Initialize the sampler. | |||
| // @param op | |||
| // @return Status | |||
| Status Init(const RandomAccessOp *op) override; | |||
| // init sampler, called by python | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| @@ -31,9 +31,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in | |||
| : Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | |||
| // Initialized this Sampler. | |||
| Status SubsetRandomSampler::Init(const RandomAccessOp *op) { | |||
| // Calling base class init. | |||
| RETURN_IF_NOT_OK(Sampler::Init(op)); | |||
| Status SubsetRandomSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); | |||
| // Initialize random generator with seed from config manager | |||
| rand_gen_.seed(GetSeed()); | |||
| @@ -38,9 +38,8 @@ class SubsetRandomSampler : public Sampler { | |||
| ~SubsetRandomSampler() = default; | |||
| // Initialize the sampler. | |||
| // @param op (Not used in this sampler) | |||
| // @return Status | |||
| Status Init(const RandomAccessOp *op) override; | |||
| Status InitSampler() override; | |||
| // Reset the internal variable to the initial state and reshuffle the indices. | |||
| // @return Status | |||
| @@ -29,21 +29,21 @@ namespace dataset { | |||
| // Constructor. | |||
| WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement, | |||
| int64_t samples_per_buffer) | |||
| : Sampler(samples_per_buffer), weights_(weights), replacement_(replacement), sample_id_(0), buffer_id_(0) { | |||
| num_samples_ = num_samples; // this variable is defined in base class sampler | |||
| } | |||
| : Sampler(samples_per_buffer), | |||
| weights_(weights), | |||
| replacement_(replacement), | |||
| sample_id_(0), | |||
| buffer_id_(0), | |||
| user_num_samples_(num_samples) {} | |||
| // Initialized this Sampler. | |||
| Status WeightedRandomSampler::Init(const RandomAccessOp *op) { | |||
| RETURN_UNEXPECTED_IF_NULL(op); | |||
| RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); | |||
| Status WeightedRandomSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); | |||
| // Initialize random generator with seed from config manager | |||
| rand_gen_.seed(GetSeed()); | |||
| samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init WeightedRandomSampler"); | |||
| samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_; | |||
| if (!replacement_) { | |||
| exp_dist_ = std::make_unique<std::exponential_distribution<>>(1); | |||
| @@ -65,8 +65,8 @@ void WeightedRandomSampler::InitOnePassSampling() { | |||
| } | |||
| // Partial sort the first `numSamples` elements. | |||
| std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); | |||
| for (int64_t i = 0; i < num_samples_; i++) { | |||
| std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end()); | |||
| for (int64_t i = 0; i < user_num_samples_; i++) { | |||
| onepass_ids_.push_back(val_idx[i].second); | |||
| } | |||
| } | |||
| @@ -91,11 +91,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf | |||
| "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); | |||
| } | |||
| if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { | |||
| if (!replacement_ && (weights_.size() < static_cast<size_t>(user_num_samples_))) { | |||
| RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); | |||
| } | |||
| if (sample_id_ == num_samples_) { | |||
| if (sample_id_ == user_num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | |||
| @@ -103,8 +103,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf | |||
| int64_t last_id = sample_id_ + samples_per_buffer_; | |||
| // Handling the return all samples at once, and when last draw is not a full batch. | |||
| if (last_id > num_samples_) { | |||
| last_id = num_samples_; | |||
| if (last_id > user_num_samples_) { | |||
| last_id = user_num_samples_; | |||
| } | |||
| // Allocate tensor. | |||
| @@ -43,7 +43,7 @@ class WeightedRandomSampler : public Sampler { | |||
| // Initialize the sampler. | |||
| // @param op (Not used in this sampler) | |||
| // @return Status | |||
| Status Init(const RandomAccessOp *op) override; | |||
| Status InitSampler() override; | |||
| // Reset the internal variable to the initial state and reshuffle the indices. | |||
| Status Reset() override; | |||
| @@ -69,6 +69,9 @@ class WeightedRandomSampler : public Sampler { | |||
| // Random engine and device | |||
| std::mt19937 rand_gen_; | |||
| // num_samples from user | |||
| int64_t user_num_samples_; | |||
| // Discrete distribution for generating weighted random numbers with replacement. | |||
| std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_; | |||
| @@ -220,7 +220,7 @@ Status VOCOp::ParseImageIds() { | |||
| } | |||
| Status VOCOp::InitSampler() { | |||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -1748,14 +1748,70 @@ class MindDataset(SourceDataset): | |||
| return num_rows | |||
| def ds_fn(dataset): | |||
| for val in dataset: | |||
| # convert output tensors to ndarrays | |||
| yield tuple([np.array(x) for x in val]) | |||
| def _iter_fn(dataset, num_samples): | |||
| """ | |||
| Generator function wrapper for iterable dataset | |||
| """ | |||
| if num_samples is not None: | |||
| ds_iter = iter(dataset) | |||
| for _ in range(num_samples): | |||
| try: | |||
| val = next(ds_iter) | |||
| except StopIteration: | |||
| return | |||
| # convert output tensors to ndarrays | |||
| yield tuple([np.array(x) for x in val]) | |||
| else: | |||
| for val in dataset: | |||
| # convert output tensors to ndarrays | |||
| yield tuple([np.array(x) for x in val]) | |||
| def _generator_fn(generator, num_samples): | |||
| """ | |||
| Generator function wrapper for generator function dataset | |||
| """ | |||
| if num_samples is not None: | |||
| gen_iter = generator() | |||
| for _ in range(num_samples): | |||
| try: | |||
| val = next(gen_iter) | |||
| except StopIteration: | |||
| return | |||
| yield val | |||
| else: | |||
| gen_iter = generator() | |||
| for val in gen_iter: | |||
| yield val | |||
| def sampler_fn(sampler, dataset): | |||
| for i in sampler: | |||
| def _py_sampler_fn(sampler, num_samples, dataset): | |||
| """ | |||
| Generator function wrapper for mappable dataset with python sampler | |||
| """ | |||
| if num_samples is not None: | |||
| sampler_iter = iter(sampler) | |||
| for _ in range(num_samples): | |||
| try: | |||
| idx = next(sampler_iter) | |||
| except StopIteration: | |||
| return | |||
| val = dataset[idx] | |||
| # convert output tensors to ndarrays | |||
| yield tuple([np.array(x) for x in val]) | |||
| else: | |||
| for i in sampler: | |||
| val = dataset[i] | |||
| # convert output tensors to ndarrays | |||
| yield tuple([np.array(x) for x in val]) | |||
| def _cpp_sampler_fn(sampler, dataset): | |||
| """ | |||
| Generator function wrapper for mappable dataset with cpp sampler | |||
| """ | |||
| indices = sampler.get_indices() | |||
| for i in indices: | |||
| val = dataset[i] | |||
| # convert output tensors to ndarrays | |||
| yield tuple([np.array(x) for x in val]) | |||
| @@ -1763,49 +1819,122 @@ def sampler_fn(sampler, dataset): | |||
| class GeneratorDataset(SourceDataset): | |||
| """ | |||
| A source dataset that generate data from calling generator function each epoch. | |||
| A source dataset that generate data from python by invoking python data source each epoch. | |||
| This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table | |||
| below shows what input args are allowed and their expected behavior. | |||
| .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle' | |||
| :widths: 25 25 50 | |||
| :header-rows: 1 | |||
| * - Parameter 'sampler' | |||
| - Parameter 'shuffle' | |||
| - Expected Order Behavior | |||
| * - None | |||
| - None | |||
| - random order | |||
| * - None | |||
| - True | |||
| - random order | |||
| * - None | |||
| - False | |||
| - sequential order | |||
| * - Sampler object | |||
| - None | |||
| - order defined by sampler | |||
| * - Sampler object | |||
| - True | |||
| - not allowed | |||
| * - Sampler object | |||
| - False | |||
| - not allowed | |||
| Args: | |||
| generator_function (callable): | |||
| A callable object that returns an Generator object that supports the iter() protocol. | |||
| Generator object is required to return a tuple of numpy array as a row of the dataset on next(). | |||
| source (Callable/Iterable/Random Accessible): | |||
| A generator callable object, an iterable python object or a random accessible python object. | |||
| Callable source is required to return a tuple of numpy array as a row of the dataset on source().next(). | |||
| Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next(). | |||
| Random accessible source is required to return a tuple of numpy array as a row of the dataset on | |||
| source[idx]. | |||
| column_names (list[str]): List of column names of the dataset. | |||
| column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None). | |||
| If provided, sanity check will be performed on generator output. | |||
| prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None). | |||
| sampler (Sampler, optional): Object used to choose samples from the dataset (default=None). | |||
| schema (Schema/String, optional): Path to the json schema file or schema object (default=None). | |||
| If the schema is not provided, the meta data from column_names and column_types is considered the schema. | |||
| num_samples (int, optional): The number of samples to be included in the dataset | |||
| (default=None, all images). | |||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. | |||
| (default=None, expected order behavior shown in the table). | |||
| sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is | |||
| required. | |||
| (default=None, expected order behavior shown in the table). | |||
| num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). | |||
| This argument should be specified only when 'num_samples' is "None". Random accessible input is required. | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only | |||
| when num_shards is also specified. Random accessible input is required. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # 1) generator function that generates multi-dimensional data | |||
| >>> import mindspore.dataengine as de | |||
| >>> # 1) Multidimensional generator function as callable input | |||
| >>> def generator_md(): | |||
| >>> for i in range(64): | |||
| >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),) | |||
| >>> # create multi_dimension_generator_dataset with GeneratorMD() and column name "multi_dimensional_data" | |||
| >>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"]) | |||
| >>> # 2) generator function that generates multi-columns data | |||
| >>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data" | |||
| >>> multi_dimension_generator_dataset = de.GeneratorDataset(generator_md, ["multi_dimensional_data"]) | |||
| >>> # 2) Multi-column generator function as callable input | |||
| >>> def generator_mc(maxid = 64): | |||
| >>> for i in range(maxid): | |||
| >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) | |||
| >>> # create multi_column_generator_dataset with GeneratorMC() and column names "col1" and "col2" | |||
| >>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1, col2"]) | |||
| >>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2" | |||
| >>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"]) | |||
| >>> # 3) Iterable dataset as iterable input | |||
| >>> class MyIterable(): | |||
| >>> def __iter__(self): | |||
| >>> return # User implementation | |||
| >>> # create iterable_generator_dataset with MyIterable object | |||
| >>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"]) | |||
| >>> # 4) Random accessible dataset as Random accessible input | |||
| >>> class MyRA(): | |||
| >>> def __getitem__(self, index): | |||
| >>> return # User implementation | |||
| >>> # create ra_generator_dataset with MyRA object | |||
| >>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"]) | |||
| >>> # List/Dict/Tuple is also random accessible | |||
| >>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"]) | |||
| >>> # 5) Built-in Sampler | |||
| >>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler()) | |||
| >>> | |||
| """ | |||
| @check_generatordataset | |||
| def __init__(self, generator_function, column_names, column_types=None, prefetch_size=None, sampler=None): | |||
| super().__init__(1) | |||
| if sampler is not None: | |||
| self.generator_function = (lambda: sampler_fn(sampler, generator_function)) | |||
| def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, | |||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | |||
| if self.sampler is not None and hasattr(source, "__getitem__"): | |||
| if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler)): | |||
| if num_samples is None: | |||
| num_samples = len(source) | |||
| sampler_instance = self.sampler.create() | |||
| sampler_instance.set_num_rows(len(source)) | |||
| sampler_instance.set_num_samples(num_samples) | |||
| sampler_instance.initialize() | |||
| self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) | |||
| else: | |||
| self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) | |||
| else: | |||
| try: | |||
| # test to see if generator_function is iterable | |||
| iter(generator_function) | |||
| iter(source) | |||
| except TypeError: | |||
| # generator_function was not iterable, assume it is a function | |||
| self.generator_function = generator_function | |||
| # Use generator function if input callable | |||
| self.source = (lambda: _generator_fn(source, num_samples)) | |||
| else: | |||
| # generator_function was iterable, build a function around it | |||
| self.generator_function = (lambda: ds_fn(generator_function)) | |||
| # Use iterator function if input is iterable | |||
| # Random accessible input is also iterable | |||
| self.source = (lambda: _iter_fn(source, num_samples)) | |||
| self.column_names = column_names | |||
| @@ -1813,17 +1942,12 @@ class GeneratorDataset(SourceDataset): | |||
| self.column_types = mstypelist_to_detypelist(column_types) | |||
| else: | |||
| self.column_types = column_types | |||
| self.distribution = "" | |||
| self.prefetch_size = prefetch_size | |||
| self.sampler = sampler | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["generator_function"] = self.generator_function | |||
| args["source"] = self.source | |||
| args["column_names"] = self.column_names | |||
| args["column_types"] = self.column_types | |||
| args["prefetch_size"] = self.prefetch_size | |||
| args["sampler"] = self.sampler | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -20,7 +20,6 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. | |||
| import mindspore._c_dataengine as cde | |||
| class DistributedSampler(): | |||
| """ | |||
| Sampler that access a shard of the dataset. | |||
| @@ -543,28 +543,48 @@ def check_generatordataset(method): | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| nreq_param_int = ['prefetch_size'] | |||
| nreq_param_list = ['column_names', 'column_types'] | |||
| # check generator_function; required argument | |||
| generator_function = param_dict.get('generator_function') | |||
| if generator_function is None: | |||
| raise ValueError("generator_function is not provided.") | |||
| source = param_dict.get('source') | |||
| if source is None: | |||
| raise ValueError("source is not provided.") | |||
| if not callable(source): | |||
| try: | |||
| iter(source) | |||
| except TypeError: | |||
| raise TypeError("source should be callable, iterable or random accessible") | |||
| # check column_names; required argument | |||
| column_names = param_dict.get('column_names') | |||
| if column_names is None: | |||
| raise ValueError("column_names is not provided.") | |||
| # check prefetch_size range | |||
| prefetch_size = param_dict.get('prefetch_size') | |||
| if prefetch_size is not None and (prefetch_size <= 0 or prefetch_size > 1024): | |||
| raise ValueError("prefetch_size exceeds the boundary.") | |||
| # check optional argument | |||
| nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"] | |||
| check_param_type(nreq_param_int, param_dict, int) | |||
| nreq_param_list = ["column_types"] | |||
| check_param_type(nreq_param_list, param_dict, list) | |||
| num_shards = param_dict.get("num_shards") | |||
| shard_id = param_dict.get("shard_id") | |||
| if (num_shards is None) != (shard_id is None): | |||
| # These two parameters appear together. | |||
| raise ValueError("num_shards and shard_id need to be passed in together") | |||
| if num_shards is not None: | |||
| if shard_id >= num_shards: | |||
| raise ValueError("shard_id should be less than num_shards") | |||
| sampler = param_dict.get("sampler") | |||
| if sampler is not None: | |||
| if isinstance(sampler, samplers.PKSampler): | |||
| raise ValueError("PKSampler is not supported by GeneratorDataset") | |||
| if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler)): | |||
| try: | |||
| iter(sampler) | |||
| except TypeError: | |||
| raise TypeError("sampler should be either iterable or from dataset.samplers.py") | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| @@ -75,7 +75,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| for (int i = 0; i < 6; i++) { | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true)); | |||
| sampler->Init(&mock); | |||
| sampler->HandshakeRandomAccessOp(&mock); | |||
| sampler->GetNextBuffer(&db); | |||
| db->GetTensor(&tensor, 0, 0); | |||
| MS_LOG(DEBUG) << (*tensor); | |||
| @@ -95,7 +95,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3); | |||
| std::unique_ptr<DataBuffer> db; | |||
| std::shared_ptr<Tensor> tensor; | |||
| sampler->Init(&mock); | |||
| sampler->HandshakeRandomAccessOp(&mock); | |||
| sampler->GetNextBuffer(&db); | |||
| db->GetTensor(&tensor, 0, 0); | |||
| EXPECT_TRUE((*tensor) == (*label1)); | |||
| @@ -52,8 +52,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||
| std::unordered_set<int64_t> in_set(in.begin(), in.end()); | |||
| SubsetRandomSampler sampler(in); | |||
| DummyRandomAccessOp dummy_random_access_op(5); | |||
| sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(5); | |||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -80,8 +80,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||
| std::vector<int64_t> input(total_samples, 1); | |||
| SubsetRandomSampler sampler(input, samples_per_buffer); | |||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||
| sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -111,8 +111,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||
| std::unordered_set<int64_t> in_set(in.begin(), in.end()); | |||
| SubsetRandomSampler sampler(in); | |||
| DummyRandomAccessOp dummy_random_access_op(5); | |||
| sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(5); | |||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -60,8 +60,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||
| // create sampler with replacement = true | |||
| WeightedRandomSampler m_sampler(weights, num_samples, true); | |||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||
| m_sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -90,8 +90,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||
| // create sampler with replacement = replacement | |||
| WeightedRandomSampler m_sampler(weights, num_samples, false); | |||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||
| m_sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -126,8 +126,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | |||
| // create sampler with replacement = replacement | |||
| WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer); | |||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||
| m_sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -162,8 +162,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { | |||
| // create sampler with replacement = replacement | |||
| WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer); | |||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||
| m_sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -203,8 +203,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||
| // create sampler with replacement = true | |||
| WeightedRandomSampler m_sampler(weights, num_samples, true); | |||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||
| m_sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -248,8 +248,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||
| // create sampler with replacement = true | |||
| WeightedRandomSampler m_sampler(weights, num_samples, false); | |||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||
| m_sampler.Init(&dummy_random_access_op); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| std::unique_ptr<DataBuffer> db; | |||
| TensorRow row; | |||
| @@ -439,6 +439,74 @@ def test_case_error_4(): | |||
| assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) | |||
| def test_sequential_sampler(): | |||
| source = [(np.array([x]),) for x in range(64)] | |||
| ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler()) | |||
| i = 0 | |||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(data["data"], golden) | |||
| i = i + 1 | |||
| def test_random_sampler(): | |||
| source = [(np.array([x]),) for x in range(64)] | |||
| ds1 = ds.GeneratorDataset(source, ["data"], shuffle = True) | |||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||
| pass | |||
| def test_distributed_sampler(): | |||
| source = [(np.array([x]),) for x in range(64)] | |||
| for sid in range(8): | |||
| ds1 = ds.GeneratorDataset(source, ["data"], shuffle = False, num_shards=8, shard_id=sid) | |||
| i = sid | |||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||
| golden = np.array([i]) | |||
| assert np.array_equal(data["data"], golden) | |||
| i = i + 8 | |||
| def test_num_samples(): | |||
| source = [(np.array([x]),) for x in range(64)] | |||
| num_samples = 32 | |||
| ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples = num_samples) | |||
| ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples = num_samples) | |||
| ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples) | |||
| count = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| count = count + 1 | |||
| assert count == num_samples | |||
| count = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| count = count + 1 | |||
| assert count == num_samples | |||
| count = 0 | |||
| for _ in ds3.create_dict_iterator(): | |||
| count = count + 1 | |||
| assert count == num_samples | |||
| def test_num_samples_underflow(): | |||
| source = [(np.array([x]),) for x in range(64)] | |||
| num_samples = 256 | |||
| ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples = num_samples) | |||
| ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples) | |||
| count = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| count = count + 1 | |||
| assert count == 64 | |||
| count = 0 | |||
| for _ in ds3.create_dict_iterator(): | |||
| count = count + 1 | |||
| assert count == 64 | |||
| if __name__ == "__main__": | |||
| test_case_0() | |||
| test_case_1() | |||
| @@ -458,3 +526,6 @@ if __name__ == "__main__": | |||
| test_case_error_2() | |||
| test_case_error_3() | |||
| test_case_error_4() | |||
| test_sequential_sampler() | |||
| test_distributed_sampler() | |||
| test_random_sampler() | |||
| @@ -87,7 +87,28 @@ def test_random_sampler_multi_iter(print_res=False): | |||
| test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5]) | |||
| def test_sampler_py_api(): | |||
| sampler = ds.SequentialSampler().create() | |||
| sampler.set_num_rows(128) | |||
| sampler.set_num_samples(64) | |||
| sampler.initialize() | |||
| sampler.get_indices() | |||
| sampler = ds.RandomSampler().create() | |||
| sampler.set_num_rows(128) | |||
| sampler.set_num_samples(64) | |||
| sampler.initialize() | |||
| sampler.get_indices() | |||
| sampler = ds.DistributedSampler(8, 4).create() | |||
| sampler.set_num_rows(128) | |||
| sampler.set_num_samples(64) | |||
| sampler.initialize() | |||
| sampler.get_indices() | |||
| if __name__ == '__main__': | |||
| test_sequential_sampler(True) | |||
| test_random_sampler(True) | |||
| test_random_sampler_multi_iter(True) | |||
| test_sampler_py_api() | |||