Browse Source

add split in minddataset

tags/v0.5.0-beta
liyong 5 years ago
parent
commit
d4f8f57c7e
17 changed files with 699 additions and 127 deletions
  1. +23
    -49
      mindspore/ccsrc/dataset/api/de_pipeline.cc
  2. +5
    -2
      mindspore/ccsrc/dataset/api/de_pipeline.h
  3. +22
    -3
      mindspore/ccsrc/dataset/api/python_bindings.cc
  4. +4
    -0
      mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h
  5. +13
    -1
      mindspore/ccsrc/mindrecord/include/shard_operator.h
  6. +2
    -0
      mindspore/ccsrc/mindrecord/include/shard_reader.h
  7. +1
    -1
      mindspore/ccsrc/mindrecord/include/shard_sample.h
  8. +48
    -0
      mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h
  9. +8
    -0
      mindspore/ccsrc/mindrecord/include/shard_shuffle.h
  10. +48
    -17
      mindspore/ccsrc/mindrecord/io/shard_reader.cc
  11. +3
    -0
      mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc
  12. +74
    -0
      mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc
  13. +39
    -4
      mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
  14. +1
    -0
      mindspore/ccsrc/mindrecord/meta/shard_task.cc
  15. +16
    -35
      mindspore/dataset/engine/datasets.py
  16. +37
    -6
      mindspore/dataset/engine/samplers.py
  17. +355
    -9
      tests/ut/python/dataset/test_minddataset_sampler.py

+ 23
- 49
mindspore/ccsrc/dataset/api/de_pipeline.cc View File

@@ -391,35 +391,27 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
return Status::OK(); return Status::OK();
} }


Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *in_partitions) {
if (args["partitions"].is_none()) {
std::string err_msg = "Error: partitions is not set (None)";
RETURN_STATUS_UNEXPECTED(err_msg);
}

py::list list = py::reinterpret_borrow<py::list>(args["partitions"]);
for (auto l : list) {
if (!l.is_none()) {
in_partitions->push_back(ToInt(l));
Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded) {
auto sampler = py::reinterpret_borrow<py::object>(handle);
auto create = sampler.attr("create_for_minddataset");
auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
while (op != nullptr) {
auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
if (sampler_op && num_padded > 0) {
sampler_op->SetNumPaddedSamples(num_padded);
stack_ops.push(sampler_op);
} else {
stack_ops.push(op);
} }
op = op->GetChildOp();
} }

if (in_partitions->size() != 2) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}

constexpr int kMaxPartitions = 1024;
if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}

if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
while (!stack_ops.empty()) {
operators->push_back(stack_ops.top());
stack_ops.pop();
} }

return Status::OK(); return Status::OK();
} }


@@ -460,34 +452,16 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
(void)builder->SetNumMindRecordWorkers(ToInt(value)); (void)builder->SetNumMindRecordWorkers(ToInt(value));
} else if (key == "block_reader" && ToBool(value) == true) { } else if (key == "block_reader" && ToBool(value) == true) {
(void)builder->SetBlockReader(); (void)builder->SetBlockReader();
} else if (key == "shuffle_option" && ToBool(value) == true) {
if (!args["partitions"].is_none()) continue;
uint32_t seed = GetSeed();
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
} else if (key == "sampler") { } else if (key == "sampler") {
auto sampler = py::reinterpret_borrow<py::object>(value);
auto create = sampler.attr("_create_for_minddataset");
auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
operators.push_back(op);
int num_padded = 0;
if (!args["num_padded"].is_none()) {
num_padded = ToInt(args["num_padded"]);
}
RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded));
} }
} }
} }


std::vector<int> in_partitions;
if (!args["partitions"].is_none()) {
auto ret = CheckMindRecordPartitionInfo(args, &in_partitions);
if (Status::OK() != ret) {
return ret;
}
auto shuffle = ToBool(args["shuffle_option"]);
int num_padded = 0;
if (!args["num_padded"].is_none()) {
num_padded = ToInt(args["num_padded"]);
}
operators.push_back(
std::make_shared<mindrecord::ShardDistributedSample>(in_partitions[0], in_partitions[1], num_padded, shuffle, 0));
}

if (!operators.empty()) { if (!operators.empty()) {
(void)builder->SetOperators(operators); (void)builder->SetOperators(operators);
} }


