| @@ -53,6 +53,7 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/python_sampler.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/jagged_connector.h" | |||
| #include "dataset/kernels/data/to_float16_op.h" | |||
| @@ -415,6 +416,7 @@ void bindSamplerOps(py::module *m) { | |||
| (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler") | |||
| .def(py::init<>()); | |||
| (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler") | |||
| .def(py::init<std::vector<int64_t>>(), py::arg("indices")); | |||
| @@ -425,6 +427,9 @@ void bindSamplerOps(py::module *m) { | |||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | |||
| .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), | |||
| py::arg("replacement")); | |||
| (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler") | |||
| .def(py::init<py::object>(), py::arg("pySampler")); | |||
| } | |||
| void bindInfoObjects(py::module *m) { | |||
| @@ -1,6 +1,7 @@ | |||
| add_library(engine-datasetops-source-sampler OBJECT | |||
| distributed_sampler.cc | |||
| pk_sampler.cc | |||
| python_sampler.cc | |||
| random_sampler.cc | |||
| sampler.cc | |||
| sequential_sampler.cc | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/source/sampler/python_sampler.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer) | |||
| : Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||
| Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (need_to_reset_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone); | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| py::object py_ret = py_sampler_instance.attr("_get_indices")(); | |||
| py::array np_sample_ids = py_ret.cast<py::array>(); | |||
| Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| } | |||
| TensorRow row(1, sample_ids); | |||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||
| need_to_reset_ = true; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status PythonSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status PythonSampler::Reset() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); | |||
| need_to_reset_ = false; | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| py_sampler_instance.attr("reset")(); | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ | |||
| #include <limits> | |||
| #include <memory> | |||
| #include "dataset/engine/datasetops/source/sampler/sampler.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class PythonSampler : public Sampler { | |||
| public: | |||
| // Constructor | |||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit PythonSampler(py::object py_sampler_instance, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~PythonSampler() = default; | |||
| // Initialize the sampler. | |||
| // @return Status | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| Status Reset() override; | |||
| // Op calls this to get next Buffer that contains all the sampleIds | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| private: | |||
| bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | |||
| py::object py_sampler_instance; // The handle to the py_sampler python object | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ | |||
| @@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||
| std::unique_ptr<DataBuffer> db; | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| // check samples_per_buffer is properly set and doesn't overflow | |||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid"); | |||
| // A call to derived class to get sample ids wrapped inside a buffer | |||
| RETURN_IF_NOT_OK(GetNextBuffer(&db)); | |||
| // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch | |||
| @@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) | |||
| } | |||
| Status SequentialSampler::InitSampler() { | |||
| num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| return Status::OK(); | |||
| @@ -23,7 +23,7 @@ from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDataset | |||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \ | |||
| Shuffle, zip | |||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | |||
| WeightedRandomSampler | |||
| WeightedRandomSampler, Sampler | |||
| from .engine.serializer_deserializer import serialize, deserialize, show | |||
| __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", | |||
| @@ -2032,7 +2032,7 @@ class GeneratorDataset(SourceDataset): | |||
| if self.sampler is not None and hasattr(source, "__getitem__"): | |||
| if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler)): | |||
| samplers.WeightedRandomSampler, samplers.Sampler)): | |||
| if num_samples is None: | |||
| num_samples = len(source) | |||
| sampler_instance = self.sampler.create() | |||
| @@ -16,11 +16,90 @@ | |||
| Sampler module provides several samplers to generate sampling data from dataset. | |||
| There are following samplers: DistributedSampler, PKSampler, RandomSampler, | |||
| SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. | |||
| User can also define custom sampler by extending from Sampler class. | |||
| """ | |||
| import mindspore._c_dataengine as cde | |||
| import numpy as np | |||
| class DistributedSampler(): | |||
| class Sampler: | |||
| """ | |||
| Base class for user defined sampler. | |||
| User defined sampler can be used with any existing dataset with sampler support. | |||
| An required _iter_() method should by overridden by user for sample index generation. | |||
| An optional reset() method can be overridden for per repeat reset, | |||
| dataset_size and num_samples will be set by dataset once a dataset iterator is created. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> | |||
| >>> class ReverseSampler(ds,Sampler): | |||
| >>> def __iter__(self): | |||
| >>> for i in range(self.dataset_size - 1, -1, -1): | |||
| >>> yield i | |||
| >>> | |||
| >>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler()) | |||
| """ | |||
| def __init__(self): | |||
| self.dataset_size = 0 | |||
| self.num_samples = 0 | |||
| def __iter__(self): | |||
| """ | |||
| User defined iterator, must be overridden. | |||
| _handshake is guaranteed to be called prior to iterator construction | |||
| """ | |||
| raise NotImplementedError | |||
| def reset(self): | |||
| """ | |||
| Per repeat reset callback, override this method if necessary | |||
| """ | |||
| # Initialization handshake callback | |||
| # Do not override this method! | |||
| def _handshake(self, ds_size, num_samples): | |||
| self.dataset_size = ds_size | |||
| self.num_samples = num_samples | |||
| # Indices fetcher | |||
| # Do not override this method! | |||
| def _get_indices(self): | |||
| sampler_iter = iter(self) | |||
| ret = [] | |||
| for _ in range(self.num_samples): | |||
| try: | |||
| idx = next(sampler_iter) | |||
| ret.append(idx) | |||
| except StopIteration: | |||
| break | |||
| return np.array(ret) | |||
| # Instance fetcher | |||
| # Do not override this method! | |||
| def create(self): | |||
| return cde.PythonSampler(self) | |||
| class BuiltinSampler: | |||
| """ | |||
| Base class for BuiltinSampler. | |||
| User should not extend this class. | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| def create(self): | |||
| pass | |||
| class DistributedSampler(BuiltinSampler): | |||
| """ | |||
| Sampler that access a shard of the dataset. | |||
| @@ -65,7 +144,7 @@ class DistributedSampler(): | |||
| return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) | |||
| class PKSampler(): | |||
| class PKSampler(BuiltinSampler): | |||
| """ | |||
| Samples K elements for each P class in the dataset. | |||
| @@ -106,7 +185,7 @@ class PKSampler(): | |||
| return cde.PKSampler(self.num_val, self.shuffle) | |||
| class RandomSampler(): | |||
| class RandomSampler(BuiltinSampler): | |||
| """ | |||
| Samples the elements randomly. | |||
| @@ -147,7 +226,7 @@ class RandomSampler(): | |||
| return cde.RandomSampler(self.replacement, self.num_samples) | |||
| class SequentialSampler(): | |||
| class SequentialSampler(BuiltinSampler): | |||
| """ | |||
| Samples the dataset elements sequentially, same as not having a sampler. | |||
| @@ -165,7 +244,7 @@ class SequentialSampler(): | |||
| return cde.SequentialSampler() | |||
| class SubsetRandomSampler(): | |||
| class SubsetRandomSampler(BuiltinSampler): | |||
| """ | |||
| Samples the elements randomly from a sequence of indices. | |||
| @@ -196,7 +275,8 @@ class SubsetRandomSampler(): | |||
| def _create_for_minddataset(self): | |||
| return cde.MindrecordSubsetRandomSampler(self.indices) | |||
| class WeightedRandomSampler(): | |||
| class WeightedRandomSampler(BuiltinSampler): | |||
| """ | |||
| Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). | |||
| @@ -297,9 +297,7 @@ def check_sampler_shuffle_shard_options(param_dict): | |||
| shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') | |||
| num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') | |||
| if sampler is not None and not isinstance(sampler, ( | |||
| samplers.DistributedSampler, samplers.PKSampler, samplers.RandomSampler, samplers.SequentialSampler, | |||
| samplers.SubsetRandomSampler, samplers.WeightedRandomSampler)): | |||
| if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): | |||
| raise ValueError("sampler is not a valid Sampler type.") | |||
| if sampler is not None: | |||
| @@ -579,11 +577,11 @@ def check_generatordataset(method): | |||
| raise ValueError("PKSampler is not supported by GeneratorDataset") | |||
| if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler)): | |||
| samplers.WeightedRandomSampler, samplers.Sampler)): | |||
| try: | |||
| iter(sampler) | |||
| except TypeError: | |||
| raise TypeError("sampler should be either iterable or from dataset.samplers.py") | |||
| raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers") | |||
| return method(*args, **kwargs) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| import numpy as np | |||
| # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] | |||
| @@ -107,8 +108,64 @@ def test_sampler_py_api(): | |||
| sampler.get_indices() | |||
| def test_python_sampler(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| class Sp1(ds.Sampler): | |||
| def __iter__(self): | |||
| return iter([i for i in range(self.dataset_size)]) | |||
| class Sp2(ds.Sampler): | |||
| def __init__(self): | |||
| super(Sp2, self).__init__() | |||
| # at this stage, self.dataset_size and self.num_samples are not yet known | |||
| self.cnt = 0 | |||
| def __iter__(self): # first epoch, all 0, second epoch all 1, third all 2 etc.. ... | |||
| return iter([self.cnt for i in range(self.num_samples)]) | |||
| def reset(self): | |||
| self.cnt = (self.cnt + 1) % self.dataset_size | |||
| def test_config(num_samples, num_repeats, sampler): | |||
| data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler) | |||
| if num_repeats is not None: | |||
| data1 = data1.repeat(num_repeats) | |||
| res = [] | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("item[image].shape[0]: {}, item[label].item(): {}" | |||
| .format(item["image"].shape[0], item["label"].item())) | |||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||
| # print(res) | |||
| return res | |||
| def test_generator(): | |||
| class MySampler(ds.Sampler): | |||
| def __iter__(self): | |||
| for i in range(99, -1, -1): | |||
| yield i | |||
| data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler = MySampler()) | |||
| i = 99 | |||
| for data in data1: | |||
| assert data[0] == (np.array(i),) | |||
| i = i - 1 | |||
| assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] | |||
| assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] | |||
| test_generator() | |||
| sp1 = Sp1().create() | |||
| sp1.set_num_rows(5) | |||
| sp1.set_num_samples(5) | |||
| sp1.initialize() | |||
| assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] | |||
| if __name__ == '__main__': | |||
| test_sequential_sampler(True) | |||
| test_random_sampler(True) | |||
| test_random_sampler_multi_iter(True) | |||
| test_sampler_py_api() | |||
| test_python_sampler() | |||