From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -7,11 +7,11 @@ if(ENABLE_PYTHON) | |||
| python/bindings/dataset/engine/cache/bindings.cc | |||
| python/bindings/dataset/engine/datasetops/bindings.cc | |||
| python/bindings/dataset/engine/datasetops/source/bindings.cc | |||
| python/bindings/dataset/engine/datasetops/source/sampler/bindings.cc | |||
| python/bindings/dataset/engine/gnn/bindings.cc | |||
| python/bindings/dataset/include/datasets_bindings.cc | |||
| python/bindings/dataset/include/iterator_bindings.cc | |||
| python/bindings/dataset/include/execute_binding.cc | |||
| python/bindings/dataset/include/sampler_bindings.cc | |||
| python/bindings/dataset/include/schema_bindings.cc | |||
| python/bindings/dataset/kernels/bindings.cc | |||
| python/bindings/dataset/kernels/data/bindings.cc | |||
| @@ -1,93 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) { | |||
| (void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler") | |||
| .def("set_num_rows", | |||
| [](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) | |||
| .def("set_num_samples", | |||
| [](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) | |||
| .def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); }) | |||
| .def("get_indices", | |||
| [](SamplerRT &self) { | |||
| py::array ret; | |||
| THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); | |||
| return ret; | |||
| }) | |||
| .def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) { | |||
| THROW_IF_ERROR(self->AddChild(child)); | |||
| }); | |||
| })); | |||
| PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>( | |||
| *m, "DistributedSampler") | |||
| .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>()); | |||
| })); | |||
| PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler") | |||
| .def(py::init<int64_t, int64_t, bool>()); | |||
| })); | |||
| PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler") | |||
| .def(py::init<int64_t, py::object>()); | |||
| })); | |||
| PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler") | |||
| .def(py::init<int64_t, bool, bool>()); | |||
| })); | |||
| PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>( | |||
| *m, "SequentialSampler") | |||
| .def(py::init<int64_t, int64_t>()); | |||
| })); | |||
| PYBIND_REGISTER(SubsetRandomSamplerRT, 2, ([](const py::module *m) { | |||
| (void)py::class_<SubsetRandomSamplerRT, SubsetSamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>( | |||
| *m, "SubsetRandomSampler") | |||
| .def(py::init<int64_t, std::vector<int64_t>>()); | |||
| })); | |||
| PYBIND_REGISTER(SubsetSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<SubsetSamplerRT, SamplerRT, std::shared_ptr<SubsetSamplerRT>>(*m, "SubsetSampler") | |||
| .def(py::init<int64_t, std::vector<int64_t>>()); | |||
| })); | |||
| PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) { | |||
| (void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>( | |||
| *m, "WeightedRandomSampler") | |||
| .def(py::init<int64_t, std::vector<double>, bool>()); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,127 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| #include "pybind11/stl_bind.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" | |||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/callback/py_ds_callback.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER(SamplerObj, 1, ([](const py::module *m) { | |||
| (void)py::class_<SamplerObj, std::shared_ptr<SamplerObj>>(*m, "SamplerObj", "to create a SamplerObj") | |||
| .def("add_child", [](std::shared_ptr<SamplerObj> self, std::shared_ptr<SamplerObj> child) { | |||
| THROW_IF_ERROR(self->AddChildSampler(child)); | |||
| }); | |||
| })); | |||
| PYBIND_REGISTER(DistributedSamplerObj, 2, ([](const py::module *m) { | |||
| (void)py::class_<DistributedSamplerObj, SamplerObj, std::shared_ptr<DistributedSamplerObj>>( | |||
| *m, "DistributedSamplerObj", "to create a DistributedSamplerObj") | |||
| .def(py::init([](int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, | |||
| uint32_t seed, int64_t offset, bool even_dist) { | |||
| std::shared_ptr<DistributedSamplerObj> sampler = std::make_shared<DistributedSamplerObj>( | |||
| num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist); | |||
| THROW_IF_ERROR(sampler->ValidateParams()); | |||
| return sampler; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(PreBuiltSamplerObj, 2, ([](const py::module *m) { | |||
| (void)py::class_<PreBuiltSamplerObj, SamplerObj, std::shared_ptr<PreBuiltSamplerObj>>( | |||
| *m, "PreBuiltSamplerObj", "to create a PreBuiltSamplerObj") | |||
| .def(py::init([](int64_t num_samples, py::object sampler) { | |||
| auto sampler_rt = std::make_shared<PythonSamplerRT>(num_samples, sampler); | |||
| auto sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler_rt)); | |||
| THROW_IF_ERROR(sampler_obj->ValidateParams()); | |||
| return sampler_obj; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(PKSamplerObj, 2, ([](const py::module *m) { | |||
| (void)py::class_<PKSamplerObj, SamplerObj, std::shared_ptr<PKSamplerObj>>(*m, "PKSamplerObj", | |||
| "to create a PKSamplerObj") | |||
| .def(py::init([](int64_t num_val, bool shuffle, int64_t num_samples) { | |||
| std::shared_ptr<PKSamplerObj> sampler = | |||
| std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples); | |||
| THROW_IF_ERROR(sampler->ValidateParams()); | |||
| return sampler; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(RandomSamplerObj, 2, ([](const py::module *m) { | |||
| (void)py::class_<RandomSamplerObj, SamplerObj, std::shared_ptr<RandomSamplerObj>>( | |||
| *m, "RandomSamplerObj", "to create a RandomSamplerObj") | |||
| .def(py::init([](bool replacement, int64_t num_samples, bool reshuffle_each_epoch) { | |||
| std::shared_ptr<RandomSamplerObj> sampler = | |||
| std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch); | |||
| THROW_IF_ERROR(sampler->ValidateParams()); | |||
| return sampler; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(SequentialSamplerObj, 2, ([](const py::module *m) { | |||
| (void)py::class_<SequentialSamplerObj, SamplerObj, std::shared_ptr<SequentialSamplerObj>>( | |||
| *m, "SequentialSamplerObj", "to create a SequentialSamplerObj") | |||
| .def(py::init([](int64_t start_index, int64_t num_samples) { | |||
| std::shared_ptr<SequentialSamplerObj> sampler = | |||
| std::make_shared<SequentialSamplerObj>(start_index, num_samples); | |||
| THROW_IF_ERROR(sampler->ValidateParams()); | |||
| return sampler; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(SubsetSamplerObj, 2, ([](const py::module *m) { | |||
| (void)py::class_<SubsetSamplerObj, SamplerObj, std::shared_ptr<SubsetSamplerObj>>( | |||
| *m, "SubsetSamplerObj", "to create a SubsetSamplerObj") | |||
| .def(py::init([](std::vector<int64_t> indices, int64_t num_samples) { | |||
| std::shared_ptr<SubsetSamplerObj> sampler = | |||
| std::make_shared<SubsetSamplerObj>(indices, num_samples); | |||
| THROW_IF_ERROR(sampler->ValidateParams()); | |||
| return sampler; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(SubsetRandomSamplerObj, 3, ([](const py::module *m) { | |||
| (void)py::class_<SubsetRandomSamplerObj, SubsetSamplerObj, std::shared_ptr<SubsetRandomSamplerObj>>( | |||
| *m, "SubsetRandomSamplerObj", "to create a SubsetRandomSamplerObj") | |||
| .def(py::init([](std::vector<int64_t> indices, int64_t num_samples) { | |||
| std::shared_ptr<SubsetRandomSamplerObj> sampler = | |||
| std::make_shared<SubsetRandomSamplerObj>(indices, num_samples); | |||
| THROW_IF_ERROR(sampler->ValidateParams()); | |||
| return sampler; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(WeightedRandomSamplerObj, 2, ([](const py::module *m) { | |||
| (void)py::class_<WeightedRandomSamplerObj, SamplerObj, std::shared_ptr<WeightedRandomSamplerObj>>( | |||
| *m, "WeightedRandomSamplerObj", "to create a WeightedRandomSamplerObj") | |||
| .def(py::init([](std::vector<double> weights, int64_t num_samples, bool replacement) { | |||
| std::shared_ptr<WeightedRandomSamplerObj> sampler = | |||
| std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement); | |||
| THROW_IF_ERROR(sampler->ValidateParams()); | |||
| return sampler; | |||
| })); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -150,15 +150,13 @@ std::shared_ptr<SamplerObj> toSamplerObj(py::handle py_sampler, bool isMindDatas | |||
| std::shared_ptr<SamplerObj> sampler_obj; | |||
| if (!isMindDataset) { | |||
| // Common Sampler | |||
| std::shared_ptr<SamplerRT> sampler; | |||
| auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler)); | |||
| auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse"); | |||
| sampler_obj = parse().cast<std::shared_ptr<SamplerObj>>(); | |||
| } else { | |||
| // Mindrecord Sampler | |||
| std::shared_ptr<mindrecord::ShardOperator> sampler; | |||
| auto create = py::reinterpret_borrow<py::object>(py_sampler).attr("create_for_minddataset"); | |||
| sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||
| auto parse = py::reinterpret_borrow<py::object>(py_sampler).attr("parse_for_minddataset"); | |||
| sampler = parse().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||
| sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler)); | |||
| } | |||
| return sampler_obj; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -211,6 +211,27 @@ std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDa | |||
| } | |||
| #endif | |||
| Status DistributedSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "DistributedSampler"; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_samples"] = num_samples_; | |||
| args["offset"] = offset_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // PKSampler | |||
| PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) | |||
| : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} | |||
| @@ -226,6 +247,25 @@ Status PKSamplerObj::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| Status PKSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "PKSampler"; | |||
| args["num_val"] = num_val_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | |||
| @@ -233,6 +273,21 @@ std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() { | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| std::shared_ptr<mindrecord::ShardOperator> mind_sampler; | |||
| if (shuffle_ == true) { | |||
| mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(), | |||
| GetSeed(), num_samples_); | |||
| } else { | |||
| mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_); | |||
| } | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| // PreBuiltOperation | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {} | |||
| @@ -274,24 +329,9 @@ Status PreBuiltSamplerObj::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| std::shared_ptr<mindrecord::ShardOperator> mind_sampler; | |||
| if (shuffle_ == true) { | |||
| mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(), | |||
| GetSeed(), num_samples_); | |||
| } else { | |||
| mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_); | |||
| } | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| // RandomSampler | |||
| RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) | |||
| : replacement_(replacement), num_samples_(num_samples) {} | |||
| RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch) | |||
| : replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {} | |||
| Status RandomSamplerObj::ValidateParams() { | |||
| if (num_samples_ < 0) { | |||
| @@ -300,10 +340,28 @@ Status RandomSamplerObj::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| Status RandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "RandomSampler"; | |||
| args["replacement"] = replacement_; | |||
| args["num_samples"] = num_samples_; | |||
| args["reshuffle_each_epoch"] = reshuffle_each_epoch_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| bool reshuffle_each_epoch = true; | |||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); | |||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| } | |||
| @@ -311,7 +369,6 @@ std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() { | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| bool reshuffle_each_epoch_ = true; | |||
| auto mind_sampler = | |||
| std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_); | |||
| @@ -335,6 +392,24 @@ Status SequentialSamplerObj::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| Status SequentialSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SequentialSampler"; | |||
| args["start_index"] = start_index_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | |||
| @@ -378,6 +453,23 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetSamplerObj::BuildForMindDataset | |||
| return mind_sampler; | |||
| } | |||
| #endif | |||
| Status SubsetSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SubsetSampler"; | |||
| args["indices"] = indices_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // SubsetRandomSampler | |||
| SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | |||
| @@ -399,6 +491,24 @@ std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindD | |||
| } | |||
| #endif | |||
| Status SubsetRandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "SubsetRandomSampler"; | |||
| args["indices"] = indices_; | |||
| args["num_samples"] = num_samples_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // WeightedRandomSampler | |||
| WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement) | |||
| : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} | |||
| @@ -426,6 +536,25 @@ Status WeightedRandomSamplerObj::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["sampler_name"] = "WeightedRandomSampler"; | |||
| args["weights"] = weights_; | |||
| args["num_samples"] = num_samples_; | |||
| args["replacement"] = replacement_; | |||
| if (!children_.empty()) { | |||
| std::vector<nlohmann::json> children_args; | |||
| for (auto child : children_) { | |||
| nlohmann::json child_arg; | |||
| RETURN_IF_NOT_OK(child->to_json(&child_arg)); | |||
| children_args.push_back(child_arg); | |||
| } | |||
| args["child_sampler"] = children_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() { | |||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | |||
| BuildChildren(sampler); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -66,6 +66,8 @@ class SamplerObj { | |||
| virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } | |||
| std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; } | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, | |||
| /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler | |||
| @@ -175,6 +177,11 @@ class DistributedSamplerObj : public SamplerObj { | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| /// \brief Function to get the shard id of sampler | |||
| @@ -211,6 +218,11 @@ class PKSamplerObj : public SamplerObj { | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| @@ -249,14 +261,14 @@ class PreBuiltSamplerObj : public SamplerObj { | |||
| class RandomSamplerObj : public SamplerObj { | |||
| public: | |||
| RandomSamplerObj(bool replacement, int64_t num_samples); | |||
| RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true); | |||
| ~RandomSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_); | |||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_); | |||
| for (auto child : children_) { | |||
| sampler->AddChildSampler(child); | |||
| } | |||
| @@ -267,11 +279,17 @@ class RandomSamplerObj : public SamplerObj { | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| bool replacement_; | |||
| int64_t num_samples_; | |||
| bool reshuffle_each_epoch_; | |||
| }; | |||
| class SequentialSamplerObj : public SamplerObj { | |||
| @@ -294,6 +312,11 @@ class SequentialSamplerObj : public SamplerObj { | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| @@ -321,6 +344,11 @@ class SubsetSamplerObj : public SamplerObj { | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| protected: | |||
| @@ -334,6 +362,8 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj { | |||
| ~SubsetRandomSamplerObj() = default; | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| @@ -367,6 +397,11 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||
| return sampler; | |||
| } | |||
| /// \brief Get the arguments of node | |||
| /// \param[out] out_json JSON string of all attributes | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -36,7 +36,7 @@ class BuiltinSampler: | |||
| self.child_sampler = None | |||
| self.num_samples = num_samples | |||
| def create(self): | |||
| def parse(self): | |||
| pass | |||
| def add_child(self, sampler): | |||
| @@ -59,16 +59,16 @@ class BuiltinSampler: | |||
| def get_child(self): | |||
| return self.child_sampler | |||
| def create_child(self): | |||
| def parse_child(self): | |||
| c_child_sampler = None | |||
| if self.child_sampler is not None: | |||
| c_child_sampler = self.child_sampler.create() | |||
| c_child_sampler = self.child_sampler.parse() | |||
| return c_child_sampler | |||
| def create_child_for_minddataset(self): | |||
| def parse_child_for_minddataset(self): | |||
| c_child_sampler = None | |||
| if self.child_sampler is not None: | |||
| c_child_sampler = self.child_sampler.create_for_minddataset() | |||
| c_child_sampler = self.child_sampler.parse_for_minddataset() | |||
| return c_child_sampler | |||
| def is_shuffled(self): | |||
| @@ -158,6 +158,8 @@ class Sampler(BuiltinSampler): | |||
| def __init__(self, num_samples=None): | |||
| super().__init__(num_samples) | |||
| self.dataset_size = 0 | |||
| self.child_sampler = None | |||
| self.num_samples = num_samples | |||
| def __iter__(self): | |||
| """ | |||
| @@ -192,13 +194,26 @@ class Sampler(BuiltinSampler): | |||
| # Instance fetcher | |||
| # Do not override this method! | |||
| def create(self): | |||
| def parse(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.PythonSampler(num_samples, self) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler = cde.PreBuiltSamplerObj(num_samples, self) | |||
| c_child_sampler = self.parse_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def add_child(self, sampler): | |||
| self.child_sampler = sampler | |||
| def get_child(self): | |||
| return self.child_sampler | |||
| def parse_child(self): | |||
| c_child_sampler = None | |||
| if self.child_sampler is not None: | |||
| c_child_sampler = self.child_sampler.parse() | |||
| return c_child_sampler | |||
| def is_shuffled(self): | |||
| if self.child_sampler is None: | |||
| return False | |||
| @@ -246,24 +261,15 @@ class DistributedSampler(BuiltinSampler): | |||
| """ | |||
| def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1): | |||
| if num_shards <= 0: | |||
| raise ValueError("num_shards should be a positive integer value, but got num_shards:{}.".format(num_shards)) | |||
| if not isinstance(num_shards, int): | |||
| raise ValueError("num_shards must be integer but was: {}.".format(num_shards)) | |||
| if shard_id < 0 or shard_id >= num_shards: | |||
| raise ValueError("shard_id should in range [0, {}], but got shard_id: {}.".format(num_shards, shard_id)) | |||
| if not isinstance(shard_id, int): | |||
| raise ValueError("shard_id must be integer but was: {}.".format(shard_id)) | |||
| if not isinstance(shuffle, bool): | |||
| raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle)) | |||
| if num_samples is not None: | |||
| if num_samples <= 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples: {}.".format(num_samples)) | |||
| if offset > num_shards: | |||
| raise ValueError("offset should be no more than num_shards: {}, " | |||
| "but got offset: {}".format(num_shards, offset)) | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.shuffle = shuffle | |||
| @@ -271,21 +277,23 @@ class DistributedSampler(BuiltinSampler): | |||
| self.offset = offset | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| def parse(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| shuffle = self.shuffle if self.shuffle is not None else True | |||
| offset = self.offset if self.offset is not None else -1 | |||
| # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle | |||
| self.seed += 1 | |||
| c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id, | |||
| self.shuffle, self.seed, self.offset) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler = cde.DistributedSamplerObj(self.num_shards, self.shard_id, | |||
| shuffle, num_samples, self.seed, offset, True) | |||
| c_child_sampler = self.parse_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def create_for_minddataset(self): | |||
| def parse_for_minddataset(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, | |||
| self.seed, num_samples, self.offset) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_child_sampler = self.parse_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -334,8 +342,8 @@ class PKSampler(BuiltinSampler): | |||
| """ | |||
| def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): | |||
| if num_val <= 0: | |||
| raise ValueError("num_val should be a positive integer value, but got num_val: {}.".format(num_val)) | |||
| if not isinstance(num_val, int): | |||
| raise ValueError("num_val must be integer but was: {}.".format(num_val)) | |||
| if num_class is not None: | |||
| raise NotImplementedError("Not supported to specify num_class for PKSampler.") | |||
| @@ -343,20 +351,16 @@ class PKSampler(BuiltinSampler): | |||
| if not isinstance(shuffle, bool): | |||
| raise ValueError("shuffle should be a boolean value, but got shuffle: {}.".format(shuffle)) | |||
| if num_samples is not None: | |||
| if num_samples <= 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples: {}.".format(num_samples)) | |||
| self.num_val = num_val | |||
| self.shuffle = shuffle | |||
| self.class_column = class_column # work for minddataset | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| def parse(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle) | |||
| c_child_sampler = self.create_child() | |||
| shuffle = self.shuffle if self.shuffle is not None else False | |||
| c_sampler = cde.PKSamplerObj(self.num_val, shuffle, num_samples) | |||
| c_child_sampler = self.parse_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -372,13 +376,13 @@ class PKSampler(BuiltinSampler): | |||
| return self.child_sampler.is_sharded() | |||
| def create_for_minddataset(self): | |||
| def parse_for_minddataset(self): | |||
| if not self.class_column or not isinstance(self.class_column, str): | |||
| raise ValueError("class_column should be a not empty string value, \ | |||
| but got class_column: {}.".format(class_column)) | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_child_sampler = self.parse_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -409,27 +413,23 @@ class RandomSampler(BuiltinSampler): | |||
| if not isinstance(replacement, bool): | |||
| raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement)) | |||
| if num_samples is not None: | |||
| if num_samples <= 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples: {}.".format(num_samples)) | |||
| self.deterministic = False | |||
| self.replacement = replacement | |||
| self.reshuffle_each_epoch = True | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| def parse(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) | |||
| c_child_sampler = self.create_child() | |||
| replacement = self.replacement if self.replacement is not None else False | |||
| c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch) | |||
| c_child_sampler = self.parse_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def create_for_minddataset(self): | |||
| def parse_for_minddataset(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_child_sampler = self.parse_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -462,32 +462,22 @@ class SequentialSampler(BuiltinSampler): | |||
| """ | |||
| def __init__(self, start_index=None, num_samples=None): | |||
| if num_samples is not None: | |||
| if num_samples <= 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples: {}.".format(num_samples)) | |||
| if start_index is not None: | |||
| if start_index < 0: | |||
| raise ValueError("start_index should be a positive integer " | |||
| "value or 0, but got start_index: {}.".format(start_index)) | |||
| self.start_index = start_index | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| def parse(self): | |||
| start_index = self.start_index if self.start_index is not None else 0 | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.SequentialSampler(num_samples, start_index) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler = cde.SequentialSamplerObj(start_index, num_samples) | |||
| c_child_sampler = self.parse_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def create_for_minddataset(self): | |||
| def parse_for_minddataset(self): | |||
| start_index = self.start_index if self.start_index is not None else 0 | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_child_sampler = self.parse_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -525,21 +515,21 @@ class SubsetSampler(BuiltinSampler): | |||
| """ | |||
| def __init__(self, indices, num_samples=None): | |||
| if num_samples is not None: | |||
| if num_samples <= 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples: {}.".format(num_samples)) | |||
| if not isinstance(indices, list): | |||
| indices = [indices] | |||
| for i, item in enumerate(indices): | |||
| if not isinstance(item, numbers.Number): | |||
| raise TypeError("type of weights element should be number, " | |||
| "but got w[{}]: {}, type: {}.".format(i, item, type(item))) | |||
| self.indices = indices | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| def parse(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.SubsetSampler(num_samples, self.indices) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler = cde.SubsetSamplerObj(self.indices, num_samples) | |||
| c_child_sampler = self.parse_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -552,9 +542,9 @@ class SubsetSampler(BuiltinSampler): | |||
| return self.child_sampler.is_sharded() | |||
| def create_for_minddataset(self): | |||
| def parse_for_minddataset(self): | |||
| c_sampler = cde.MindrecordSubsetSampler(self.indices) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_child_sampler = self.parse_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -586,19 +576,19 @@ class SubsetRandomSampler(SubsetSampler): | |||
| >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler) | |||
| """ | |||
| def create(self): | |||
| def parse(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.SubsetRandomSampler(num_samples, self.indices) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler = cde.SubsetRandomSamplerObj(self.indices, num_samples) | |||
| c_child_sampler = self.parse_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def is_shuffled(self): | |||
| return True | |||
| def create_for_minddataset(self): | |||
| def parse_for_minddataset(self): | |||
| c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed()) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_child_sampler = self.parse_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -637,20 +627,6 @@ class WeightedRandomSampler(BuiltinSampler): | |||
| raise TypeError("type of weights element should be number, " | |||
| "but got w[{}]: {}, type: {}.".format(ind, w, type(w))) | |||
| if weights == []: | |||
| raise ValueError("weights size should not be 0") | |||
| if list(filter(lambda x: x < 0, weights)) != []: | |||
| raise ValueError("weights should not contain negative numbers.") | |||
| if list(filter(lambda x: x == 0, weights)) == weights: | |||
| raise ValueError("elements of weights should not be all zeros.") | |||
| if num_samples is not None: | |||
| if num_samples <= 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples: {}.".format(num_samples)) | |||
| if not isinstance(replacement, bool): | |||
| raise ValueError("replacement should be a boolean value, but got replacement: {}.".format(replacement)) | |||
| @@ -658,10 +634,11 @@ class WeightedRandomSampler(BuiltinSampler): | |||
| self.replacement = replacement | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| def parse(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement) | |||
| c_child_sampler = self.create_child() | |||
| replacement = self.replacement if self.replacement is not None else True | |||
| c_sampler = cde.WeightedRandomSamplerObj(self.weights, num_samples, replacement) | |||
| c_child_sampler = self.parse_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -401,20 +401,23 @@ def test_weighted_random_sampler_exception(): | |||
| weights = (0.9, 0.8, 1.1) | |||
| ds.WeightedRandomSampler(weights) | |||
| error_msg_3 = "weights size should not be 0" | |||
| with pytest.raises(ValueError, match=error_msg_3): | |||
| error_msg_3 = "WeightedRandomSampler: weights vector must not be empty" | |||
| with pytest.raises(RuntimeError, match=error_msg_3): | |||
| weights = [] | |||
| ds.WeightedRandomSampler(weights) | |||
| sampler = ds.WeightedRandomSampler(weights) | |||
| sampler.parse() | |||
| error_msg_4 = "weights should not contain negative numbers" | |||
| with pytest.raises(ValueError, match=error_msg_4): | |||
| error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: " | |||
| with pytest.raises(RuntimeError, match=error_msg_4): | |||
| weights = [1.0, 0.1, 0.02, 0.3, -0.4] | |||
| ds.WeightedRandomSampler(weights) | |||
| sampler = ds.WeightedRandomSampler(weights) | |||
| sampler.parse() | |||
| error_msg_5 = "elements of weights should not be all zero" | |||
| with pytest.raises(ValueError, match=error_msg_5): | |||
| error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero" | |||
| with pytest.raises(RuntimeError, match=error_msg_5): | |||
| weights = [0, 0, 0, 0, 0] | |||
| ds.WeightedRandomSampler(weights) | |||
| sampler = ds.WeightedRandomSampler(weights) | |||
| sampler.parse() | |||
| def test_chained_sampler_01(): | |||
| @@ -1,5 +1,5 @@ | |||
| #!/usr/bin/env python | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -273,14 +273,14 @@ def test_cv_minddataset_partition_num_samples_equals_0(): | |||
| for partition_id in range(num_shards): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, | |||
| num_shards=num_shards, | |||
| shard_id=partition_id, num_samples=0) | |||
| shard_id=partition_id, num_samples=-1) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| with pytest.raises(Exception) as error_info: | |||
| with pytest.raises(ValueError) as error_info: | |||
| partitions(5) | |||
| try: | |||
| assert 'num_samples should be a positive integer value, but got num_samples: 0.' in str(error_info.value) | |||
| assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(error_info.value) | |||
| except Exception as error: | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -91,23 +91,9 @@ def test_random_sampler_multi_iter(print_res=False): | |||
| def test_sampler_py_api(): | |||
| sampler = ds.SequentialSampler().create() | |||
| sampler.set_num_rows(128) | |||
| sampler.set_num_samples(64) | |||
| sampler.initialize() | |||
| sampler.get_indices() | |||
| sampler = ds.RandomSampler().create() | |||
| sampler.set_num_rows(128) | |||
| sampler.set_num_samples(64) | |||
| sampler.initialize() | |||
| sampler.get_indices() | |||
| sampler = ds.DistributedSampler(8, 4).create() | |||
| sampler.set_num_rows(128) | |||
| sampler.set_num_samples(64) | |||
| sampler.initialize() | |||
| sampler.get_indices() | |||
| sampler = ds.SequentialSampler().parse() | |||
| sampler1 = ds.RandomSampler().parse() | |||
| sampler1.add_child(sampler) | |||
| def test_python_sampler(): | |||
| @@ -158,12 +144,6 @@ def test_python_sampler(): | |||
| assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] | |||
| test_generator() | |||
| sp1 = Sp1().create() | |||
| sp1.set_num_rows(5) | |||
| sp1.set_num_samples(5) | |||
| sp1.initialize() | |||
| assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] | |||
| def test_sequential_sampler2(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| @@ -229,8 +209,8 @@ def test_subset_sampler(): | |||
| test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") | |||
| test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]") | |||
| # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset | |||
| test_config([0, 9, 3, 2], num_samples=0, | |||
| exception_msg="num_samples should be a positive integer value, but got num_samples: 0.") | |||
| test_config([0, 9, 3, 2], num_samples=-1, | |||
| exception_msg="SubsetRandomSampler: invalid num_samples: -1") | |||
| def test_sampler_chain(): | |||
| @@ -280,9 +260,9 @@ def test_add_sampler_invalid_input(): | |||
| def test_distributed_sampler_invalid_offset(): | |||
| with pytest.raises(ValueError) as info: | |||
| sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5) | |||
| assert "offset should be no more than num_shards" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5).parse() | |||
| assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value) | |||
| if __name__ == '__main__': | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -377,7 +377,7 @@ def test_serdes_exception(): | |||
| def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files): | |||
| """ | |||
| Utility function for testing serdes files. It is to check if a json file is indeed created with correct name | |||
| after serializing and if it remains the same after repeatly saving and loading. | |||
| after serializing and if it remains the same after repeatedly saving and loading. | |||
| :param data_orig: original data pipeline to be serialized | |||
| :param filename: filename to be saved as json format | |||
| :param remove_json_files: whether to remove the json file after testing | |||