+ 5
- 2
mindspore/ccsrc/dataset/api/de_pipeline.h View File

@@ -18,6 +18,7 @@


#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <stack>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
@@ -108,10 +109,12 @@ class DEPipeline {


Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);


Status CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *ptr);

Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);


Status BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded);

Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);


Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);


+ 22
- 3
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -71,6 +71,7 @@
#include "mindrecord/include/shard_pk_sample.h" #include "mindrecord/include/shard_pk_sample.h"
#include "mindrecord/include/shard_distributed_sample.h" #include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_sequential_sample.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "pybind11/stl_bind.h" #include "pybind11/stl_bind.h"
@@ -165,8 +166,8 @@ void bindDatasetOps(py::module *m) {
const int64_t num_padded) { const int64_t num_padded) {
int64_t count = 0; int64_t count = 0;
std::shared_ptr<mindrecord::ShardOperator> op; std::shared_ptr<mindrecord::ShardOperator> op;
if (py::hasattr(sampler, "_create_for_minddataset")) {
auto create = sampler.attr("_create_for_minddataset");
if (py::hasattr(sampler, "create_for_minddataset")) {
auto create = sampler.attr("create_for_minddataset");
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
} }
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
@@ -486,7 +487,9 @@ void bindSamplerOps(py::module *m) {
.def("add_child", .def("add_child",
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); }); [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });


(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator")
.def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });


(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>()); .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
@@ -518,6 +521,22 @@ void bindSamplerOps(py::module *m) {
} }
})); }));


(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>());

(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
*m, "MindrecordRandomSampler")
.def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
}));

(void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler")
.def(py::init([](int num_samples, int start_index) {
return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
}));

(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>()); .def(py::init<int64_t, std::vector<double>, bool>());




+ 4
- 0
mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h View File

@@ -31,6 +31,10 @@ class ShardDistributedSample : public ShardSample {
public: public:
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed);


ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed);

void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; }

~ShardDistributedSample() override{}; ~ShardDistributedSample() override{};


MSRStatus PreExecute(ShardTask &tasks) override; MSRStatus PreExecute(ShardTask &tasks) override;


+ 13
- 1
mindspore/ccsrc/mindrecord/include/shard_operator.h View File

@@ -17,6 +17,7 @@
#ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ #ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
#define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ #define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_


#include <memory>
#include "mindrecord/include/shard_task.h" #include "mindrecord/include/shard_task.h"


namespace mindspore { namespace mindspore {
@@ -37,6 +38,14 @@ class ShardOperator {
} }
return SUCCESS; return SUCCESS;
} }
virtual bool HasChildOp() { return child_op_ != nullptr; }

virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) {
if (child_op != nullptr) child_op_ = child_op;
return SUCCESS;
}

virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }


virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; }


@@ -44,7 +53,10 @@ class ShardOperator {


virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; }


virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; }
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }

private:
std::shared_ptr<ShardOperator> child_op_ = nullptr;
}; };
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore


+ 2
- 0
mindspore/ccsrc/mindrecord/include/shard_reader.h View File

@@ -34,6 +34,7 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <set> #include <set>
#include <stack>
#include <string> #include <string>
#include <thread> #include <thread>
#include <tuple> #include <tuple>
@@ -44,6 +45,7 @@
#include "mindrecord/include/common/shard_utils.h" #include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_column.h" #include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_error.h" #include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_index_generator.h" #include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_operator.h" #include "mindrecord/include/shard_operator.h"


+ 1
- 1
mindspore/ccsrc/mindrecord/include/shard_sample.h View File

@@ -48,10 +48,10 @@ class ShardSample : public ShardOperator {
int numerator_; int numerator_;
int denominator_; int denominator_;
int partition_id_; int partition_id_;
int no_of_samples_;
std::shared_ptr<ShardShuffle> shuffle_op_; std::shared_ptr<ShardShuffle> shuffle_op_;


private: private:
int no_of_samples_;
std::vector<int64_t> indices_; std::vector<int64_t> indices_;
SamplerType sampler_type_; SamplerType sampler_type_;
}; };


+ 48
- 0
mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_
#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_

#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_sample.h"

