| @@ -517,7 +517,7 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<Datase | |||||
| std::string key = py::str(arg.first); | std::string key = py::str(arg.first); | ||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "generator_function") { | |||||
| if (key == "source") { | |||||
| py::object obj = py::cast(&value); | py::object obj = py::cast(&value); | ||||
| if (!py::isinstance<py::function>(obj)) { | if (!py::isinstance<py::function>(obj)) { | ||||
| std::string err_msg = "Error: generator is invalid or not set."; | std::string err_msg = "Error: generator is invalid or not set."; | ||||
| @@ -384,7 +384,16 @@ void bindTensorOps4(py::module *m) { | |||||
| } | } | ||||
| void bindSamplerOps(py::module *m) { | void bindSamplerOps(py::module *m) { | ||||
| (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler"); | |||||
| (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler") | |||||
| .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) | |||||
| .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) | |||||
| .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) | |||||
| .def("get_indices", [](Sampler &self) { | |||||
| py::array ret; | |||||
| THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); | |||||
| return ret; | |||||
| }); | |||||
| (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); | (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); | ||||
| (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | ||||
| @@ -491,6 +491,8 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const { | |||||
| // return data as numpy, should return status | // return data as numpy, should return status | ||||
| Status Tensor::GetDataAsNumpy(py::array *data) { | Status Tensor::GetDataAsNumpy(py::array *data) { | ||||
| RETURN_UNEXPECTED_IF_NULL(data_); | |||||
| RETURN_UNEXPECTED_IF_NULL(data); | |||||
| if (type_ == DataType::DE_BOOL) { | if (type_ == DataType::DE_BOOL) { | ||||
| *data = py::array_t<bool>(shape_.AsVector(), reinterpret_cast<bool *>(data_)); | *data = py::array_t<bool>(shape_.AsVector(), reinterpret_cast<bool *>(data_)); | ||||
| } else if (type_ == DataType::DE_INT8) { | } else if (type_ == DataType::DE_INT8) { | ||||
| @@ -100,7 +100,7 @@ Status CelebAOp::LaunchThreadsAndInitOp() { | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); | RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| RETURN_IF_NOT_OK(ParseImageAttrInfo()); | RETURN_IF_NOT_OK(ParseImageAttrInfo()); | ||||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -240,7 +240,7 @@ Status CifarOp::Reset() { | |||||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | ||||
| Status CifarOp::InitSampler() { | Status CifarOp::InitSampler() { | ||||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -258,7 +258,7 @@ Status ImageFolderOp::Reset() { | |||||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | ||||
| Status ImageFolderOp::InitSampler() { | Status ImageFolderOp::InitSampler() { | ||||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -254,7 +254,7 @@ Status ManifestOp::Reset() { | |||||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | ||||
| Status ManifestOp::InitSampler() { | Status ManifestOp::InitSampler() { | ||||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -205,7 +205,7 @@ Status MnistOp::Reset() { | |||||
| // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | // hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows | ||||
| Status MnistOp::InitSampler() { | Status MnistOp::InitSampler() { | ||||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -31,8 +31,9 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu | |||||
| num_devices_(num_dev), | num_devices_(num_dev), | ||||
| shuffle_(shuffle) {} | shuffle_(shuffle) {} | ||||
| Status DistributedSampler::Init(const RandomAccessOp *op) { | |||||
| RETURN_IF_NOT_OK(Sampler::Init(op)); | |||||
| Status DistributedSampler::InitSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, | CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, | ||||
| "fail to init DistributedSampler"); | "fail to init DistributedSampler"); | ||||
| rnd_.seed(seed_++); | rnd_.seed(seed_++); | ||||
| @@ -41,10 +41,8 @@ class DistributedSampler : public Sampler { | |||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // first handshake between StorageOp and Sampler | |||||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | |||||
| // @return | |||||
| Status Init(const RandomAccessOp *) override; | |||||
| // Init sampler, called by base class or python | |||||
| Status InitSampler() override; | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| @@ -28,9 +28,7 @@ PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer) | |||||
| num_pk_samples_(0), | num_pk_samples_(0), | ||||
| samples_per_class_(val) {} | samples_per_class_(val) {} | ||||
| Status PKSampler::Init(const RandomAccessOp *op) { | |||||
| RETURN_UNEXPECTED_IF_NULL(op); | |||||
| RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); | |||||
| Status PKSampler::InitSampler() { | |||||
| labels_.reserve(label_to_ids_.size()); | labels_.reserve(label_to_ids_.size()); | ||||
| for (const auto &pair : label_to_ids_) { | for (const auto &pair : label_to_ids_) { | ||||
| if (pair.second.empty() == false) { | if (pair.second.empty() == false) { | ||||
| @@ -79,5 +77,13 @@ Status PKSampler::Reset() { | |||||
| rnd_.seed(seed_++); | rnd_.seed(seed_++); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| RETURN_UNEXPECTED_IF_NULL(op); | |||||
| RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); | |||||
| RETURN_IF_NOT_OK(InitSampler()); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,7 +45,10 @@ class PKSampler : public Sampler { // NOT YET FINISHED | |||||
| // first handshake between StorageOp and Sampler | // first handshake between StorageOp and Sampler | ||||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | ||||
| // @return | // @return | ||||
| Status Init(const RandomAccessOp *op) override; | |||||
| Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; | |||||
| // init sampler, to be called by python or Handshake | |||||
| Status InitSampler() override; | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| @@ -49,10 +49,9 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status RandomSampler::Init(const RandomAccessOp *op) { | |||||
| RETURN_IF_NOT_OK(Sampler::Init(op)); | |||||
| Status RandomSampler::InitSampler() { | |||||
| num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; | num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_; | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "Fail to init RandomSampler"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); | |||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | ||||
| if (replacement_ == false) { | if (replacement_ == false) { | ||||
| shuffled_ids_.reserve(num_rows_); | shuffled_ids_.reserve(num_rows_); | ||||
| @@ -42,10 +42,8 @@ class RandomSampler : public Sampler { | |||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // first handshake between StorageOp and Sampler | |||||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | |||||
| // @return | |||||
| Status Init(const RandomAccessOp *op) override; | |||||
| // meant to be called by base class or python | |||||
| Status InitSampler() override; | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| @@ -20,12 +20,13 @@ namespace dataset { | |||||
| Sampler::Sampler(int64_t samples_per_buffer) | Sampler::Sampler(int64_t samples_per_buffer) | ||||
| : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | ||||
| Status Sampler::Init(const RandomAccessOp *op) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr && samples_per_buffer_ > 0, "Fail to init Sampler()\n"); | |||||
| Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); | |||||
| RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); | RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); | ||||
| RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); | RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); | ||||
| // It's up to the derived class to check the validity of the two args | // It's up to the derived class to check the validity of the two args | ||||
| // Because some sampler only needs one of the arg (weighted_random_sampler) | // Because some sampler only needs one of the arg (weighted_random_sampler) | ||||
| RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -42,5 +43,49 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t | |||||
| (void)(*sample_ids)->StartAddr(); // allocate memory in case user forgets! | (void)(*sample_ids)->StartAddr(); // allocate memory in case user forgets! | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Sampler::GetAllIdsThenReset(py::array *data) { | |||||
| std::unique_ptr<DataBuffer> db; | |||||
| std::shared_ptr<Tensor> sample_ids; | |||||
| // check samples_per_buffer is properly set and doesn't overflow | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid"); | |||||
| // A call to derived class to get sample ids wrapped inside a buffer | |||||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||||
| // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch | |||||
| RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); | |||||
| // check this buffer is not a ctrl buffer | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); | |||||
| { | |||||
| py::gil_scoped_acquire gil_acquire; | |||||
| if (Py_IsInitialized() == 0) { | |||||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||||
| } | |||||
| try { | |||||
| RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data)); | |||||
| } catch (const std::runtime_error &e) { | |||||
| return Status(StatusCode::kPyFuncException, e.what()); | |||||
| } | |||||
| } | |||||
| // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch | |||||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); | |||||
| // Reset Sampler since this is the end of the epoch | |||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Sampler::SetNumSamples(int64_t num_samples) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0"); | |||||
| num_samples_ = num_samples; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Sampler::SetNumRowsInDataset(int64_t num_rows) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0"); | |||||
| num_rows_ = num_rows; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -78,14 +78,26 @@ class Sampler : public DatasetOp { | |||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0; | Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override = 0; | ||||
| // return all ids in one epoch as a numpy array, then call reset | |||||
| Status GetAllIdsThenReset(py::array *data); | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status Reset() override = 0; | Status Reset() override = 0; | ||||
| // first handshake between StorageOp and Sampler. Base class init will call both GetNumRows and GetNumSamples | |||||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | |||||
| // setter function for num_rows_ | |||||
| Status SetNumRowsInDataset(int64_t num_rows); | |||||
| // setter function for num_samples_ | |||||
| Status SetNumSamples(int64_t num_samples); | |||||
| // first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples | |||||
| // @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds() | |||||
| // @return | // @return | ||||
| virtual Status Init(const RandomAccessOp *op); | |||||
| virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); | |||||
| // initialize sampler and perform checks on certain vars | |||||
| virtual Status InitSampler() { return Status::OK(); } | |||||
| // Not meant to be called | // Not meant to be called | ||||
| // @return | // @return | ||||
| @@ -41,9 +41,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status SequentialSampler::Init(const RandomAccessOp *op) { | |||||
| RETURN_UNEXPECTED_IF_NULL(op); | |||||
| RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); | |||||
| Status SequentialSampler::InitSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); | CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); | ||||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -32,10 +32,8 @@ class SequentialSampler : public Sampler { | |||||
| // Destructor. | // Destructor. | ||||
| ~SequentialSampler() = default; | ~SequentialSampler() = default; | ||||
| // Initialize the sampler. | |||||
| // @param op | |||||
| // @return Status | |||||
| Status Init(const RandomAccessOp *op) override; | |||||
| // init sampler, called by python | |||||
| Status InitSampler() override; | |||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | // @return - The error code return | ||||
| @@ -31,9 +31,8 @@ SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, in | |||||
| : Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | : Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | ||||
| // Initialized this Sampler. | // Initialized this Sampler. | ||||
| Status SubsetRandomSampler::Init(const RandomAccessOp *op) { | |||||
| // Calling base class init. | |||||
| RETURN_IF_NOT_OK(Sampler::Init(op)); | |||||
| Status SubsetRandomSampler::InitSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); | |||||
| // Initialize random generator with seed from config manager | // Initialize random generator with seed from config manager | ||||
| rand_gen_.seed(GetSeed()); | rand_gen_.seed(GetSeed()); | ||||
| @@ -38,9 +38,8 @@ class SubsetRandomSampler : public Sampler { | |||||
| ~SubsetRandomSampler() = default; | ~SubsetRandomSampler() = default; | ||||
| // Initialize the sampler. | // Initialize the sampler. | ||||
| // @param op (Not used in this sampler) | |||||
| // @return Status | // @return Status | ||||
| Status Init(const RandomAccessOp *op) override; | |||||
| Status InitSampler() override; | |||||
| // Reset the internal variable to the initial state and reshuffle the indices. | // Reset the internal variable to the initial state and reshuffle the indices. | ||||
| // @return Status | // @return Status | ||||
| @@ -29,21 +29,21 @@ namespace dataset { | |||||
| // Constructor. | // Constructor. | ||||
| WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement, | WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement, | ||||
| int64_t samples_per_buffer) | int64_t samples_per_buffer) | ||||
| : Sampler(samples_per_buffer), weights_(weights), replacement_(replacement), sample_id_(0), buffer_id_(0) { | |||||
| num_samples_ = num_samples; // this variable is defined in base class sampler | |||||
| } | |||||
| : Sampler(samples_per_buffer), | |||||
| weights_(weights), | |||||
| replacement_(replacement), | |||||
| sample_id_(0), | |||||
| buffer_id_(0), | |||||
| user_num_samples_(num_samples) {} | |||||
| // Initialized this Sampler. | // Initialized this Sampler. | ||||
| Status WeightedRandomSampler::Init(const RandomAccessOp *op) { | |||||
| RETURN_UNEXPECTED_IF_NULL(op); | |||||
| RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); | |||||
| Status WeightedRandomSampler::InitSampler() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); | |||||
| // Initialize random generator with seed from config manager | // Initialize random generator with seed from config manager | ||||
| rand_gen_.seed(GetSeed()); | rand_gen_.seed(GetSeed()); | ||||
| samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init WeightedRandomSampler"); | |||||
| samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_; | |||||
| if (!replacement_) { | if (!replacement_) { | ||||
| exp_dist_ = std::make_unique<std::exponential_distribution<>>(1); | exp_dist_ = std::make_unique<std::exponential_distribution<>>(1); | ||||
| @@ -65,8 +65,8 @@ void WeightedRandomSampler::InitOnePassSampling() { | |||||
| } | } | ||||
| // Partial sort the first `numSamples` elements. | // Partial sort the first `numSamples` elements. | ||||
| std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); | |||||
| for (int64_t i = 0; i < num_samples_; i++) { | |||||
| std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end()); | |||||
| for (int64_t i = 0; i < user_num_samples_; i++) { | |||||
| onepass_ids_.push_back(val_idx[i].second); | onepass_ids_.push_back(val_idx[i].second); | ||||
| } | } | ||||
| } | } | ||||
| @@ -91,11 +91,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf | |||||
| "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); | "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); | ||||
| } | } | ||||
| if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { | |||||
| if (!replacement_ && (weights_.size() < static_cast<size_t>(user_num_samples_))) { | |||||
| RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); | RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); | ||||
| } | } | ||||
| if (sample_id_ == num_samples_) { | |||||
| if (sample_id_ == user_num_samples_) { | |||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | ||||
| } else { | } else { | ||||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagNone); | ||||
| @@ -103,8 +103,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf | |||||
| int64_t last_id = sample_id_ + samples_per_buffer_; | int64_t last_id = sample_id_ + samples_per_buffer_; | ||||
| // Handling the return all samples at once, and when last draw is not a full batch. | // Handling the return all samples at once, and when last draw is not a full batch. | ||||
| if (last_id > num_samples_) { | |||||
| last_id = num_samples_; | |||||
| if (last_id > user_num_samples_) { | |||||
| last_id = user_num_samples_; | |||||
| } | } | ||||
| // Allocate tensor. | // Allocate tensor. | ||||
| @@ -43,7 +43,7 @@ class WeightedRandomSampler : public Sampler { | |||||
| // Initialize the sampler. | // Initialize the sampler. | ||||
| // @param op (Not used in this sampler) | // @param op (Not used in this sampler) | ||||
| // @return Status | // @return Status | ||||
| Status Init(const RandomAccessOp *op) override; | |||||
| Status InitSampler() override; | |||||
| // Reset the internal variable to the initial state and reshuffle the indices. | // Reset the internal variable to the initial state and reshuffle the indices. | ||||
| Status Reset() override; | Status Reset() override; | ||||
| @@ -69,6 +69,9 @@ class WeightedRandomSampler : public Sampler { | |||||
| // Random engine and device | // Random engine and device | ||||
| std::mt19937 rand_gen_; | std::mt19937 rand_gen_; | ||||
| // num_samples from user | |||||
| int64_t user_num_samples_; | |||||
| // Discrete distribution for generating weighted random numbers with replacement. | // Discrete distribution for generating weighted random numbers with replacement. | ||||
| std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_; | std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_; | ||||
| @@ -220,7 +220,7 @@ Status VOCOp::ParseImageIds() { | |||||
| } | } | ||||
| Status VOCOp::InitSampler() { | Status VOCOp::InitSampler() { | ||||
| RETURN_IF_NOT_OK(sampler_->Init(this)); | |||||
| RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1748,14 +1748,70 @@ class MindDataset(SourceDataset): | |||||
| return num_rows | return num_rows | ||||
| def ds_fn(dataset): | |||||
| for val in dataset: | |||||
| # convert output tensors to ndarrays | |||||
| yield tuple([np.array(x) for x in val]) | |||||
| def _iter_fn(dataset, num_samples): | |||||
| """ | |||||
| Generator function wrapper for iterable dataset | |||||
| """ | |||||
| if num_samples is not None: | |||||
| ds_iter = iter(dataset) | |||||
| for _ in range(num_samples): | |||||
| try: | |||||
| val = next(ds_iter) | |||||
| except StopIteration: | |||||
| return | |||||
| # convert output tensors to ndarrays | |||||
| yield tuple([np.array(x) for x in val]) | |||||
| else: | |||||
| for val in dataset: | |||||
| # convert output tensors to ndarrays | |||||
| yield tuple([np.array(x) for x in val]) | |||||
| def _generator_fn(generator, num_samples): | |||||
| """ | |||||
| Generator function wrapper for generator function dataset | |||||
| """ | |||||
| if num_samples is not None: | |||||
| gen_iter = generator() | |||||
| for _ in range(num_samples): | |||||
| try: | |||||
| val = next(gen_iter) | |||||
| except StopIteration: | |||||
| return | |||||
| yield val | |||||
| else: | |||||
| gen_iter = generator() | |||||
| for val in gen_iter: | |||||
| yield val | |||||
| def sampler_fn(sampler, dataset): | |||||
| for i in sampler: | |||||
| def _py_sampler_fn(sampler, num_samples, dataset): | |||||
| """ | |||||
| Generator function wrapper for mappable dataset with python sampler | |||||
| """ | |||||
| if num_samples is not None: | |||||
| sampler_iter = iter(sampler) | |||||
| for _ in range(num_samples): | |||||
| try: | |||||
| idx = next(sampler_iter) | |||||
| except StopIteration: | |||||
| return | |||||
| val = dataset[idx] | |||||
| # convert output tensors to ndarrays | |||||
| yield tuple([np.array(x) for x in val]) | |||||
| else: | |||||
| for i in sampler: | |||||
| val = dataset[i] | |||||
| # convert output tensors to ndarrays | |||||
| yield tuple([np.array(x) for x in val]) | |||||
| def _cpp_sampler_fn(sampler, dataset): | |||||
| """ | |||||
| Generator function wrapper for mappable dataset with cpp sampler | |||||
| """ | |||||
| indices = sampler.get_indices() | |||||
| for i in indices: | |||||
| val = dataset[i] | val = dataset[i] | ||||
| # convert output tensors to ndarrays | # convert output tensors to ndarrays | ||||
| yield tuple([np.array(x) for x in val]) | yield tuple([np.array(x) for x in val]) | ||||
| @@ -1763,49 +1819,122 @@ def sampler_fn(sampler, dataset): | |||||
| class GeneratorDataset(SourceDataset): | class GeneratorDataset(SourceDataset): | ||||
| """ | """ | ||||
| A source dataset that generate data from calling generator function each epoch. | |||||
| A source dataset that generate data from python by invoking python data source each epoch. | |||||
| This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table | |||||
| below shows what input args are allowed and their expected behavior. | |||||
| .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle' | |||||
| :widths: 25 25 50 | |||||
| :header-rows: 1 | |||||
| * - Parameter 'sampler' | |||||
| - Parameter 'shuffle' | |||||
| - Expected Order Behavior | |||||
| * - None | |||||
| - None | |||||
| - random order | |||||
| * - None | |||||
| - True | |||||
| - random order | |||||
| * - None | |||||
| - False | |||||
| - sequential order | |||||
| * - Sampler object | |||||
| - None | |||||
| - order defined by sampler | |||||
| * - Sampler object | |||||
| - True | |||||
| - not allowed | |||||
| * - Sampler object | |||||
| - False | |||||
| - not allowed | |||||
| Args: | Args: | ||||
| generator_function (callable): | |||||
| A callable object that returns an Generator object that supports the iter() protocol. | |||||
| Generator object is required to return a tuple of numpy array as a row of the dataset on next(). | |||||
| source (Callable/Iterable/Random Accessible): | |||||
| A generator callable object, an iterable python object or a random accessible python object. | |||||
| Callable source is required to return a tuple of numpy array as a row of the dataset on source().next(). | |||||
| Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next(). | |||||
| Random accessible source is required to return a tuple of numpy array as a row of the dataset on | |||||
| source[idx]. | |||||
| column_names (list[str]): List of column names of the dataset. | column_names (list[str]): List of column names of the dataset. | ||||
| column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None). | column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None). | ||||
| If provided, sanity check will be performed on generator output. | If provided, sanity check will be performed on generator output. | ||||
| prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None). | |||||
| sampler (Sampler, optional): Object used to choose samples from the dataset (default=None). | |||||
| schema (Schema/String, optional): Path to the json schema file or schema object (default=None). | |||||
| If the schema is not provided, the meta data from column_names and column_types is considered the schema. | |||||
| num_samples (int, optional): The number of samples to be included in the dataset | |||||
| (default=None, all images). | |||||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. | |||||
| (default=None, expected order behavior shown in the table). | |||||
| sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is | |||||
| required. | |||||
| (default=None, expected order behavior shown in the table). | |||||
| num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). | |||||
| This argument should be specified only when 'num_samples' is "None". Random accessible input is required. | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only | |||||
| when num_shards is also specified. Random accessible input is required. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | |||||
| >>> # 1) generator function that generates multi-dimensional data | |||||
| >>> import mindspore.dataengine as de | |||||
| >>> # 1) Multidimensional generator function as callable input | |||||
| >>> def generator_md(): | >>> def generator_md(): | ||||
| >>> for i in range(64): | >>> for i in range(64): | ||||
| >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),) | >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),) | ||||
| >>> # create multi_dimension_generator_dataset with GeneratorMD() and column name "multi_dimensional_data" | |||||
| >>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"]) | |||||
| >>> # 2) generator function that generates multi-columns data | |||||
| >>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data" | |||||
| >>> multi_dimension_generator_dataset = de.GeneratorDataset(generator_md, ["multi_dimensional_data"]) | |||||
| >>> # 2) Multi-column generator function as callable input | |||||
| >>> def generator_mc(maxid = 64): | >>> def generator_mc(maxid = 64): | ||||
| >>> for i in range(maxid): | >>> for i in range(maxid): | ||||
| >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) | >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) | ||||
| >>> # create multi_column_generator_dataset with GeneratorMC() and column names "col1" and "col2" | |||||
| >>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1, col2"]) | |||||
| >>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2" | |||||
| >>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"]) | |||||
| >>> # 3) Iterable dataset as iterable input | |||||
| >>> class MyIterable(): | |||||
| >>> def __iter__(self): | |||||
| >>> return # User implementation | |||||
| >>> # create iterable_generator_dataset with MyIterable object | |||||
| >>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"]) | |||||
| >>> # 4) Random accessible dataset as Random accessible input | |||||
| >>> class MyRA(): | |||||
| >>> def __getitem__(self, index): | |||||
| >>> return # User implementation | |||||
| >>> # create ra_generator_dataset with MyRA object | |||||
| >>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"]) | |||||
| >>> # List/Dict/Tuple is also random accessible | |||||
| >>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"]) | |||||
| >>> # 5) Built-in Sampler | |||||
| >>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler()) | |||||
| >>> | |||||
| """ | """ | ||||
| @check_generatordataset | @check_generatordataset | ||||
| def __init__(self, generator_function, column_names, column_types=None, prefetch_size=None, sampler=None): | |||||
| super().__init__(1) | |||||
| if sampler is not None: | |||||
| self.generator_function = (lambda: sampler_fn(sampler, generator_function)) | |||||
| def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, | |||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||||
| super().__init__(num_parallel_workers) | |||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | |||||
| if self.sampler is not None and hasattr(source, "__getitem__"): | |||||
| if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||||
| samplers.WeightedRandomSampler)): | |||||
| if num_samples is None: | |||||
| num_samples = len(source) | |||||
| sampler_instance = self.sampler.create() | |||||
| sampler_instance.set_num_rows(len(source)) | |||||
| sampler_instance.set_num_samples(num_samples) | |||||
| sampler_instance.initialize() | |||||
| self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) | |||||
| else: | |||||
| self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) | |||||
| else: | else: | ||||
| try: | try: | ||||
| # test to see if generator_function is iterable | |||||
| iter(generator_function) | |||||
| iter(source) | |||||
| except TypeError: | except TypeError: | ||||
| # generator_function was not iterable, assume it is a function | |||||
| self.generator_function = generator_function | |||||
| # Use generator function if input callable | |||||
| self.source = (lambda: _generator_fn(source, num_samples)) | |||||
| else: | else: | ||||
| # generator_function was iterable, build a function around it | |||||
| self.generator_function = (lambda: ds_fn(generator_function)) | |||||
| # Use iterator function if input is iterable | |||||
| # Random accessible input is also iterable | |||||
| self.source = (lambda: _iter_fn(source, num_samples)) | |||||
| self.column_names = column_names | self.column_names = column_names | ||||
| @@ -1813,17 +1942,12 @@ class GeneratorDataset(SourceDataset): | |||||
| self.column_types = mstypelist_to_detypelist(column_types) | self.column_types = mstypelist_to_detypelist(column_types) | ||||
| else: | else: | ||||
| self.column_types = column_types | self.column_types = column_types | ||||
| self.distribution = "" | |||||
| self.prefetch_size = prefetch_size | |||||
| self.sampler = sampler | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["generator_function"] = self.generator_function | |||||
| args["source"] = self.source | |||||
| args["column_names"] = self.column_names | args["column_names"] = self.column_names | ||||
| args["column_types"] = self.column_types | args["column_types"] = self.column_types | ||||
| args["prefetch_size"] = self.prefetch_size | |||||
| args["sampler"] = self.sampler | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -20,7 +20,6 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. | |||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| class DistributedSampler(): | class DistributedSampler(): | ||||
| """ | """ | ||||
| Sampler that access a shard of the dataset. | Sampler that access a shard of the dataset. | ||||
| @@ -543,28 +543,48 @@ def check_generatordataset(method): | |||||
| def new_method(*args, **kwargs): | def new_method(*args, **kwargs): | ||||
| param_dict = make_param_dict(method, args, kwargs) | param_dict = make_param_dict(method, args, kwargs) | ||||
| nreq_param_int = ['prefetch_size'] | |||||
| nreq_param_list = ['column_names', 'column_types'] | |||||
| # check generator_function; required argument | # check generator_function; required argument | ||||
| generator_function = param_dict.get('generator_function') | |||||
| if generator_function is None: | |||||
| raise ValueError("generator_function is not provided.") | |||||
| source = param_dict.get('source') | |||||
| if source is None: | |||||
| raise ValueError("source is not provided.") | |||||
| if not callable(source): | |||||
| try: | |||||
| iter(source) | |||||
| except TypeError: | |||||
| raise TypeError("source should be callable, iterable or random accessible") | |||||
| # check column_names; required argument | # check column_names; required argument | ||||
| column_names = param_dict.get('column_names') | column_names = param_dict.get('column_names') | ||||
| if column_names is None: | if column_names is None: | ||||
| raise ValueError("column_names is not provided.") | raise ValueError("column_names is not provided.") | ||||
| # check prefetch_size range | |||||
| prefetch_size = param_dict.get('prefetch_size') | |||||
| if prefetch_size is not None and (prefetch_size <= 0 or prefetch_size > 1024): | |||||
| raise ValueError("prefetch_size exceeds the boundary.") | |||||
| # check optional argument | |||||
| nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"] | |||||
| check_param_type(nreq_param_int, param_dict, int) | check_param_type(nreq_param_int, param_dict, int) | ||||
| nreq_param_list = ["column_types"] | |||||
| check_param_type(nreq_param_list, param_dict, list) | check_param_type(nreq_param_list, param_dict, list) | ||||
| num_shards = param_dict.get("num_shards") | |||||
| shard_id = param_dict.get("shard_id") | |||||
| if (num_shards is None) != (shard_id is None): | |||||
| # These two parameters appear together. | |||||
| raise ValueError("num_shards and shard_id need to be passed in together") | |||||
| if num_shards is not None: | |||||
| if shard_id >= num_shards: | |||||
| raise ValueError("shard_id should be less than num_shards") | |||||
| sampler = param_dict.get("sampler") | |||||
| if sampler is not None: | |||||
| if isinstance(sampler, samplers.PKSampler): | |||||
| raise ValueError("PKSampler is not supported by GeneratorDataset") | |||||
| if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||||
| samplers.WeightedRandomSampler)): | |||||
| try: | |||||
| iter(sampler) | |||||
| except TypeError: | |||||
| raise TypeError("sampler should be either iterable or from dataset.samplers.py") | |||||
| return method(*args, **kwargs) | return method(*args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -75,7 +75,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { | |||||
| std::shared_ptr<Tensor> tensor; | std::shared_ptr<Tensor> tensor; | ||||
| for (int i = 0; i < 6; i++) { | for (int i = 0; i < 6; i++) { | ||||
| std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true)); | std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true)); | ||||
| sampler->Init(&mock); | |||||
| sampler->HandshakeRandomAccessOp(&mock); | |||||
| sampler->GetNextBuffer(&db); | sampler->GetNextBuffer(&db); | ||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| MS_LOG(DEBUG) << (*tensor); | MS_LOG(DEBUG) << (*tensor); | ||||
| @@ -95,7 +95,7 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { | |||||
| std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3); | std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3); | ||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| std::shared_ptr<Tensor> tensor; | std::shared_ptr<Tensor> tensor; | ||||
| sampler->Init(&mock); | |||||
| sampler->HandshakeRandomAccessOp(&mock); | |||||
| sampler->GetNextBuffer(&db); | sampler->GetNextBuffer(&db); | ||||
| db->GetTensor(&tensor, 0, 0); | db->GetTensor(&tensor, 0, 0); | ||||
| EXPECT_TRUE((*tensor) == (*label1)); | EXPECT_TRUE((*tensor) == (*label1)); | ||||
| @@ -52,8 +52,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||||
| std::unordered_set<int64_t> in_set(in.begin(), in.end()); | std::unordered_set<int64_t> in_set(in.begin(), in.end()); | ||||
| SubsetRandomSampler sampler(in); | SubsetRandomSampler sampler(in); | ||||
| DummyRandomAccessOp dummy_random_access_op(5); | |||||
| sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(5); | |||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -80,8 +80,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||||
| std::vector<int64_t> input(total_samples, 1); | std::vector<int64_t> input(total_samples, 1); | ||||
| SubsetRandomSampler sampler(input, samples_per_buffer); | SubsetRandomSampler sampler(input, samples_per_buffer); | ||||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||||
| sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -111,8 +111,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||||
| std::unordered_set<int64_t> in_set(in.begin(), in.end()); | std::unordered_set<int64_t> in_set(in.begin(), in.end()); | ||||
| SubsetRandomSampler sampler(in); | SubsetRandomSampler sampler(in); | ||||
| DummyRandomAccessOp dummy_random_access_op(5); | |||||
| sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(5); | |||||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -60,8 +60,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||||
| // create sampler with replacement = true | // create sampler with replacement = true | ||||
| WeightedRandomSampler m_sampler(weights, num_samples, true); | WeightedRandomSampler m_sampler(weights, num_samples, true); | ||||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||||
| m_sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -90,8 +90,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||||
| // create sampler with replacement = replacement | // create sampler with replacement = replacement | ||||
| WeightedRandomSampler m_sampler(weights, num_samples, false); | WeightedRandomSampler m_sampler(weights, num_samples, false); | ||||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||||
| m_sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -126,8 +126,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | |||||
| // create sampler with replacement = replacement | // create sampler with replacement = replacement | ||||
| WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer); | WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer); | ||||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||||
| m_sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -162,8 +162,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { | |||||
| // create sampler with replacement = replacement | // create sampler with replacement = replacement | ||||
| WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer); | WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer); | ||||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||||
| m_sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -203,8 +203,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||||
| // create sampler with replacement = true | // create sampler with replacement = true | ||||
| WeightedRandomSampler m_sampler(weights, num_samples, true); | WeightedRandomSampler m_sampler(weights, num_samples, true); | ||||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||||
| m_sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -248,8 +248,8 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||||
| // create sampler with replacement = true | // create sampler with replacement = true | ||||
| WeightedRandomSampler m_sampler(weights, num_samples, false); | WeightedRandomSampler m_sampler(weights, num_samples, false); | ||||
| DummyRandomAccessOp dummy_random_access_op(total_samples); | |||||
| m_sampler.Init(&dummy_random_access_op); | |||||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||||
| std::unique_ptr<DataBuffer> db; | std::unique_ptr<DataBuffer> db; | ||||
| TensorRow row; | TensorRow row; | ||||
| @@ -439,6 +439,74 @@ def test_case_error_4(): | |||||
| assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) | assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) | ||||
| def test_sequential_sampler(): | |||||
| source = [(np.array([x]),) for x in range(64)] | |||||
| ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler()) | |||||
| i = 0 | |||||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||||
| golden = np.array([i]) | |||||
| assert np.array_equal(data["data"], golden) | |||||
| i = i + 1 | |||||
| def test_random_sampler(): | |||||
| source = [(np.array([x]),) for x in range(64)] | |||||
| ds1 = ds.GeneratorDataset(source, ["data"], shuffle = True) | |||||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||||
| pass | |||||
| def test_distributed_sampler(): | |||||
| source = [(np.array([x]),) for x in range(64)] | |||||
| for sid in range(8): | |||||
| ds1 = ds.GeneratorDataset(source, ["data"], shuffle = False, num_shards=8, shard_id=sid) | |||||
| i = sid | |||||
| for data in ds1.create_dict_iterator(): # each data is a dictionary | |||||
| golden = np.array([i]) | |||||
| assert np.array_equal(data["data"], golden) | |||||
| i = i + 8 | |||||
| def test_num_samples(): | |||||
| source = [(np.array([x]),) for x in range(64)] | |||||
| num_samples = 32 | |||||
| ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples = num_samples) | |||||
| ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples = num_samples) | |||||
| ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples) | |||||
| count = 0 | |||||
| for _ in ds1.create_dict_iterator(): | |||||
| count = count + 1 | |||||
| assert count == num_samples | |||||
| count = 0 | |||||
| for _ in ds2.create_dict_iterator(): | |||||
| count = count + 1 | |||||
| assert count == num_samples | |||||
| count = 0 | |||||
| for _ in ds3.create_dict_iterator(): | |||||
| count = count + 1 | |||||
| assert count == num_samples | |||||
| def test_num_samples_underflow(): | |||||
| source = [(np.array([x]),) for x in range(64)] | |||||
| num_samples = 256 | |||||
| ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples = num_samples) | |||||
| ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples = num_samples) | |||||
| count = 0 | |||||
| for _ in ds2.create_dict_iterator(): | |||||
| count = count + 1 | |||||
| assert count == 64 | |||||
| count = 0 | |||||
| for _ in ds3.create_dict_iterator(): | |||||
| count = count + 1 | |||||
| assert count == 64 | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_case_0() | test_case_0() | ||||
| test_case_1() | test_case_1() | ||||
| @@ -458,3 +526,6 @@ if __name__ == "__main__": | |||||
| test_case_error_2() | test_case_error_2() | ||||
| test_case_error_3() | test_case_error_3() | ||||
| test_case_error_4() | test_case_error_4() | ||||
| test_sequential_sampler() | |||||
| test_distributed_sampler() | |||||
| test_random_sampler() | |||||
| @@ -87,7 +87,28 @@ def test_random_sampler_multi_iter(print_res=False): | |||||
| test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5]) | test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5]) | ||||
| def test_sampler_py_api(): | |||||
| sampler = ds.SequentialSampler().create() | |||||
| sampler.set_num_rows(128) | |||||
| sampler.set_num_samples(64) | |||||
| sampler.initialize() | |||||
| sampler.get_indices() | |||||
| sampler = ds.RandomSampler().create() | |||||
| sampler.set_num_rows(128) | |||||
| sampler.set_num_samples(64) | |||||
| sampler.initialize() | |||||
| sampler.get_indices() | |||||
| sampler = ds.DistributedSampler(8, 4).create() | |||||
| sampler.set_num_rows(128) | |||||
| sampler.set_num_samples(64) | |||||
| sampler.initialize() | |||||
| sampler.get_indices() | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_sequential_sampler(True) | test_sequential_sampler(True) | ||||
| test_random_sampler(True) | test_random_sampler(True) | ||||
| test_random_sampler_multi_iter(True) | test_random_sampler_multi_iter(True) | ||||
| test_sampler_py_api() | |||||