| @@ -391,35 +391,27 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *in_partitions) { | |||||
| if (args["partitions"].is_none()) { | |||||
| std::string err_msg = "Error: partitions is not set (None)"; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| py::list list = py::reinterpret_borrow<py::list>(args["partitions"]); | |||||
| for (auto l : list) { | |||||
| if (!l.is_none()) { | |||||
| in_partitions->push_back(ToInt(l)); | |||||
| Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, | |||||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators, | |||||
| int num_padded) { | |||||
| auto sampler = py::reinterpret_borrow<py::object>(handle); | |||||
| auto create = sampler.attr("create_for_minddataset"); | |||||
| auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||||
| std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops; | |||||
| while (op != nullptr) { | |||||
| auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op); | |||||
| if (sampler_op && num_padded > 0) { | |||||
| sampler_op->SetNumPaddedSamples(num_padded); | |||||
| stack_ops.push(sampler_op); | |||||
| } else { | |||||
| stack_ops.push(op); | |||||
| } | } | ||||
| op = op->GetChildOp(); | |||||
| } | } | ||||
| if (in_partitions->size() != 2) { | |||||
| std::string err_msg = "Error: partitions is invalid or not set."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| constexpr int kMaxPartitions = 1024; | |||||
| if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) { | |||||
| std::string err_msg = "Error: partitions is invalid or not set."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) { | |||||
| std::string err_msg = "Error: partitions is invalid or not set."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| while (!stack_ops.empty()) { | |||||
| operators->push_back(stack_ops.top()); | |||||
| stack_ops.pop(); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -460,34 +452,16 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas | |||||
| (void)builder->SetNumMindRecordWorkers(ToInt(value)); | (void)builder->SetNumMindRecordWorkers(ToInt(value)); | ||||
| } else if (key == "block_reader" && ToBool(value) == true) { | } else if (key == "block_reader" && ToBool(value) == true) { | ||||
| (void)builder->SetBlockReader(); | (void)builder->SetBlockReader(); | ||||
| } else if (key == "shuffle_option" && ToBool(value) == true) { | |||||
| if (!args["partitions"].is_none()) continue; | |||||
| uint32_t seed = GetSeed(); | |||||
| operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed)); | |||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto sampler = py::reinterpret_borrow<py::object>(value); | |||||
| auto create = sampler.attr("_create_for_minddataset"); | |||||
| auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||||
| operators.push_back(op); | |||||
| int num_padded = 0; | |||||
| if (!args["num_padded"].is_none()) { | |||||
| num_padded = ToInt(args["num_padded"]); | |||||
| } | |||||
| RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::vector<int> in_partitions; | |||||
| if (!args["partitions"].is_none()) { | |||||
| auto ret = CheckMindRecordPartitionInfo(args, &in_partitions); | |||||
| if (Status::OK() != ret) { | |||||
| return ret; | |||||
| } | |||||
| auto shuffle = ToBool(args["shuffle_option"]); | |||||
| int num_padded = 0; | |||||
| if (!args["num_padded"].is_none()) { | |||||
| num_padded = ToInt(args["num_padded"]); | |||||
| } | |||||
| operators.push_back( | |||||
| std::make_shared<mindrecord::ShardDistributedSample>(in_partitions[0], in_partitions[1], num_padded, shuffle, 0)); | |||||
| } | |||||
| if (!operators.empty()) { | if (!operators.empty()) { | ||||
| (void)builder->SetOperators(operators); | (void)builder->SetOperators(operators); | ||||
| } | } | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <stack> | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -108,10 +109,12 @@ class DEPipeline { | |||||
| Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *ptr); | |||||
| Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status BuildMindrecordSamplerChain(const py::handle &handle, | |||||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators, | |||||
| int num_padded); | |||||
| Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| @@ -71,6 +71,7 @@ | |||||
| #include "mindrecord/include/shard_pk_sample.h" | #include "mindrecord/include/shard_pk_sample.h" | ||||
| #include "mindrecord/include/shard_distributed_sample.h" | #include "mindrecord/include/shard_distributed_sample.h" | ||||
| #include "mindrecord/include/shard_sample.h" | #include "mindrecord/include/shard_sample.h" | ||||
| #include "mindrecord/include/shard_sequential_sample.h" | |||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include "pybind11/stl.h" | #include "pybind11/stl.h" | ||||
| #include "pybind11/stl_bind.h" | #include "pybind11/stl_bind.h" | ||||
| @@ -165,8 +166,8 @@ void bindDatasetOps(py::module *m) { | |||||
| const int64_t num_padded) { | const int64_t num_padded) { | ||||
| int64_t count = 0; | int64_t count = 0; | ||||
| std::shared_ptr<mindrecord::ShardOperator> op; | std::shared_ptr<mindrecord::ShardOperator> op; | ||||
| if (py::hasattr(sampler, "_create_for_minddataset")) { | |||||
| auto create = sampler.attr("_create_for_minddataset"); | |||||
| if (py::hasattr(sampler, "create_for_minddataset")) { | |||||
| auto create = sampler.attr("create_for_minddataset"); | |||||
| op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | ||||
| } | } | ||||
| THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); | THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); | ||||
| @@ -486,7 +487,9 @@ void bindSamplerOps(py::module *m) { | |||||
| .def("add_child", | .def("add_child", | ||||
| [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); }); | [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); }); | ||||
| (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); | |||||
| (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator") | |||||
| .def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self, | |||||
| std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); }); | |||||
| (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | ||||
| .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>()); | .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>()); | ||||
| @@ -518,6 +521,22 @@ void bindSamplerOps(py::module *m) { | |||||
| } | } | ||||
| })); | })); | ||||
| (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample, | |||||
| std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler") | |||||
| .def(py::init<int64_t, int64_t, bool, uint32_t>()); | |||||
| (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>( | |||||
| *m, "MindrecordRandomSampler") | |||||
| .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { | |||||
| return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch); | |||||
| })); | |||||
| (void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample, | |||||
| std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler") | |||||
| .def(py::init([](int num_samples, int start_index) { | |||||
| return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index); | |||||
| })); | |||||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | ||||
| .def(py::init<int64_t, std::vector<double>, bool>()); | .def(py::init<int64_t, std::vector<double>, bool>()); | ||||
| @@ -31,6 +31,10 @@ class ShardDistributedSample : public ShardSample { | |||||
| public: | public: | ||||
| ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); | ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); | ||||
| ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); | |||||
| void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } | |||||
| ~ShardDistributedSample() override{}; | ~ShardDistributedSample() override{}; | ||||
| MSRStatus PreExecute(ShardTask &tasks) override; | MSRStatus PreExecute(ShardTask &tasks) override; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | #ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | ||||
| #define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | #define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | ||||
| #include <memory> | |||||
| #include "mindrecord/include/shard_task.h" | #include "mindrecord/include/shard_task.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -37,6 +38,14 @@ class ShardOperator { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| virtual bool HasChildOp() { return child_op_ != nullptr; } | |||||
| virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) { | |||||
| if (child_op != nullptr) child_op_ = child_op; | |||||
| return SUCCESS; | |||||
| } | |||||
| virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } | |||||
| virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } | virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } | ||||
| @@ -44,7 +53,10 @@ class ShardOperator { | |||||
| virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } | virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } | ||||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } | |||||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } | |||||
| private: | |||||
| std::shared_ptr<ShardOperator> child_op_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,6 +34,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <set> | #include <set> | ||||
| #include <stack> | |||||
| #include <string> | #include <string> | ||||
| #include <thread> | #include <thread> | ||||
| #include <tuple> | #include <tuple> | ||||
| @@ -44,6 +45,7 @@ | |||||
| #include "mindrecord/include/common/shard_utils.h" | #include "mindrecord/include/common/shard_utils.h" | ||||
| #include "mindrecord/include/shard_category.h" | #include "mindrecord/include/shard_category.h" | ||||
| #include "mindrecord/include/shard_column.h" | #include "mindrecord/include/shard_column.h" | ||||
| #include "mindrecord/include/shard_distributed_sample.h" | |||||
| #include "mindrecord/include/shard_error.h" | #include "mindrecord/include/shard_error.h" | ||||
| #include "mindrecord/include/shard_index_generator.h" | #include "mindrecord/include/shard_index_generator.h" | ||||
| #include "mindrecord/include/shard_operator.h" | #include "mindrecord/include/shard_operator.h" | ||||
| @@ -48,10 +48,10 @@ class ShardSample : public ShardOperator { | |||||
| int numerator_; | int numerator_; | ||||
| int denominator_; | int denominator_; | ||||
| int partition_id_; | int partition_id_; | ||||
| int no_of_samples_; | |||||
| std::shared_ptr<ShardShuffle> shuffle_op_; | std::shared_ptr<ShardShuffle> shuffle_op_; | ||||
| private: | private: | ||||
| int no_of_samples_; | |||||
| std::vector<int64_t> indices_; | std::vector<int64_t> indices_; | ||||
| SamplerType sampler_type_; | SamplerType sampler_type_; | ||||
| }; | }; | ||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ | |||||
| #define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "mindrecord/include/shard_sample.h" | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| class ShardSequentialSample : public ShardSample { | |||||
| public: | |||||
| ShardSequentialSample(int n, int offset); | |||||
| ShardSequentialSample(float per, float per_offset); | |||||
| ~ShardSequentialSample() override{}; | |||||
| MSRStatus Execute(ShardTask &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||||
| private: | |||||
| int offset_; | |||||
| float per_; | |||||
| float per_offset_; | |||||
| }; | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| #endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ | |||||
| @@ -26,12 +26,20 @@ class ShardShuffle : public ShardOperator { | |||||
| public: | public: | ||||
| explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); | explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); | ||||
| ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, | |||||
| ShuffleType shuffle_type = kShuffleSample); | |||||
| ~ShardShuffle() override{}; | ~ShardShuffle() override{}; | ||||
| MSRStatus Execute(ShardTask &tasks) override; | MSRStatus Execute(ShardTask &tasks) override; | ||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||||
| private: | private: | ||||
| uint32_t shuffle_seed_; | uint32_t shuffle_seed_; | ||||
| int64_t no_of_samples_; | |||||
| bool replacement_; | |||||
| bool reshuffle_each_epoch_; | |||||
| ShuffleType shuffle_type_; | ShuffleType shuffle_type_; | ||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| @@ -792,24 +792,51 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { | |||||
| } | } | ||||
| MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | ||||
| const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded) { | |||||
| const std::shared_ptr<ShardOperator> &ops, int64_t *count, const int num_padded) { | |||||
| if (SUCCESS != Init(file_paths, load_dataset)) { | if (SUCCESS != Init(file_paths, load_dataset)) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int64_t num_samples = num_rows_; | int64_t num_samples = num_rows_; | ||||
| if (std::dynamic_pointer_cast<ShardCategory>(op)) { | |||||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | |||||
| std::string category_field = category_op->GetCategoryField(); | |||||
| auto num_classes = GetNumClasses(category_field); | |||||
| num_samples = category_op->GetNumSamples(num_rows_, num_classes); | |||||
| } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | |||||
| num_samples = op->GetNumSamples(num_rows_, 0); | |||||
| if (-1 == num_samples) { | |||||
| MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; | |||||
| return FAILED; | |||||
| bool root = true; | |||||
| std::stack<std::shared_ptr<ShardOperator>> stack_ops; | |||||
| std::shared_ptr<ShardOperator> op(ops); | |||||
| while (op != nullptr) { | |||||
| stack_ops.push(op); | |||||
| op = op->GetChildOp(); | |||||
| } | |||||
| while (!stack_ops.empty()) { | |||||
| op = stack_ops.top(); | |||||
| stack_ops.pop(); | |||||
| if (std::dynamic_pointer_cast<ShardShuffle>(op)) { | |||||
| num_samples = op->GetNumSamples(num_samples, 0); | |||||
| if (num_padded > 0 && root == true) { | |||||
| num_samples += num_padded; | |||||
| MS_LOG(DEBUG) << "Padding samples work on shuffle sampler."; | |||||
| root = false; | |||||
| } | |||||
| } else if (std::dynamic_pointer_cast<ShardCategory>(op)) { | |||||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | |||||
| std::string category_field = category_op->GetCategoryField(); | |||||
| auto num_classes = GetNumClasses(category_field); | |||||
| num_samples = category_op->GetNumSamples(num_samples, num_classes); | |||||
| } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | |||||
| if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) { | |||||
| auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op); | |||||
| if (root == true) { | |||||
| sampler_op->SetNumPaddedSamples(num_padded); | |||||
| num_samples = op->GetNumSamples(num_samples, 0); | |||||
| if (-1 == num_samples) { | |||||
| MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; | |||||
| return FAILED; | |||||
| } | |||||
| root = false; | |||||
| } | |||||
| } else { | |||||
| num_samples = op->GetNumSamples(num_samples, 0); | |||||
| } | |||||
| } else { | |||||
| if (num_padded > 0) num_samples += num_padded; | |||||
| } | } | ||||
| } else { | |||||
| if (num_padded > 0) num_samples += num_padded; | |||||
| } | } | ||||
| *count = num_samples; | *count = num_samples; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -1385,12 +1412,16 @@ void ShardReader::Reset() { | |||||
| } | } | ||||
| void ShardReader::ShuffleTask() { | void ShardReader::ShuffleTask() { | ||||
| if (block_reader_) return; | |||||
| // exist shuffle and distributed sampler in ops, skip shuffle | |||||
| bool has_sharding = false; | |||||
| for (const auto &op : operators_) { | for (const auto &op : operators_) { | ||||
| if (block_reader_) { | |||||
| continue; | |||||
| if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) { | |||||
| has_sharding = true; | |||||
| } | } | ||||
| if (std::dynamic_pointer_cast<ShardShuffle>(op)) { | |||||
| } | |||||
| for (const auto &op : operators_) { | |||||
| if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) { | |||||
| if (SUCCESS != (*op)(tasks_)) { | if (SUCCESS != (*op)(tasks_)) { | ||||
| MS_LOG(WARNING) << "Reshuffle reader tasks failed."; | MS_LOG(WARNING) << "Reshuffle reader tasks failed."; | ||||
| } | } | ||||
| @@ -31,6 +31,9 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int | |||||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); | shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); | ||||
| } | } | ||||
| ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) | |||||
| : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} | |||||
| int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | ||||
| if (no_of_padded_samples_ <= 0) { | if (no_of_padded_samples_ <= 0) { | ||||
| if (dataset_size % denominator_ == 0) { | if (dataset_size % denominator_ == 0) { | ||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "mindrecord/include/shard_sequential_sample.h" | |||||
| using mindspore::LogStream; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::ERROR; | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| ShardSequentialSample::ShardSequentialSample(int n, int offset) | |||||
| : ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {} | |||||
| ShardSequentialSample::ShardSequentialSample(float per, float per_offset) | |||||
| : ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {} | |||||
| int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | |||||
| return dataset_size; | |||||
| } | |||||
| if (per_ > kEpsilon && per_ <= 1.0f) { | |||||
| return dataset_size * kEpsilon; | |||||
| } | |||||
| return no_of_samples_; | |||||
| } | |||||
| MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { | |||||
| int total_no = static_cast<int>(tasks.Size()); | |||||
| int taking; | |||||
| if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | |||||
| taking = total_no; | |||||
| } else if (per_ > kEpsilon && per_ <= 1.0f) { | |||||
| taking = total_no * kEpsilon; | |||||
| } else { | |||||
| taking = no_of_samples_; | |||||
| } | |||||
| if (tasks.permutation_.empty()) { | |||||
| ShardTask new_tasks; | |||||
| total_no = static_cast<int>(tasks.Size()); | |||||
| for (int i = offset_; i < taking + offset_; ++i) { | |||||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); | |||||
| } | |||||
| std::swap(tasks, new_tasks); | |||||
| } else { // shuffled | |||||
| ShardTask new_tasks; | |||||
| if (taking > static_cast<int>(tasks.permutation_.size())) { | |||||
| return FAILED; | |||||
| } | |||||
| total_no = static_cast<int>(tasks.permutation_.size()); | |||||
| for (size_t i = offset_; i < taking + offset_; ++i) { | |||||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||||
| } | |||||
| std::swap(tasks, new_tasks); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| @@ -21,17 +21,52 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) | ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) | ||||
| : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} | |||||
| : shuffle_seed_(seed), | |||||
| no_of_samples_(0), | |||||
| replacement_(false), | |||||
| reshuffle_each_epoch_(true), | |||||
| shuffle_type_(shuffle_type) {} | |||||
| ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, | |||||
| ShuffleType shuffle_type) | |||||
| : shuffle_seed_(seed), | |||||
| no_of_samples_(no_of_samples), | |||||
| replacement_(replacement), | |||||
| reshuffle_each_epoch_(reshuffle_each_epoch), | |||||
| shuffle_type_(shuffle_type) {} | |||||
| int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| if (replacement_) { | |||||
| return no_of_samples_ == 0 ? dataset_size : no_of_samples_; | |||||
| } | |||||
| return dataset_size; | |||||
| } | |||||
| MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | ||||
| if (tasks.categories < 1) { | if (tasks.categories < 1) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (shuffle_type_ == kShuffleSample) { | |||||
| if (shuffle_type_ == kShuffleSample) { // shuffle each sample | |||||
| if (tasks.permutation_.empty() == true) { | if (tasks.permutation_.empty() == true) { | ||||
| tasks.MakePerm(); | tasks.MakePerm(); | ||||
| } | } | ||||
| std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); | |||||
| if (replacement_ == true) { | |||||
| ShardTask new_tasks; | |||||
| if (no_of_samples_ == 0) { | |||||
| no_of_samples_ = static_cast<int>(tasks.Size()); | |||||
| } | |||||
| if (no_of_samples_ <= 0) { | |||||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||||
| return FAILED; | |||||
| } | |||||
| new_tasks.task_list_.reserve(no_of_samples_); | |||||
| for (uint32_t i = 0; i < no_of_samples_; ++i) { | |||||
| new_tasks.InsertTask(tasks.GetRandomTask()); | |||||
| } | |||||
| std::swap(tasks, new_tasks); | |||||
| } else { | |||||
| std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); | |||||
| } | |||||
| } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | ||||
| uint32_t individual_size = tasks.Size() / tasks.categories; | uint32_t individual_size = tasks.Size() / tasks.categories; | ||||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | ||||
| @@ -46,7 +81,7 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| shuffle_seed_++; | |||||
| if (reshuffle_each_epoch_) shuffle_seed_++; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| @@ -72,6 +72,7 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa | |||||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | ||||
| return task_list_[dis(gen)]; | return task_list_[dis(gen)]; | ||||
| } | } | ||||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) { | ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) { | ||||
| ShardTask res; | ShardTask res; | ||||
| if (category_tasks.empty()) return res; | if (category_tasks.empty()) return res; | ||||
| @@ -1015,10 +1015,8 @@ class Dataset: | |||||
| def get_distribution(output_dataset): | def get_distribution(output_dataset): | ||||
| dev_id = 0 | dev_id = 0 | ||||
| if isinstance(output_dataset, (MindDataset)): | |||||
| return output_dataset.distribution, dev_id | |||||
| if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, | if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, | ||||
| ManifestDataset, MnistDataset, VOCDataset, CelebADataset)): | |||||
| ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)): | |||||
| sampler = output_dataset.sampler | sampler = output_dataset.sampler | ||||
| if isinstance(sampler, samplers.DistributedSampler): | if isinstance(sampler, samplers.DistributedSampler): | ||||
| dev_id = sampler.shard_id | dev_id = sampler.shard_id | ||||
| @@ -2670,7 +2668,7 @@ class MnistDataset(MappableDataset): | |||||
| return self.sampler.is_sharded() | return self.sampler.is_sharded() | ||||
| class MindDataset(SourceDataset): | |||||
| class MindDataset(MappableDataset): | |||||
| """ | """ | ||||
| A source dataset that reads from shard files and database. | A source dataset that reads from shard files and database. | ||||
| @@ -2687,11 +2685,13 @@ class MindDataset(SourceDataset): | |||||
| sampler (Sampler, optional): Object used to choose samples from the | sampler (Sampler, optional): Object used to choose samples from the | ||||
| dataset (default=None, sampler is exclusive | dataset (default=None, sampler is exclusive | ||||
| with shuffle and block_reader). Support list: SubsetRandomSampler, | with shuffle and block_reader). Support list: SubsetRandomSampler, | ||||
| PkSampler. | |||||
| PkSampler, RandomSampler, SequentialSampler, DistributedSampler. | |||||
| padded_sample (dict, optional): Samples will be appended to dataset, which | padded_sample (dict, optional): Samples will be appended to dataset, which | ||||
| keys are the same as column_list. | keys are the same as column_list. | ||||
| num_padded (int, optional): Number of padding samples.Dataset size | num_padded (int, optional): Number of padding samples.Dataset size | ||||
| plus num_padded should be divisible by num_shards. | plus num_padded should be divisible by num_shards. | ||||
| num_samples (int, optional): The number of samples to be included in the dataset | |||||
| (default=None, all samples). | |||||
| Raises: | Raises: | ||||
| ValueError: If num_shards is specified but shard_id is None. | ValueError: If num_shards is specified but shard_id is None. | ||||
| @@ -2703,7 +2703,7 @@ class MindDataset(SourceDataset): | |||||
| def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, | def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, | ||||
| shuffle=None, num_shards=None, shard_id=None, | shuffle=None, num_shards=None, shard_id=None, | ||||
| block_reader=False, sampler=None, padded_sample=None, | block_reader=False, sampler=None, padded_sample=None, | ||||
| num_padded=None): | |||||
| num_padded=None, num_samples=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| if isinstance(dataset_file, list): | if isinstance(dataset_file, list): | ||||
| self.load_dataset = False | self.load_dataset = False | ||||
| @@ -2712,15 +2712,10 @@ class MindDataset(SourceDataset): | |||||
| self.dataset_file = dataset_file | self.dataset_file = dataset_file | ||||
| self.columns_list = columns_list | self.columns_list = columns_list | ||||
| self.shuffle_option = shuffle | self.shuffle_option = shuffle | ||||
| self.distribution = "" | |||||
| self.sampler = sampler | |||||
| if num_shards is None or shard_id is None: | |||||
| self.partitions = None | |||||
| else: | |||||
| self.partitions = [num_shards, shard_id] | |||||
| self.num_shards = num_shards | |||||
| self.shard_id = shard_id | |||||
| if block_reader is True and self.partitions is not None: | |||||
| if block_reader is True and num_shards is not None: | |||||
| raise ValueError("block reader not allowed true when use partitions") | raise ValueError("block reader not allowed true when use partitions") | ||||
| if block_reader is True and shuffle is True: | if block_reader is True and shuffle is True: | ||||
| @@ -2730,25 +2725,21 @@ class MindDataset(SourceDataset): | |||||
| logger.warning("WARN: global shuffle is not used.") | logger.warning("WARN: global shuffle is not used.") | ||||
| if sampler is not None: | if sampler is not None: | ||||
| if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ | |||||
| isinstance(sampler, samplers.PKSampler) is False: | |||||
| if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler, | |||||
| samplers.DistributedSampler, samplers.RandomSampler, | |||||
| samplers.SequentialSampler)) is False: | |||||
| raise ValueError("the sampler is not supported yet.") | raise ValueError("the sampler is not supported yet.") | ||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | |||||
| self.num_samples = num_samples | |||||
| # sampler exclusive | # sampler exclusive | ||||
| if block_reader is True and sampler is not None: | if block_reader is True and sampler is not None: | ||||
| raise ValueError("block reader not allowed true when use sampler") | raise ValueError("block reader not allowed true when use sampler") | ||||
| if shuffle is not None and sampler is not None: | |||||
| raise ValueError("shuffle not allowed when use sampler") | |||||
| if block_reader is False and sampler is None: | |||||
| self.shuffle_option = not bool(shuffle is False) | |||||
| if num_padded is None: | if num_padded is None: | ||||
| num_padded = 0 | num_padded = 0 | ||||
| self.num_shards = num_shards | |||||
| self.shard_id = shard_id | |||||
| self.block_reader = block_reader | self.block_reader = block_reader | ||||
| self.padded_sample = padded_sample | self.padded_sample = padded_sample | ||||
| self.num_padded = num_padded | self.num_padded = num_padded | ||||
| @@ -2766,10 +2757,8 @@ class MindDataset(SourceDataset): | |||||
| args["load_dataset"] = self.load_dataset | args["load_dataset"] = self.load_dataset | ||||
| args["columns_list"] = self.columns_list | args["columns_list"] = self.columns_list | ||||
| args["shuffle_option"] = self.shuffle_option | args["shuffle_option"] = self.shuffle_option | ||||
| args["partitions"] = self.partitions | |||||
| args["num_samples"] = self.num_samples | |||||
| args["block_reader"] = self.block_reader | args["block_reader"] = self.block_reader | ||||
| args["num_shards"] = self.num_shards | |||||
| args["shard_id"] = self.shard_id | |||||
| args["num_padded"] = self.num_padded | args["num_padded"] = self.num_padded | ||||
| args["padded_sample"] = padded_sample | args["padded_sample"] = padded_sample | ||||
| args["sampler"] = self.sampler | args["sampler"] = self.sampler | ||||
| @@ -2788,14 +2777,6 @@ class MindDataset(SourceDataset): | |||||
| else: | else: | ||||
| dataset_file = self.dataset_file | dataset_file = self.dataset_file | ||||
| num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) | num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) | ||||
| if self.partitions is not None and self.partitions[0] > 0: | |||||
| if num_rows % self.partitions[0] == 0: | |||||
| num_rows = num_rows // self.partitions[0] | |||||
| else: | |||||
| if self.num_padded > 0: | |||||
| raise RuntimeError( | |||||
| "Dataset size plus number of padded samples is not divisible by number of shards.") | |||||
| num_rows = num_rows // self.partitions[0] + 1 | |||||
| return num_rows | return num_rows | ||||
| return self._dataset_size | return self._dataset_size | ||||
| @@ -141,7 +141,12 @@ class BuiltinSampler: | |||||
| c_child_sampler = None | c_child_sampler = None | ||||
| if self.child_sampler is not None: | if self.child_sampler is not None: | ||||
| c_child_sampler = self.child_sampler.create() | c_child_sampler = self.child_sampler.create() | ||||
| return c_child_sampler | |||||
| def create_child_for_minddataset(self): | |||||
| c_child_sampler = None | |||||
| if self.child_sampler is not None: | |||||
| c_child_sampler = self.child_sampler.create_for_minddataset() | |||||
| return c_child_sampler | return c_child_sampler | ||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| @@ -262,6 +267,12 @@ class DistributedSampler(BuiltinSampler): | |||||
| c_sampler.add_child(c_child_sampler) | c_sampler.add_child(c_child_sampler) | ||||
| return c_sampler | return c_sampler | ||||
| def create_for_minddataset(self): | |||||
| c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) | |||||
| c_child_sampler = self.create_child_for_minddataset() | |||||
| c_sampler.add_child(c_child_sampler) | |||||
| return c_sampler | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| if self.child_sampler is None: | if self.child_sampler is None: | ||||
| return self.shuffle | return self.shuffle | ||||
| @@ -318,7 +329,7 @@ class PKSampler(BuiltinSampler): | |||||
| self.num_val = num_val | self.num_val = num_val | ||||
| self.shuffle = shuffle | self.shuffle = shuffle | ||||
| self.class_column = class_column # work for minddataset | |||||
| self.class_column = class_column # work for minddataset | |||||
| super().__init__(num_samples) | super().__init__(num_samples) | ||||
| def create(self): | def create(self): | ||||
| @@ -340,12 +351,14 @@ class PKSampler(BuiltinSampler): | |||||
| return self.child_sampler.is_sharded() | return self.child_sampler.is_sharded() | ||||
| def _create_for_minddataset(self): | |||||
| def create_for_minddataset(self): | |||||
| if not self.class_column or not isinstance(self.class_column, str): | if not self.class_column or not isinstance(self.class_column, str): | ||||
| raise ValueError("class_column should be a not empty string value, \ | raise ValueError("class_column should be a not empty string value, \ | ||||
| but got class_column={}".format(class_column)) | but got class_column={}".format(class_column)) | ||||
| return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | |||||
| c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | |||||
| c_child_sampler = self.create_child_for_minddataset() | |||||
| c_sampler.add_child(c_child_sampler) | |||||
| return c_sampler | |||||
| class RandomSampler(BuiltinSampler): | class RandomSampler(BuiltinSampler): | ||||
| """ | """ | ||||
| @@ -390,6 +403,13 @@ class RandomSampler(BuiltinSampler): | |||||
| c_sampler.add_child(c_child_sampler) | c_sampler.add_child(c_child_sampler) | ||||
| return c_sampler | return c_sampler | ||||
| def create_for_minddataset(self): | |||||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||||
| c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) | |||||
| c_child_sampler = self.create_child_for_minddataset() | |||||
| c_sampler.add_child(c_child_sampler) | |||||
| return c_sampler | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| return True | return True | ||||
| @@ -440,6 +460,14 @@ class SequentialSampler(BuiltinSampler): | |||||
| c_sampler.add_child(c_child_sampler) | c_sampler.add_child(c_child_sampler) | ||||
| return c_sampler | return c_sampler | ||||
| def create_for_minddataset(self): | |||||
| start_index = self.start_index if self.start_index is not None else 0 | |||||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||||
| c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index) | |||||
| c_child_sampler = self.create_child_for_minddataset() | |||||
| c_sampler.add_child(c_child_sampler) | |||||
| return c_sampler | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| if self.child_sampler is None: | if self.child_sampler is None: | ||||
| return False | return False | ||||
| @@ -501,8 +529,11 @@ class SubsetRandomSampler(BuiltinSampler): | |||||
| return self.child_sampler.is_sharded() | return self.child_sampler.is_sharded() | ||||
| def _create_for_minddataset(self): | |||||
| return cde.MindrecordSubsetRandomSampler(self.indices) | |||||
| def create_for_minddataset(self): | |||||
| c_sampler = cde.MindrecordSubsetRandomSampler(self.indices) | |||||
| c_child_sampler = self.create_child_for_minddataset() | |||||
| c_sampler.add_child(c_child_sampler) | |||||
| return c_sampler | |||||
| def get_num_samples(self): | def get_num_samples(self): | ||||
| num_samples = super().get_num_samples() | num_samples = super().get_num_samples() | ||||
| @@ -17,6 +17,7 @@ This is the test module for mindrecord | |||||
| """ | """ | ||||
| import os | import os | ||||
| import pytest | import pytest | ||||
| import numpy as np | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -64,10 +65,12 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() == 6 | assert data_set.get_dataset_size() == 6 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | logger.info("-------------- item[file_name]: \ | ||||
| {}------------------------".format(to_str(item["file_name"]))) | {}------------------------".format(to_str(item["file_name"]))) | ||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -82,12 +85,14 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() == 6 | assert data_set.get_dataset_size() == 6 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[data]: \ | logger.info("-------------- item[data]: \ | ||||
| {}------------------------".format(item["data"][:10])) | {}------------------------".format(item["data"][:10])) | ||||
| logger.info("-------------- item[file_name]: \ | logger.info("-------------- item[file_name]: \ | ||||
| {}------------------------".format(to_str(item["file_name"]))) | {}------------------------".format(to_str(item["file_name"]))) | ||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -102,10 +107,12 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() == 9 | assert data_set.get_dataset_size() == 9 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | logger.info("-------------- item[file_name]: \ | ||||
| {}------------------------".format(to_str(item["file_name"]))) | {}------------------------".format(to_str(item["file_name"]))) | ||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -119,10 +126,12 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): | |||||
| assert data_set.get_dataset_size() == 15 | assert data_set.get_dataset_size() == 15 | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | logger.info("-------------- item[file_name]: \ | ||||
| {}------------------------".format(to_str(item["file_name"]))) | {}------------------------".format(to_str(item["file_name"]))) | ||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| @@ -219,7 +228,6 @@ def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file | |||||
| def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): | def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): | ||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| indices = [1, 2, 4, -1, -2] | indices = [1, 2, 4, -1, -2] | ||||
| @@ -241,6 +249,344 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): | |||||
| assert num_iter == 5 | assert num_iter == 5 | ||||
| def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file): | |||||
| data = get_data(CV_DIR_NAME, True) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.RandomSampler() | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| num_iter = 0 | |||||
| new_dataset = [] | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| new_dataset.append(item['file_name']) | |||||
| assert num_iter == 10 | |||||
| assert new_dataset != [x['file_name'] for x in data] | |||||
| def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file): | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.RandomSampler() | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| ds1 = data_set.repeat(3) | |||||
| num_iter = 0 | |||||
| epoch1_dataset = [] | |||||
| epoch2_dataset = [] | |||||
| epoch3_dataset = [] | |||||
| for item in ds1.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| if num_iter <= 10: | |||||
| epoch1_dataset.append(item['file_name']) | |||||
| elif num_iter <= 20: | |||||
| epoch2_dataset.append(item['file_name']) | |||||
| else: | |||||
| epoch3_dataset.append(item['file_name']) | |||||
| assert num_iter == 30 | |||||
| assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset) | |||||
| assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset) | |||||
| assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset) | |||||
| def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file): | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.RandomSampler(replacement=True, num_samples=5) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 5 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 5 | |||||
| def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file): | |||||
| data = get_data(CV_DIR_NAME, True) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.SequentialSampler(1, 4) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 4 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| assert item['file_name'] == np.array( | |||||
| data[num_iter+1]['file_name'], dtype='S') | |||||
| num_iter += 1 | |||||
| assert num_iter == 4 | |||||
| def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file): | |||||
| data = get_data(CV_DIR_NAME, True) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.SequentialSampler(2, 10) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| dataset_size = data_set.get_dataset_size() | |||||
| assert dataset_size == 10 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| assert item['file_name'] == np.array( | |||||
| data[(num_iter + 2) % dataset_size]['file_name'], dtype='S') | |||||
| num_iter += 1 | |||||
| assert num_iter == 10 | |||||
| def test_cv_minddataset_split_basic(add_and_remove_cv_file): | |||||
| data = get_data(CV_DIR_NAME, True) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||||
| num_readers, shuffle=False) | |||||
| d1, d2 = d.split([8, 2], randomize=False) | |||||
| assert d.get_dataset_size() == 10 | |||||
| assert d1.get_dataset_size() == 8 | |||||
| assert d2.get_dataset_size() == 2 | |||||
| num_iter = 0 | |||||
| for item in d1.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| assert item['file_name'] == np.array(data[num_iter]['file_name'], | |||||
| dtype='S') | |||||
| num_iter += 1 | |||||
| assert num_iter == 8 | |||||
| num_iter = 0 | |||||
| for item in d2.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| assert item['file_name'] == np.array(data[num_iter + 8]['file_name'], | |||||
| dtype='S') | |||||
| num_iter += 1 | |||||
| assert num_iter == 2 | |||||
| def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file): | |||||
| data = get_data(CV_DIR_NAME, True) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||||
| num_readers, shuffle=False) | |||||
| d1, d2 = d.split([0.8, 0.2], randomize=False) | |||||
| assert d.get_dataset_size() == 10 | |||||
| assert d1.get_dataset_size() == 8 | |||||
| assert d2.get_dataset_size() == 2 | |||||
| num_iter = 0 | |||||
| for item in d1.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| assert item['file_name'] == np.array( | |||||
| data[num_iter]['file_name'], dtype='S') | |||||
| num_iter += 1 | |||||
| assert num_iter == 8 | |||||
| num_iter = 0 | |||||
| for item in d2.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| assert item['file_name'] == np.array(data[num_iter + 8]['file_name'], | |||||
| dtype='S') | |||||
| num_iter += 1 | |||||
| assert num_iter == 2 | |||||
| def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file): | |||||
| data = get_data(CV_DIR_NAME, True) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||||
| num_readers, shuffle=False) | |||||
| d1, d2 = d.split([0.41, 0.59], randomize=False) | |||||
| assert d.get_dataset_size() == 10 | |||||
| assert d1.get_dataset_size() == 4 | |||||
| assert d2.get_dataset_size() == 6 | |||||
| num_iter = 0 | |||||
| for item in d1.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| assert item['file_name'] == np.array( | |||||
| data[num_iter]['file_name'], dtype='S') | |||||
| num_iter += 1 | |||||
| assert num_iter == 4 | |||||
| num_iter = 0 | |||||
| for item in d2.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| assert item['file_name'] == np.array(data[num_iter + 4]['file_name'], | |||||
| dtype='S') | |||||
| num_iter += 1 | |||||
| assert num_iter == 6 | |||||
| def test_cv_minddataset_split_deterministic(add_and_remove_cv_file): | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||||
| num_readers, shuffle=False) | |||||
| # should set seed to avoid data overlap | |||||
| ds.config.set_seed(111) | |||||
| d1, d2 = d.split([0.8, 0.2]) | |||||
| assert d.get_dataset_size() == 10 | |||||
| assert d1.get_dataset_size() == 8 | |||||
| assert d2.get_dataset_size() == 2 | |||||
| d1_dataset = [] | |||||
| d2_dataset = [] | |||||
| num_iter = 0 | |||||
| for item in d1.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| d1_dataset.append(item['file_name']) | |||||
| num_iter += 1 | |||||
| assert num_iter == 8 | |||||
| num_iter = 0 | |||||
| for item in d2.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| d2_dataset.append(item['file_name']) | |||||
| num_iter += 1 | |||||
| assert num_iter == 2 | |||||
| inter_dataset = [x for x in d1_dataset if x in d2_dataset] | |||||
| assert inter_dataset == [] # intersection of d1 and d2 | |||||
| def test_cv_minddataset_split_sharding(add_and_remove_cv_file): | |||||
| data = get_data(CV_DIR_NAME, True) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, | |||||
| num_readers, shuffle=False) | |||||
| # should set seed to avoid data overlap | |||||
| ds.config.set_seed(111) | |||||
| d1, d2 = d.split([0.8, 0.2]) | |||||
| assert d.get_dataset_size() == 10 | |||||
| assert d1.get_dataset_size() == 8 | |||||
| assert d2.get_dataset_size() == 2 | |||||
| distributed_sampler = ds.DistributedSampler(2, 0) | |||||
| d1.use_sampler(distributed_sampler) | |||||
| assert d1.get_dataset_size() == 4 | |||||
| num_iter = 0 | |||||
| d1_shard1 = [] | |||||
| for item in d1.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| d1_shard1.append(item['file_name']) | |||||
| assert num_iter == 4 | |||||
| assert d1_shard1 != [x['file_name'] for x in data[0:4]] | |||||
| distributed_sampler = ds.DistributedSampler(2, 1) | |||||
| d1.use_sampler(distributed_sampler) | |||||
| assert d1.get_dataset_size() == 4 | |||||
| d1s = d1.repeat(3) | |||||
| epoch1_dataset = [] | |||||
| epoch2_dataset = [] | |||||
| epoch3_dataset = [] | |||||
| num_iter = 0 | |||||
| for item in d1s.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info( | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| if num_iter <= 4: | |||||
| epoch1_dataset.append(item['file_name']) | |||||
| elif num_iter <= 8: | |||||
| epoch2_dataset.append(item['file_name']) | |||||
| else: | |||||
| epoch3_dataset.append(item['file_name']) | |||||
| assert len(epoch1_dataset) == 4 | |||||
| assert len(epoch2_dataset) == 4 | |||||
| assert len(epoch3_dataset) == 4 | |||||
| inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset] | |||||
| assert inter_dataset == [] # intersection of d1's shard1 and d1's shard2 | |||||
| assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset) | |||||
| assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset) | |||||
| assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset) | |||||
| def get_data(dir_name, sampler=False): | def get_data(dir_name, sampler=False): | ||||
| """ | """ | ||||
| usage: get data from imagenet dataset | usage: get data from imagenet dataset | ||||