namespace mindspore {
namespace mindrecord {
class ShardSequentialSample : public ShardSample {
public:
ShardSequentialSample(int n, int offset);

ShardSequentialSample(float per, float per_offset);

~ShardSequentialSample() override{};

MSRStatus Execute(ShardTask &tasks) override;

int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;

private:
int offset_;
float per_;
float per_offset_;
};
} // namespace mindrecord
} // namespace mindspore

#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_

+ 8
- 0
mindspore/ccsrc/mindrecord/include/shard_shuffle.h View File

@@ -26,12 +26,20 @@ class ShardShuffle : public ShardOperator {
public: public:
explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory);


ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
ShuffleType shuffle_type = kShuffleSample);

~ShardShuffle() override{}; ~ShardShuffle() override{};


MSRStatus Execute(ShardTask &tasks) override; MSRStatus Execute(ShardTask &tasks) override;


int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;

private: private:
uint32_t shuffle_seed_; uint32_t shuffle_seed_;
int64_t no_of_samples_;
bool replacement_;
bool reshuffle_each_epoch_;
ShuffleType shuffle_type_; ShuffleType shuffle_type_;
}; };
} // namespace mindrecord } // namespace mindrecord


+ 48
- 17
mindspore/ccsrc/mindrecord/io/shard_reader.cc View File

@@ -792,24 +792,51 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
} }


MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded) {
const std::shared_ptr<ShardOperator> &ops, int64_t *count, const int num_padded) {
if (SUCCESS != Init(file_paths, load_dataset)) { if (SUCCESS != Init(file_paths, load_dataset)) {
return FAILED; return FAILED;
} }
int64_t num_samples = num_rows_; int64_t num_samples = num_rows_;
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
std::string category_field = category_op->GetCategoryField();
auto num_classes = GetNumClasses(category_field);
num_samples = category_op->GetNumSamples(num_rows_, num_classes);
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
num_samples = op->GetNumSamples(num_rows_, 0);
if (-1 == num_samples) {
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards.";
return FAILED;
bool root = true;
std::stack<std::shared_ptr<ShardOperator>> stack_ops;
std::shared_ptr<ShardOperator> op(ops);
while (op != nullptr) {
stack_ops.push(op);
op = op->GetChildOp();
}
while (!stack_ops.empty()) {
op = stack_ops.top();
stack_ops.pop();
if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
num_samples = op->GetNumSamples(num_samples, 0);
if (num_padded > 0 && root == true) {
num_samples += num_padded;
MS_LOG(DEBUG) << "Padding samples work on shuffle sampler.";
root = false;
}
} else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
std::string category_field = category_op->GetCategoryField();
auto num_classes = GetNumClasses(category_field);
num_samples = category_op->GetNumSamples(num_samples, num_classes);
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
if (root == true) {
sampler_op->SetNumPaddedSamples(num_padded);
num_samples = op->GetNumSamples(num_samples, 0);
if (-1 == num_samples) {
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards.";
return FAILED;
}
root = false;
}
} else {
num_samples = op->GetNumSamples(num_samples, 0);
}
} else {
if (num_padded > 0) num_samples += num_padded;
} }
} else {
if (num_padded > 0) num_samples += num_padded;
} }
*count = num_samples; *count = num_samples;
return SUCCESS; return SUCCESS;
@@ -1385,12 +1412,16 @@ void ShardReader::Reset() {
} }


void ShardReader::ShuffleTask() { void ShardReader::ShuffleTask() {
if (block_reader_) return;
// exist shuffle and distributed sampler in ops, skip shuffle
bool has_sharding = false;
for (const auto &op : operators_) { for (const auto &op : operators_) {
if (block_reader_) {
continue;
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
has_sharding = true;
} }

if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
}
for (const auto &op : operators_) {
if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
if (SUCCESS != (*op)(tasks_)) { if (SUCCESS != (*op)(tasks_)) {
MS_LOG(WARNING) << "Reshuffle reader tasks failed."; MS_LOG(WARNING) << "Reshuffle reader tasks failed.";
} }


+ 3
- 0
mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc View File

@@ -31,6 +31,9 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
} }


ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed)
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {}

