| @@ -391,35 +391,27 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *in_partitions) { | |||
| if (args["partitions"].is_none()) { | |||
| std::string err_msg = "Error: partitions is not set (None)"; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| py::list list = py::reinterpret_borrow<py::list>(args["partitions"]); | |||
| for (auto l : list) { | |||
| if (!l.is_none()) { | |||
| in_partitions->push_back(ToInt(l)); | |||
| Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, | |||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators, | |||
| int num_padded) { | |||
| auto sampler = py::reinterpret_borrow<py::object>(handle); | |||
| auto create = sampler.attr("create_for_minddataset"); | |||
| auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||
| std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops; | |||
| while (op != nullptr) { | |||
| auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op); | |||
| if (sampler_op && num_padded > 0) { | |||
| sampler_op->SetNumPaddedSamples(num_padded); | |||
| stack_ops.push(sampler_op); | |||
| } else { | |||
| stack_ops.push(op); | |||
| } | |||
| op = op->GetChildOp(); | |||
| } | |||
| if (in_partitions->size() != 2) { | |||
| std::string err_msg = "Error: partitions is invalid or not set."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| constexpr int kMaxPartitions = 1024; | |||
| if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) { | |||
| std::string err_msg = "Error: partitions is invalid or not set."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) { | |||
| std::string err_msg = "Error: partitions is invalid or not set."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| while (!stack_ops.empty()) { | |||
| operators->push_back(stack_ops.top()); | |||
| stack_ops.pop(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -460,34 +452,16 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas | |||
| (void)builder->SetNumMindRecordWorkers(ToInt(value)); | |||
| } else if (key == "block_reader" && ToBool(value) == true) { | |||
| (void)builder->SetBlockReader(); | |||
| } else if (key == "shuffle_option" && ToBool(value) == true) { | |||
| if (!args["partitions"].is_none()) continue; | |||
| uint32_t seed = GetSeed(); | |||
| operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed)); | |||
| } else if (key == "sampler") { | |||
| auto sampler = py::reinterpret_borrow<py::object>(value); | |||
| auto create = sampler.attr("_create_for_minddataset"); | |||
| auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||
| operators.push_back(op); | |||
| int num_padded = 0; | |||
| if (!args["num_padded"].is_none()) { | |||
| num_padded = ToInt(args["num_padded"]); | |||
| } | |||
| RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); | |||
| } | |||
| } | |||
| } | |||
| std::vector<int> in_partitions; | |||
| if (!args["partitions"].is_none()) { | |||
| auto ret = CheckMindRecordPartitionInfo(args, &in_partitions); | |||
| if (Status::OK() != ret) { | |||
| return ret; | |||
| } | |||
| auto shuffle = ToBool(args["shuffle_option"]); | |||
| int num_padded = 0; | |||
| if (!args["num_padded"].is_none()) { | |||
| num_padded = ToInt(args["num_padded"]); | |||
| } | |||
| operators.push_back( | |||
| std::make_shared<mindrecord::ShardDistributedSample>(in_partitions[0], in_partitions[1], num_padded, shuffle, 0)); | |||
| } | |||
| if (!operators.empty()) { | |||
| (void)builder->SetOperators(operators); | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| @@ -108,10 +109,12 @@ class DEPipeline { | |||
| Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *ptr); | |||
| Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status BuildMindrecordSamplerChain(const py::handle &handle, | |||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators, | |||
| int num_padded); | |||
| Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -71,6 +71,7 @@ | |||
| #include "mindrecord/include/shard_pk_sample.h" | |||
| #include "mindrecord/include/shard_distributed_sample.h" | |||
| #include "mindrecord/include/shard_sample.h" | |||
| #include "mindrecord/include/shard_sequential_sample.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| #include "pybind11/stl_bind.h" | |||
| @@ -165,8 +166,8 @@ void bindDatasetOps(py::module *m) { | |||
| const int64_t num_padded) { | |||
| int64_t count = 0; | |||
| std::shared_ptr<mindrecord::ShardOperator> op; | |||
| if (py::hasattr(sampler, "_create_for_minddataset")) { | |||
| auto create = sampler.attr("_create_for_minddataset"); | |||
| if (py::hasattr(sampler, "create_for_minddataset")) { | |||
| auto create = sampler.attr("create_for_minddataset"); | |||
| op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||
| } | |||
| THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); | |||
| @@ -486,7 +487,9 @@ void bindSamplerOps(py::module *m) { | |||
| .def("add_child", | |||
| [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); }); | |||
| (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); | |||
| (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator") | |||
| .def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self, | |||
| std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); }); | |||
| (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | |||
| .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>()); | |||
| @@ -518,6 +521,22 @@ void bindSamplerOps(py::module *m) { | |||
| } | |||
| })); | |||
| (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample, | |||
| std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler") | |||
| .def(py::init<int64_t, int64_t, bool, uint32_t>()); | |||
| (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>( | |||
| *m, "MindrecordRandomSampler") | |||
| .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { | |||
| return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch); | |||
| })); | |||
| (void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample, | |||
| std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler") | |||
| .def(py::init([](int num_samples, int start_index) { | |||
| return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index); | |||
| })); | |||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | |||
| .def(py::init<int64_t, std::vector<double>, bool>()); | |||
| @@ -31,6 +31,10 @@ class ShardDistributedSample : public ShardSample { | |||
| public: | |||
| ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); | |||
| ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); | |||
| void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } | |||
| ~ShardDistributedSample() override{}; | |||
| MSRStatus PreExecute(ShardTask &tasks) override; | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | |||
| #define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | |||
| #include <memory> | |||
| #include "mindrecord/include/shard_task.h" | |||
| namespace mindspore { | |||
| @@ -37,6 +38,14 @@ class ShardOperator { | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| virtual bool HasChildOp() { return child_op_ != nullptr; } | |||
| virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) { | |||
| if (child_op != nullptr) child_op_ = child_op; | |||
| return SUCCESS; | |||
| } | |||
| virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } | |||
| virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } | |||
| @@ -44,7 +53,10 @@ class ShardOperator { | |||
| virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } | |||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } | |||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } | |||
| private: | |||
| std::shared_ptr<ShardOperator> child_op_ = nullptr; | |||
| }; | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -34,6 +34,7 @@ | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <set> | |||
| #include <stack> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <tuple> | |||
| @@ -44,6 +45,7 @@ | |||
| #include "mindrecord/include/common/shard_utils.h" | |||
| #include "mindrecord/include/shard_category.h" | |||
| #include "mindrecord/include/shard_column.h" | |||
| #include "mindrecord/include/shard_distributed_sample.h" | |||
| #include "mindrecord/include/shard_error.h" | |||
| #include "mindrecord/include/shard_index_generator.h" | |||
| #include "mindrecord/include/shard_operator.h" | |||
| @@ -48,10 +48,10 @@ class ShardSample : public ShardOperator { | |||
| int numerator_; | |||
| int denominator_; | |||
| int partition_id_; | |||
| int no_of_samples_; | |||
| std::shared_ptr<ShardShuffle> shuffle_op_; | |||
| private: | |||
| int no_of_samples_; | |||
| std::vector<int64_t> indices_; | |||
| SamplerType sampler_type_; | |||
| }; | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ | |||
| #define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "mindrecord/include/shard_sample.h" | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| class ShardSequentialSample : public ShardSample { | |||
| public: | |||
| ShardSequentialSample(int n, int offset); | |||
| ShardSequentialSample(float per, float per_offset); | |||
| ~ShardSequentialSample() override{}; | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| private: | |||
| int offset_; | |||
| float per_; | |||
| float per_offset_; | |||
| }; | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| #endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ | |||
| @@ -26,12 +26,20 @@ class ShardShuffle : public ShardOperator { | |||
| public: | |||
| explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); | |||
| ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, | |||
| ShuffleType shuffle_type = kShuffleSample); | |||
| ~ShardShuffle() override{}; | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| private: | |||
| uint32_t shuffle_seed_; | |||
| int64_t no_of_samples_; | |||
| bool replacement_; | |||
| bool reshuffle_each_epoch_; | |||
| ShuffleType shuffle_type_; | |||
| }; | |||
| } // namespace mindrecord | |||
| @@ -792,24 +792,51 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { | |||
| } | |||
| MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | |||
| const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded) { | |||
| const std::shared_ptr<ShardOperator> &ops, int64_t *count, const int num_padded) { | |||
| if (SUCCESS != Init(file_paths, load_dataset)) { | |||
| return FAILED; | |||
| } | |||
| int64_t num_samples = num_rows_; | |||
| if (std::dynamic_pointer_cast<ShardCategory>(op)) { | |||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | |||
| std::string category_field = category_op->GetCategoryField(); | |||
| auto num_classes = GetNumClasses(category_field); | |||
| num_samples = category_op->GetNumSamples(num_rows_, num_classes); | |||
| } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | |||
| num_samples = op->GetNumSamples(num_rows_, 0); | |||
| if (-1 == num_samples) { | |||
| MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; | |||
| return FAILED; | |||
| bool root = true; | |||
| std::stack<std::shared_ptr<ShardOperator>> stack_ops; | |||
| std::shared_ptr<ShardOperator> op(ops); | |||
| while (op != nullptr) { | |||
| stack_ops.push(op); | |||
| op = op->GetChildOp(); | |||
| } | |||
| while (!stack_ops.empty()) { | |||
| op = stack_ops.top(); | |||
| stack_ops.pop(); | |||
| if (std::dynamic_pointer_cast<ShardShuffle>(op)) { | |||
| num_samples = op->GetNumSamples(num_samples, 0); | |||
| if (num_padded > 0 && root == true) { | |||
| num_samples += num_padded; | |||
| MS_LOG(DEBUG) << "Padding samples work on shuffle sampler."; | |||
| root = false; | |||
| } | |||
| } else if (std::dynamic_pointer_cast<ShardCategory>(op)) { | |||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | |||
| std::string category_field = category_op->GetCategoryField(); | |||
| auto num_classes = GetNumClasses(category_field); | |||
| num_samples = category_op->GetNumSamples(num_samples, num_classes); | |||
| } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | |||
| if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) { | |||
| auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op); | |||
| if (root == true) { | |||
| sampler_op->SetNumPaddedSamples(num_padded); | |||
| num_samples = op->GetNumSamples(num_samples, 0); | |||
| if (-1 == num_samples) { | |||
| MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; | |||
| return FAILED; | |||
| } | |||
| root = false; | |||
| } | |||
| } else { | |||
| num_samples = op->GetNumSamples(num_samples, 0); | |||
| } | |||
| } else { | |||
| if (num_padded > 0) num_samples += num_padded; | |||
| } | |||
| } else { | |||
| if (num_padded > 0) num_samples += num_padded; | |||
| } | |||
| *count = num_samples; | |||
| return SUCCESS; | |||
| @@ -1385,12 +1412,16 @@ void ShardReader::Reset() { | |||
| } | |||
| void ShardReader::ShuffleTask() { | |||
| if (block_reader_) return; | |||
| // exist shuffle and distributed sampler in ops, skip shuffle | |||
| bool has_sharding = false; | |||
| for (const auto &op : operators_) { | |||
| if (block_reader_) { | |||
| continue; | |||
| if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) { | |||
| has_sharding = true; | |||
| } | |||
| if (std::dynamic_pointer_cast<ShardShuffle>(op)) { | |||
| } | |||
| for (const auto &op : operators_) { | |||
| if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) { | |||
| if (SUCCESS != (*op)(tasks_)) { | |||
| MS_LOG(WARNING) << "Reshuffle reader tasks failed."; | |||
| } | |||
| @@ -31,6 +31,9 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int | |||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); | |||
| } | |||
| ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) | |||
| : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} | |||
| int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| if (no_of_padded_samples_ <= 0) { | |||
| if (dataset_size % denominator_ == 0) { | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindrecord/include/shard_sequential_sample.h" | |||
| using mindspore::LogStream; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::MsLogLevel::ERROR; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardSequentialSample::ShardSequentialSample(int n, int offset) | |||
| : ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {} | |||
| ShardSequentialSample::ShardSequentialSample(float per, float per_offset) | |||
| : ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {} | |||
| int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | |||
| return dataset_size; | |||
| } | |||
| if (per_ > kEpsilon && per_ <= 1.0f) { | |||
| return dataset_size * kEpsilon; | |||
| } | |||
| return no_of_samples_; | |||
| } | |||
| MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { | |||
| int total_no = static_cast<int>(tasks.Size()); | |||
| int taking; | |||
| if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | |||
| taking = total_no; | |||
| } else if (per_ > kEpsilon && per_ <= 1.0f) { | |||
| taking = total_no * kEpsilon; | |||
| } else { | |||
| taking = no_of_samples_; | |||
| } | |||
| if (tasks.permutation_.empty()) { | |||
| ShardTask new_tasks; | |||
| total_no = static_cast<int>(tasks.Size()); | |||
| for (int i = offset_; i < taking + offset_; ++i) { | |||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| } else { // shuffled | |||
| ShardTask new_tasks; | |||
| if (taking > static_cast<int>(tasks.permutation_.size())) { | |||
| return FAILED; | |||
| } | |||
| total_no = static_cast<int>(tasks.permutation_.size()); | |||
| for (size_t i = offset_; i < taking + offset_; ++i) { | |||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -21,17 +21,52 @@ | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) | |||
| : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} | |||
| : shuffle_seed_(seed), | |||
| no_of_samples_(0), | |||
| replacement_(false), | |||
| reshuffle_each_epoch_(true), | |||
| shuffle_type_(shuffle_type) {} | |||
| ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, | |||
| ShuffleType shuffle_type) | |||
| : shuffle_seed_(seed), | |||
| no_of_samples_(no_of_samples), | |||
| replacement_(replacement), | |||
| reshuffle_each_epoch_(reshuffle_each_epoch), | |||
| shuffle_type_(shuffle_type) {} | |||
| int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| if (replacement_) { | |||
| return no_of_samples_ == 0 ? dataset_size : no_of_samples_; | |||
| } | |||
| return dataset_size; | |||
| } | |||
| MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||
| if (tasks.categories < 1) { | |||
| return FAILED; | |||
| } | |||
| if (shuffle_type_ == kShuffleSample) { | |||
| if (shuffle_type_ == kShuffleSample) { // shuffle each sample | |||
| if (tasks.permutation_.empty() == true) { | |||
| tasks.MakePerm(); | |||
| } | |||
| std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); | |||
| if (replacement_ == true) { | |||
| ShardTask new_tasks; | |||
| if (no_of_samples_ == 0) { | |||
| no_of_samples_ = static_cast<int>(tasks.Size()); | |||
| } | |||
| if (no_of_samples_ <= 0) { | |||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||
| return FAILED; | |||
| } | |||
| new_tasks.task_list_.reserve(no_of_samples_); | |||
| for (uint32_t i = 0; i < no_of_samples_; ++i) { | |||
| new_tasks.InsertTask(tasks.GetRandomTask()); | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| } else { | |||
| std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); | |||
| } | |||
| } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | |||
| uint32_t individual_size = tasks.Size() / tasks.categories; | |||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | |||
| @@ -46,7 +81,7 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||
| } | |||
| } | |||
| } | |||
| shuffle_seed_++; | |||
| if (reshuffle_each_epoch_) shuffle_seed_++; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindrecord | |||
| @@ -72,6 +72,7 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa | |||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | |||
| return task_list_[dis(gen)]; | |||
| } | |||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) { | |||
| ShardTask res; | |||
| if (category_tasks.empty()) return res; | |||
| @@ -1015,10 +1015,8 @@ class Dataset: | |||
| def get_distribution(output_dataset): | |||
| dev_id = 0 | |||
| if isinstance(output_dataset, (MindDataset)): | |||
| return output_dataset.distribution, dev_id | |||
| if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, | |||
| ManifestDataset, MnistDataset, VOCDataset, CelebADataset)): | |||
| ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)): | |||
| sampler = output_dataset.sampler | |||
| if isinstance(sampler, samplers.DistributedSampler): | |||
| dev_id = sampler.shard_id | |||
| @@ -2670,7 +2668,7 @@ class MnistDataset(MappableDataset): | |||
| return self.sampler.is_sharded() | |||
| class MindDataset(SourceDataset): | |||
| class MindDataset(MappableDataset): | |||
| """ | |||
| A source dataset that reads from shard files and database. | |||
| @@ -2687,11 +2685,13 @@ class MindDataset(SourceDataset): | |||
| sampler (Sampler, optional): Object used to choose samples from the | |||
| dataset (default=None, sampler is exclusive | |||
| with shuffle and block_reader). Support list: SubsetRandomSampler, | |||
| PkSampler. | |||
| PkSampler, RandomSampler, SequentialSampler, DistributedSampler. | |||
| padded_sample (dict, optional): Samples will be appended to dataset, which | |||
| keys are the same as column_list. | |||
| num_padded (int, optional): Number of padding samples.Dataset size | |||
| plus num_padded should be divisible by num_shards. | |||
| num_samples (int, optional): The number of samples to be included in the dataset | |||
| (default=None, all samples). | |||
| Raises: | |||
| ValueError: If num_shards is specified but shard_id is None. | |||
| @@ -2703,7 +2703,7 @@ class MindDataset(SourceDataset): | |||
| def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, | |||
| shuffle=None, num_shards=None, shard_id=None, | |||
| block_reader=False, sampler=None, padded_sample=None, | |||
| num_padded=None): | |||
| num_padded=None, num_samples=None): | |||
| super().__init__(num_parallel_workers) | |||
| if isinstance(dataset_file, list): | |||
| self.load_dataset = False | |||
| @@ -2712,15 +2712,10 @@ class MindDataset(SourceDataset): | |||
| self.dataset_file = dataset_file | |||
| self.columns_list = columns_list | |||
| self.shuffle_option = shuffle | |||
| self.distribution = "" | |||
| self.sampler = sampler | |||
| if num_shards is None or shard_id is None: | |||
| self.partitions = None | |||
| else: | |||
| self.partitions = [num_shards, shard_id] | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| if block_reader is True and self.partitions is not None: | |||
| if block_reader is True and num_shards is not None: | |||
| raise ValueError("block reader not allowed true when use partitions") | |||
| if block_reader is True and shuffle is True: | |||
| @@ -2730,25 +2725,21 @@ class MindDataset(SourceDataset): | |||
| logger.warning("WARN: global shuffle is not used.") | |||
| if sampler is not None: | |||
| if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ | |||
| isinstance(sampler, samplers.PKSampler) is False: | |||
| if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler, | |||
| samplers.DistributedSampler, samplers.RandomSampler, | |||
| samplers.SequentialSampler)) is False: | |||
| raise ValueError("the sampler is not supported yet.") | |||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | |||
| self.num_samples = num_samples | |||
| # sampler exclusive | |||
| if block_reader is True and sampler is not None: | |||
| raise ValueError("block reader not allowed true when use sampler") | |||
| if shuffle is not None and sampler is not None: | |||
| raise ValueError("shuffle not allowed when use sampler") | |||
| if block_reader is False and sampler is None: | |||
| self.shuffle_option = not bool(shuffle is False) | |||
| if num_padded is None: | |||
| num_padded = 0 | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.block_reader = block_reader | |||
| self.padded_sample = padded_sample | |||
| self.num_padded = num_padded | |||
| @@ -2766,10 +2757,8 @@ class MindDataset(SourceDataset): | |||
| args["load_dataset"] = self.load_dataset | |||
| args["columns_list"] = self.columns_list | |||
| args["shuffle_option"] = self.shuffle_option | |||
| args["partitions"] = self.partitions | |||
| args["num_samples"] = self.num_samples | |||
| args["block_reader"] = self.block_reader | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["num_padded"] = self.num_padded | |||
| args["padded_sample"] = padded_sample | |||
| args["sampler"] = self.sampler | |||
| @@ -2788,14 +2777,6 @@ class MindDataset(SourceDataset): | |||
| else: | |||
| dataset_file = self.dataset_file | |||
| num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) | |||
| if self.partitions is not None and self.partitions[0] > 0: | |||
| if num_rows % self.partitions[0] == 0: | |||
| num_rows = num_rows // self.partitions[0] | |||
| else: | |||
| if self.num_padded > 0: | |||
| raise RuntimeError( | |||
| "Dataset size plus number of padded samples is not divisible by number of shards.") | |||
| num_rows = num_rows // self.partitions[0] + 1 | |||
| return num_rows | |||
| return self._dataset_size | |||
| @@ -141,7 +141,12 @@ class BuiltinSampler: | |||
| c_child_sampler = None | |||
| if self.child_sampler is not None: | |||
| c_child_sampler = self.child_sampler.create() | |||
| return c_child_sampler | |||
| def create_child_for_minddataset(self): | |||
| c_child_sampler = None | |||
| if self.child_sampler is not None: | |||
| c_child_sampler = self.child_sampler.create_for_minddataset() | |||
| return c_child_sampler | |||
| def is_shuffled(self): | |||
| @@ -262,6 +267,12 @@ class DistributedSampler(BuiltinSampler): | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def create_for_minddataset(self): | |||
| c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def is_shuffled(self): | |||
| if self.child_sampler is None: | |||
| return self.shuffle | |||
| @@ -318,7 +329,7 @@ class PKSampler(BuiltinSampler): | |||
| self.num_val = num_val | |||
| self.shuffle = shuffle | |||
| self.class_column = class_column # work for minddataset | |||
| self.class_column = class_column # work for minddataset | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| @@ -340,12 +351,14 @@ class PKSampler(BuiltinSampler): | |||
| return self.child_sampler.is_sharded() | |||
| def _create_for_minddataset(self): | |||
| def create_for_minddataset(self): | |||
| if not self.class_column or not isinstance(self.class_column, str): | |||
| raise ValueError("class_column should be a not empty string value, \ | |||
| but got class_column={}".format(class_column)) | |||
| return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | |||
| c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| class RandomSampler(BuiltinSampler): | |||
| """ | |||
| @@ -390,6 +403,13 @@ class RandomSampler(BuiltinSampler): | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def create_for_minddataset(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def is_shuffled(self): | |||
| return True | |||
| @@ -440,6 +460,14 @@ class SequentialSampler(BuiltinSampler): | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def create_for_minddataset(self): | |||
| start_index = self.start_index if self.start_index is not None else 0 | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def is_shuffled(self): | |||
| if self.child_sampler is None: | |||
| return False | |||
| @@ -501,8 +529,11 @@ class SubsetRandomSampler(BuiltinSampler): | |||
| return self.child_sampler.is_sharded() | |||
| def _create_for_minddataset(self): | |||
| return cde.MindrecordSubsetRandomSampler(self.indices) | |||
| def create_for_minddataset(self): | |||
| c_sampler = cde.MindrecordSubsetRandomSampler(self.indices) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def get_num_samples(self): | |||
| num_samples = super().get_num_samples() | |||
| @@ -17,6 +17,7 @@ This is the test module for mindrecord | |||
| """ | |||
| import os | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| @@ -64,10 +65,12 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| @@ -82,12 +85,14 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[data]: \ | |||
| {}------------------------".format(item["data"][:10])) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| @@ -102,10 +107,12 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() == 9 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| @@ -119,10 +126,12 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): | |||
| assert data_set.get_dataset_size() == 15 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| @@ -219,7 +228,6 @@ def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file | |||
| def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| indices = [1, 2, 4, -1, -2] | |||
| @@ -241,6 +249,344 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): | |||
| assert num_iter == 5 | |||
| def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file): | |||
| data = get_data(CV_DIR_NAME, True) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.RandomSampler() | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| assert data_set.get_dataset_size() == 10 | |||
| num_iter = 0 | |||
| new_dataset = [] | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| new_dataset.append(item['file_name']) | |||
| assert num_iter == 10 | |||
| assert new_dataset != [x['file_name'] for x in data] | |||
| def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file): | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.RandomSampler() | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| assert data_set.get_dataset_size() == 10 | |||
| ds1 = data_set.repeat(3) | |||
| num_iter = 0 | |||
| epoch1_dataset = [] | |||
| epoch2_dataset = [] | |||
| epoch3_dataset = [] | |||
| for item in ds1.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| if num_iter <= 10: | |||
| epoch1_dataset.append(item['file_name']) | |||
| elif num_iter <= 20: | |||
| epoch2_dataset.append(item['file_name']) | |||
| else: | |||
| epoch3_dataset.append(item['file_name']) | |||
| assert num_iter == 30 | |||
| assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset) | |||
| assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset) | |||
| assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset) | |||
| def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file): | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.RandomSampler(replacement=True, num_samples=5) | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| assert data_set.get_dataset_size() == 5 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 5 | |||
| def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file): | |||
| data = get_data(CV_DIR_NAME, True) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.SequentialSampler(1, 4) | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| assert data_set.get_dataset_size() == 4 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| assert item['file_name'] == np.array( | |||
| data[num_iter+1]['file_name'], dtype='S') | |||
| num_iter += 1 | |||
| assert num_iter == 4 | |||
| def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file): | |||
| data = get_data(CV_DIR_NAME, True) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.SequentialSampler(2, 10) | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| dataset_size = data_set.get_dataset_size() | |||
| assert dataset_size == 10 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| assert item['file_name'] == np.array( | |||
| data[(num_iter + 2) % dataset_size]['file_name'], dtype='S') | |||
| num_iter += 1 | |||
| assert num_iter == 10 | |||
| def test_cv_minddataset_split_basic(add_and_remove_cv_file): | |||
| data = get_data(CV_DIR_NAME, True) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||
| num_readers, shuffle=False) | |||
| d1, d2 = d.split([8, 2], randomize=False) | |||
| assert d.get_dataset_size() == 10 | |||
| assert d1.get_dataset_size() == 8 | |||
| assert d2.get_dataset_size() == 2 | |||
| num_iter = 0 | |||
| for item in d1.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| assert item['file_name'] == np.array(data[num_iter]['file_name'], | |||
| dtype='S') | |||
| num_iter += 1 | |||
| assert num_iter == 8 | |||
| num_iter = 0 | |||
| for item in d2.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| assert item['file_name'] == np.array(data[num_iter + 8]['file_name'], | |||
| dtype='S') | |||
| num_iter += 1 | |||
| assert num_iter == 2 | |||
| def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file): | |||
| data = get_data(CV_DIR_NAME, True) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||
| num_readers, shuffle=False) | |||
| d1, d2 = d.split([0.8, 0.2], randomize=False) | |||
| assert d.get_dataset_size() == 10 | |||
| assert d1.get_dataset_size() == 8 | |||
| assert d2.get_dataset_size() == 2 | |||
| num_iter = 0 | |||
| for item in d1.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| assert item['file_name'] == np.array( | |||
| data[num_iter]['file_name'], dtype='S') | |||
| num_iter += 1 | |||
| assert num_iter == 8 | |||
| num_iter = 0 | |||
| for item in d2.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| assert item['file_name'] == np.array(data[num_iter + 8]['file_name'], | |||
| dtype='S') | |||
| num_iter += 1 | |||
| assert num_iter == 2 | |||
| def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file): | |||
| data = get_data(CV_DIR_NAME, True) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||
| num_readers, shuffle=False) | |||
| d1, d2 = d.split([0.41, 0.59], randomize=False) | |||
| assert d.get_dataset_size() == 10 | |||
| assert d1.get_dataset_size() == 4 | |||
| assert d2.get_dataset_size() == 6 | |||
| num_iter = 0 | |||
| for item in d1.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| assert item['file_name'] == np.array( | |||
| data[num_iter]['file_name'], dtype='S') | |||
| num_iter += 1 | |||
| assert num_iter == 4 | |||
| num_iter = 0 | |||
| for item in d2.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| assert item['file_name'] == np.array(data[num_iter + 4]['file_name'], | |||
| dtype='S') | |||
| num_iter += 1 | |||
| assert num_iter == 6 | |||
| def test_cv_minddataset_split_deterministic(add_and_remove_cv_file): | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||
| num_readers, shuffle=False) | |||
| # should set seed to avoid data overlap | |||
| ds.config.set_seed(111) | |||
| d1, d2 = d.split([0.8, 0.2]) | |||
| assert d.get_dataset_size() == 10 | |||
| assert d1.get_dataset_size() == 8 | |||
| assert d2.get_dataset_size() == 2 | |||
| d1_dataset = [] | |||
| d2_dataset = [] | |||
| num_iter = 0 | |||
| for item in d1.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| d1_dataset.append(item['file_name']) | |||
| num_iter += 1 | |||
| assert num_iter == 8 | |||
| num_iter = 0 | |||
| for item in d2.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| d2_dataset.append(item['file_name']) | |||
| num_iter += 1 | |||
| assert num_iter == 2 | |||
| inter_dataset = [x for x in d1_dataset if x in d2_dataset] | |||
| assert inter_dataset == [] # intersection of d1 and d2 | |||
| def test_cv_minddataset_split_sharding(add_and_remove_cv_file): | |||
| data = get_data(CV_DIR_NAME, True) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||
| num_readers, shuffle=False) | |||
| # should set seed to avoid data overlap | |||
| ds.config.set_seed(111) | |||
| d1, d2 = d.split([0.8, 0.2]) | |||
| assert d.get_dataset_size() == 10 | |||
| assert d1.get_dataset_size() == 8 | |||
| assert d2.get_dataset_size() == 2 | |||
| distributed_sampler = ds.DistributedSampler(2, 0) | |||
| d1.use_sampler(distributed_sampler) | |||
| assert d1.get_dataset_size() == 4 | |||
| num_iter = 0 | |||
| d1_shard1 = [] | |||
| for item in d1.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| d1_shard1.append(item['file_name']) | |||
| assert num_iter == 4 | |||
| assert d1_shard1 != [x['file_name'] for x in data[0:4]] | |||
| distributed_sampler = ds.DistributedSampler(2, 1) | |||
| d1.use_sampler(distributed_sampler) | |||
| assert d1.get_dataset_size() == 4 | |||
| d1s = d1.repeat(3) | |||
| epoch1_dataset = [] | |||
| epoch2_dataset = [] | |||
| epoch3_dataset = [] | |||
| num_iter = 0 | |||
| for item in d1s.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||
| logger.info( | |||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| if num_iter <= 4: | |||
| epoch1_dataset.append(item['file_name']) | |||
| elif num_iter <= 8: | |||
| epoch2_dataset.append(item['file_name']) | |||
| else: | |||
| epoch3_dataset.append(item['file_name']) | |||
| assert len(epoch1_dataset) == 4 | |||
| assert len(epoch2_dataset) == 4 | |||
| assert len(epoch3_dataset) == 4 | |||
| inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset] | |||
| assert inter_dataset == [] # intersection of d1's shard1 and d1's shard2 | |||
| assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset) | |||
| assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset) | |||
| assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset) | |||
| def get_data(dir_name, sampler=False): | |||
| """ | |||
| usage: get data from imagenet dataset | |||