int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (no_of_padded_samples_ <= 0) { if (no_of_padded_samples_ <= 0) {
if (dataset_size % denominator_ == 0) { if (dataset_size % denominator_ == 0) {


+ 74
- 0
mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc View File

@@ -0,0 +1,74 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindrecord/include/shard_sequential_sample.h"

using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::ERROR;

namespace mindspore {
namespace mindrecord {
ShardSequentialSample::ShardSequentialSample(int n, int offset)
: ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {}

ShardSequentialSample::ShardSequentialSample(float per, float per_offset)
: ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {}

int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
return dataset_size;
}
if (per_ > kEpsilon && per_ <= 1.0f) {
return dataset_size * kEpsilon;
}
return no_of_samples_;
}

MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
int total_no = static_cast<int>(tasks.Size());
int taking;
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
taking = total_no;
} else if (per_ > kEpsilon && per_ <= 1.0f) {
taking = total_no * kEpsilon;
} else {
taking = no_of_samples_;
}

if (tasks.permutation_.empty()) {
ShardTask new_tasks;
total_no = static_cast<int>(tasks.Size());
for (int i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
}
std::swap(tasks, new_tasks);
} else { // shuffled
ShardTask new_tasks;
if (taking > static_cast<int>(tasks.permutation_.size())) {
return FAILED;
}
total_no = static_cast<int>(tasks.permutation_.size());
for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
}
std::swap(tasks, new_tasks);
}
return SUCCESS;
}

} // namespace mindrecord
} // namespace mindspore

+ 39
- 4
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc View File

@@ -21,17 +21,52 @@
namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
: shuffle_seed_(seed), shuffle_type_(shuffle_type) {}
: shuffle_seed_(seed),
no_of_samples_(0),
replacement_(false),
reshuffle_each_epoch_(true),
shuffle_type_(shuffle_type) {}

ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch,
ShuffleType shuffle_type)
: shuffle_seed_(seed),
no_of_samples_(no_of_samples),
replacement_(replacement),
reshuffle_each_epoch_(reshuffle_each_epoch),
shuffle_type_(shuffle_type) {}

int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (replacement_) {
return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
}
return dataset_size;
}


MSRStatus ShardShuffle::Execute(ShardTask &tasks) { MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
if (tasks.categories < 1) { if (tasks.categories < 1) {
return FAILED; return FAILED;
} }
if (shuffle_type_ == kShuffleSample) {
if (shuffle_type_ == kShuffleSample) { // shuffle each sample
if (tasks.permutation_.empty() == true) { if (tasks.permutation_.empty() == true) {
tasks.MakePerm(); tasks.MakePerm();
} }
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
if (replacement_ == true) {
ShardTask new_tasks;
if (no_of_samples_ == 0) {
no_of_samples_ = static_cast<int>(tasks.Size());
}
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
new_tasks.task_list_.reserve(no_of_samples_);
for (uint32_t i = 0; i < no_of_samples_; ++i) {
new_tasks.InsertTask(tasks.GetRandomTask());
}
std::swap(tasks, new_tasks);
} else {
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
}
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
uint32_t individual_size = tasks.Size() / tasks.categories; uint32_t individual_size = tasks.Size() / tasks.categories;
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
@@ -46,7 +81,7 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
} }
} }
} }
shuffle_seed_++;
if (reshuffle_each_epoch_) shuffle_seed_++;
return SUCCESS; return SUCCESS;
} }
} // namespace mindrecord } // namespace mindrecord


+ 1
- 0
mindspore/ccsrc/mindrecord/meta/shard_task.cc View File

@@ -72,6 +72,7 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa
std::uniform_int_distribution<> dis(0, task_list_.size() - 1); std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
return task_list_[dis(gen)]; return task_list_[dis(gen)];
} }

ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) { ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) {
ShardTask res; ShardTask res;
if (category_tasks.empty()) return res; if (category_tasks.empty()) return res;


+ 16
- 35
mindspore/dataset/engine/datasets.py View File

@@ -1015,10 +1015,8 @@ class Dataset:


def get_distribution(output_dataset): def get_distribution(output_dataset):
dev_id = 0 dev_id = 0
if isinstance(output_dataset, (MindDataset)):
return output_dataset.distribution, dev_id
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
ManifestDataset, MnistDataset, VOCDataset, CelebADataset)):
ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)):
sampler = output_dataset.sampler sampler = output_dataset.sampler
if isinstance(sampler, samplers.DistributedSampler): if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id dev_id = sampler.shard_id
@@ -2670,7 +2668,7 @@ class MnistDataset(MappableDataset):
return self.sampler.is_sharded() return self.sampler.is_sharded()




class MindDataset(SourceDataset):
class MindDataset(MappableDataset):
""" """
A source dataset that reads from shard files and database. A source dataset that reads from shard files and database.


@@ -2687,11 +2685,13 @@ class MindDataset(SourceDataset):
sampler (Sampler, optional): Object used to choose samples from the sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, sampler is exclusive dataset (default=None, sampler is exclusive
with shuffle and block_reader). Support list: SubsetRandomSampler, with shuffle and block_reader). Support list: SubsetRandomSampler,
PkSampler.
PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
padded_sample (dict, optional): Samples will be appended to dataset, which padded_sample (dict, optional): Samples will be appended to dataset, which
keys are the same as column_list. keys are the same as column_list.
num_padded (int, optional): Number of padding samples.Dataset size num_padded (int, optional): Number of padding samples.Dataset size
plus num_padded should be divisible by num_shards. plus num_padded should be divisible by num_shards.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, all samples).


Raises: Raises:
ValueError: If num_shards is specified but shard_id is None. ValueError: If num_shards is specified but shard_id is None.
@@ -2703,7 +2703,7 @@ class MindDataset(SourceDataset):
def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
shuffle=None, num_shards=None, shard_id=None, shuffle=None, num_shards=None, shard_id=None,
block_reader=False, sampler=None, padded_sample=None, block_reader=False, sampler=None, padded_sample=None,
num_padded=None):
num_padded=None, num_samples=None):
super().__init__(num_parallel_workers) super().__init__(num_parallel_workers)
if isinstance(dataset_file, list): if isinstance(dataset_file, list):
self.load_dataset = False self.load_dataset = False
@@ -2712,15 +2712,10 @@ class MindDataset(SourceDataset):
self.dataset_file = dataset_file self.dataset_file = dataset_file
self.columns_list = columns_list self.columns_list = columns_list
self.shuffle_option = shuffle self.shuffle_option = shuffle
self.distribution = ""
self.sampler = sampler

if num_shards is None or shard_id is None:
self.partitions = None
else:
self.partitions = [num_shards, shard_id]
self.num_shards = num_shards
self.shard_id = shard_id


if block_reader is True and self.partitions is not None:
if block_reader is True and num_shards is not None:
raise ValueError("block reader not allowed true when use partitions") raise ValueError("block reader not allowed true when use partitions")


if block_reader is True and shuffle is True: if block_reader is True and shuffle is True:
@@ -2730,25 +2725,21 @@ class MindDataset(SourceDataset):
logger.warning("WARN: global shuffle is not used.") logger.warning("WARN: global shuffle is not used.")


if sampler is not None: if sampler is not None:
if isinstance(sampler, samplers.SubsetRandomSampler) is False and \
isinstance(sampler, samplers.PKSampler) is False:
if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
samplers.DistributedSampler, samplers.RandomSampler,
samplers.SequentialSampler)) is False:
raise ValueError("the sampler is not supported yet.") raise ValueError("the sampler is not supported yet.")


self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples

# sampler exclusive # sampler exclusive
if block_reader is True and sampler is not None: if block_reader is True and sampler is not None:
raise ValueError("block reader not allowed true when use sampler") raise ValueError("block reader not allowed true when use sampler")


if shuffle is not None and sampler is not None:
raise ValueError("shuffle not allowed when use sampler")

if block_reader is False and sampler is None:
self.shuffle_option = not bool(shuffle is False)

if num_padded is None: if num_padded is None:
num_padded = 0 num_padded = 0


self.num_shards = num_shards
self.shard_id = shard_id
self.block_reader = block_reader self.block_reader = block_reader
self.padded_sample = padded_sample self.padded_sample = padded_sample
self.num_padded = num_padded self.num_padded = num_padded
@@ -2766,10 +2757,8 @@ class MindDataset(SourceDataset):
args["load_dataset"] = self.load_dataset args["load_dataset"] = self.load_dataset
args["columns_list"] = self.columns_list args["columns_list"] = self.columns_list
args["shuffle_option"] = self.shuffle_option args["shuffle_option"] = self.shuffle_option
args["partitions"] = self.partitions
args["num_samples"] = self.num_samples
args["block_reader"] = self.block_reader args["block_reader"] = self.block_reader
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["num_padded"] = self.num_padded args["num_padded"] = self.num_padded
args["padded_sample"] = padded_sample args["padded_sample"] = padded_sample
args["sampler"] = self.sampler args["sampler"] = self.sampler
@@ -2788,14 +2777,6 @@ class MindDataset(SourceDataset):
else: else:
dataset_file = self.dataset_file dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0]
else:
if self.num_padded > 0:
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
return num_rows return num_rows
return self._dataset_size return self._dataset_size




+ 37
- 6
mindspore/dataset/engine/samplers.py View File

@@ -141,7 +141,12 @@ class BuiltinSampler:
c_child_sampler = None c_child_sampler = None
if self.child_sampler is not None: if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create() c_child_sampler = self.child_sampler.create()
return c_child_sampler


def create_child_for_minddataset(self):
c_child_sampler = None
if self.child_sampler is not None:
c_child_sampler = self.child_sampler.create_for_minddataset()
return c_child_sampler return c_child_sampler


def is_shuffled(self): def is_shuffled(self):
@@ -262,6 +267,12 @@ class DistributedSampler(BuiltinSampler):
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler


def create_for_minddataset(self):
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler

def is_shuffled(self): def is_shuffled(self):
if self.child_sampler is None: if self.child_sampler is None:
return self.shuffle return self.shuffle
@@ -318,7 +329,7 @@ class PKSampler(BuiltinSampler):


self.num_val = num_val self.num_val = num_val
self.shuffle = shuffle self.shuffle = shuffle
self.class_column = class_column # work for minddataset
self.class_column = class_column # work for minddataset
super().__init__(num_samples) super().__init__(num_samples)


def create(self): def create(self):
@@ -340,12 +351,14 @@ class PKSampler(BuiltinSampler):


return self.child_sampler.is_sharded() return self.child_sampler.is_sharded()


def _create_for_minddataset(self):
def create_for_minddataset(self):
if not self.class_column or not isinstance(self.class_column, str): if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \ raise ValueError("class_column should be a not empty string value, \
but got class_column={}".format(class_column)) but got class_column={}".format(class_column))
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)

c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler


class RandomSampler(BuiltinSampler): class RandomSampler(BuiltinSampler):
""" """
@@ -390,6 +403,13 @@ class RandomSampler(BuiltinSampler):
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler


def create_for_minddataset(self):
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler

def is_shuffled(self): def is_shuffled(self):
return True return True


@@ -440,6 +460,14 @@ class SequentialSampler(BuiltinSampler):
c_sampler.add_child(c_child_sampler) c_sampler.add_child(c_child_sampler)
return c_sampler return c_sampler


def create_for_minddataset(self):
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler

def is_shuffled(self): def is_shuffled(self):
if self.child_sampler is None: if self.child_sampler is None:
return False return False
@@ -501,8 +529,11 @@ class SubsetRandomSampler(BuiltinSampler):


return self.child_sampler.is_sharded() return self.child_sampler.is_sharded()


def _create_for_minddataset(self):
return cde.MindrecordSubsetRandomSampler(self.indices)
def create_for_minddataset(self):
c_sampler = cde.MindrecordSubsetRandomSampler(self.indices)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler


def get_num_samples(self): def get_num_samples(self):
num_samples = super().get_num_samples() num_samples = super().get_num_samples()


+ 355
- 9
tests/ut/python/dataset/test_minddataset_sampler.py View File

@@ -17,6 +17,7 @@ This is the test module for mindrecord
""" """
import os import os
import pytest import pytest
import numpy as np


import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
@@ -64,10 +65,12 @@ def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 6 assert data_set.get_dataset_size() == 6
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \ logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"]))) {}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1 num_iter += 1




@@ -82,12 +85,14 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 6 assert data_set.get_dataset_size() == 6
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[data]: \ logger.info("-------------- item[data]: \
{}------------------------".format(item["data"][:10])) {}------------------------".format(item["data"][:10]))
logger.info("-------------- item[file_name]: \ logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"]))) {}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1 num_iter += 1




@@ -102,10 +107,12 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 9 assert data_set.get_dataset_size() == 9
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \ logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"]))) {}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1 num_iter += 1




@@ -119,10 +126,12 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
assert data_set.get_dataset_size() == 15 assert data_set.get_dataset_size() == 15
num_iter = 0 num_iter = 0
for item in data_set.create_dict_iterator(): for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \ logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"]))) {}------------------------".format(to_str(item["file_name"])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1 num_iter += 1




@@ -219,7 +228,6 @@ def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file




def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
indices = [1, 2, 4, -1, -2] indices = [1, 2, 4, -1, -2]
@@ -241,6 +249,344 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
assert num_iter == 5 assert num_iter == 5




def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler()
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 10
num_iter = 0
new_dataset = []
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
new_dataset.append(item['file_name'])
assert num_iter == 10
assert new_dataset != [x['file_name'] for x in data]

def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler()
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 10
ds1 = data_set.repeat(3)
num_iter = 0
epoch1_dataset = []
epoch2_dataset = []
epoch3_dataset = []
for item in ds1.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
if num_iter <= 10:
epoch1_dataset.append(item['file_name'])
elif num_iter <= 20:
epoch2_dataset.append(item['file_name'])
else:
epoch3_dataset.append(item['file_name'])
assert num_iter == 30
assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)

def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler(replacement=True, num_samples=5)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 5


def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.SequentialSampler(1, 4)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 4
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[num_iter+1]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 4


def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.SequentialSampler(2, 10)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
dataset_size = data_set.get_dataset_size()
assert dataset_size == 10
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 10


def test_cv_minddataset_split_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
d1, d2 = d.split([8, 2], randomize=False)
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 8
assert d2.get_dataset_size() == 2
num_iter = 0
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(data[num_iter]['file_name'],
dtype='S')
num_iter += 1
assert num_iter == 8
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
dtype='S')
num_iter += 1
assert num_iter == 2


def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
d1, d2 = d.split([0.8, 0.2], randomize=False)
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 8
assert d2.get_dataset_size() == 2
num_iter = 0
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[num_iter]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 8
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
dtype='S')
num_iter += 1
assert num_iter == 2


def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
d1, d2 = d.split([0.41, 0.59], randomize=False)
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 4
assert d2.get_dataset_size() == 6
num_iter = 0
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[num_iter]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 4
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(data[num_iter + 4]['file_name'],
dtype='S')
num_iter += 1
assert num_iter == 6


def test_cv_minddataset_split_deterministic(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
# should set seed to avoid data overlap
ds.config.set_seed(111)
d1, d2 = d.split([0.8, 0.2])
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 8
assert d2.get_dataset_size() == 2

d1_dataset = []
d2_dataset = []
num_iter = 0
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
d1_dataset.append(item['file_name'])
num_iter += 1
assert num_iter == 8
num_iter = 0
for item in d2.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
d2_dataset.append(item['file_name'])
num_iter += 1
assert num_iter == 2
inter_dataset = [x for x in d1_dataset if x in d2_dataset]
assert inter_dataset == [] # intersection of d1 and d2


def test_cv_minddataset_split_sharding(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
num_readers, shuffle=False)
# should set seed to avoid data overlap
ds.config.set_seed(111)
d1, d2 = d.split([0.8, 0.2])
assert d.get_dataset_size() == 10
assert d1.get_dataset_size() == 8
assert d2.get_dataset_size() == 2
distributed_sampler = ds.DistributedSampler(2, 0)
d1.use_sampler(distributed_sampler)
assert d1.get_dataset_size() == 4

num_iter = 0
d1_shard1 = []
for item in d1.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
d1_shard1.append(item['file_name'])
assert num_iter == 4
assert d1_shard1 != [x['file_name'] for x in data[0:4]]

distributed_sampler = ds.DistributedSampler(2, 1)
d1.use_sampler(distributed_sampler)
assert d1.get_dataset_size() == 4

d1s = d1.repeat(3)
epoch1_dataset = []
epoch2_dataset = []
epoch3_dataset = []
num_iter = 0
for item in d1s.create_dict_iterator():
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
if num_iter <= 4:
epoch1_dataset.append(item['file_name'])
elif num_iter <= 8:
epoch2_dataset.append(item['file_name'])
else:
epoch3_dataset.append(item['file_name'])
assert len(epoch1_dataset) == 4
assert len(epoch2_dataset) == 4
assert len(epoch3_dataset) == 4
inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset]
assert inter_dataset == [] # intersection of d1's shard1 and d1's shard2
assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)


def get_data(dir_name, sampler=False): def get_data(dir_name, sampler=False):
""" """
usage: get data from imagenet dataset usage: get data from imagenet dataset


Loading…
Cancel
Save