Browse Source

!273 [MD] update subset random sampler in minddataset

Merge pull request !273 from liyong126/mindrecord_subset_sampler_python
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
f1fa2a9941
15 changed files with 115 additions and 107 deletions
  1. +4
    -30
      mindspore/ccsrc/dataset/api/de_pipeline.cc
  2. +0
    -3
      mindspore/ccsrc/dataset/api/de_pipeline.h
  3. +8
    -0
      mindspore/ccsrc/dataset/api/python_bindings.cc
  4. +1
    -1
      mindspore/ccsrc/mindrecord/include/shard_category.h
  5. +19
    -1
      mindspore/ccsrc/mindrecord/include/shard_operator.h
  6. +7
    -3
      mindspore/ccsrc/mindrecord/include/shard_sample.h
  7. +1
    -1
      mindspore/ccsrc/mindrecord/include/shard_shuffle.h
  8. +9
    -2
      mindspore/ccsrc/mindrecord/io/shard_reader.cc
  9. +1
    -1
      mindspore/ccsrc/mindrecord/meta/shard_category.cc
  10. +14
    -3
      mindspore/ccsrc/mindrecord/meta/shard_sample.cc
  11. +1
    -1
      mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc
  12. +1
    -3
      mindspore/dataset/engine/datasets.py
  13. +2
    -0
      mindspore/dataset/engine/samplers.py
  14. +47
    -48
      tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
  15. +0
    -10
      tests/ut/python/dataset/test_minddataset_sampler.py

+ 4
- 30
mindspore/ccsrc/dataset/api/de_pipeline.cc View File

@@ -391,30 +391,6 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto
return Status::OK(); return Status::OK();
} }


Status DEPipeline::GetMindrecordSampler(const std::string &sampler_name, const py::dict &args,
std::shared_ptr<mindrecord::ShardOperator> *ptr) {
std::vector<int> indices;
for (auto &arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "indices") {
indices = ToIntVector(value);
} else {
std::string err_msg = "ERROR: parameter " + key + " is invalid.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
}
if (sampler_name == "SubsetRandomSampler") {
*ptr = std::make_shared<mindrecord::ShardSample>(indices);
} else {
std::string err_msg = "ERROR: parameter sampler_name is invalid.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}

Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
if (args["dataset_file"].is_none()) { if (args["dataset_file"].is_none()) {
std::string err_msg = "Error: at least one of dataset_files is missing"; std::string err_msg = "Error: at least one of dataset_files is missing";
@@ -446,12 +422,10 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
} else if (key == "global_shuffle" && ToBool(value) == true) { } else if (key == "global_shuffle" && ToBool(value) == true) {
uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0; uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0;
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed)); operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
} else if (key == "sampler_name") {
std::shared_ptr<mindrecord::ShardOperator> sample_op;
auto ret = GetMindrecordSampler(ToString(value), args["sampler_params"], &sample_op);
if (Status::OK() != ret) {
return ret;
}
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("_create_for_minddataset");
std::shared_ptr<mindrecord::ShardOperator> sample_op =
create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
operators.push_back(sample_op); operators.push_back(sample_op);
} }
} }


+ 0
- 3
mindspore/ccsrc/dataset/api/de_pipeline.h View File

@@ -145,9 +145,6 @@ class DEPipeline {


Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);


Status GetMindrecordSampler(const std::string &sampler_name, const py::dict &args,
std::shared_ptr<mindrecord::ShardOperator> *ptr);

private: private:
// Execution tree that links the dataset operators. // Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_; std::shared_ptr<ExecutionTree> tree_;


+ 8
- 0
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -54,6 +54,9 @@
#include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h" #include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/to_float16_op.h" #include "dataset/kernels/data/to_float16_op.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_sample.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "pybind11/stl_bind.h" #include "pybind11/stl_bind.h"
@@ -382,6 +385,7 @@ void bindTensorOps4(py::module *m) {


void bindSamplerOps(py::module *m) { void bindSamplerOps(py::module *m) {
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler"); (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler");
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");


(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"), .def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
@@ -399,6 +403,10 @@ void bindSamplerOps(py::module *m) {
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler") (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices")); .def(py::init<std::vector<int64_t>>(), py::arg("indices"));


(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
*m, "MindrecordSubsetRandomSampler")
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());

(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
py::arg("replacement")); py::arg("replacement"));


+ 1
- 1
mindspore/ccsrc/mindrecord/include/shard_category.h View File

@@ -32,7 +32,7 @@ class ShardCategory : public ShardOperator {


const std::vector<std::pair<std::string, std::string>> &get_categories() const; const std::vector<std::pair<std::string, std::string>> &get_categories() const;


MSRStatus operator()(ShardTask &tasks) override;
MSRStatus execute(ShardTask &tasks) override;


private: private:
std::vector<std::pair<std::string, std::string>> categories_; std::vector<std::pair<std::string, std::string>> categories_;


+ 19
- 1
mindspore/ccsrc/mindrecord/include/shard_operator.h View File

@@ -24,7 +24,25 @@ namespace mindrecord {
class ShardOperator { class ShardOperator {
public: public:
virtual ~ShardOperator() = default; virtual ~ShardOperator() = default;
virtual MSRStatus operator()(ShardTask &tasks) = 0;

MSRStatus operator()(ShardTask &tasks) {
if (SUCCESS != this->pre_execute(tasks)) {
return FAILED;
}
if (SUCCESS != this->execute(tasks)) {
return FAILED;
}
if (SUCCESS != this->suf_execute(tasks)) {
return FAILED;
}
return SUCCESS;
}

virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; }

virtual MSRStatus execute(ShardTask &tasks) = 0;

virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; }
}; };
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore


+ 7
- 3
mindspore/ccsrc/mindrecord/include/shard_sample.h View File

@@ -17,10 +17,12 @@
#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ #ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_
#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ #define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_


#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "mindrecord/include/shard_operator.h" #include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_shuffle.h"


namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
@@ -32,21 +34,23 @@ class ShardSample : public ShardOperator {


ShardSample(int num, int den, int par); ShardSample(int num, int den, int par);


explicit ShardSample(const std::vector<int> &indices);
ShardSample(const std::vector<int64_t> &indices, uint32_t seed);


~ShardSample() override{}; ~ShardSample() override{};


const std::pair<int, int> get_partitions() const; const std::pair<int, int> get_partitions() const;


MSRStatus operator()(ShardTask &tasks) override;
MSRStatus execute(ShardTask &tasks) override;
MSRStatus suf_execute(ShardTask &tasks) override;


private: private:
int numerator_; int numerator_;
int denominator_; int denominator_;
int no_of_samples_; int no_of_samples_;
int partition_id_; int partition_id_;
std::vector<int> indices_;
std::vector<int64_t> indices_;
SamplerType sampler_type_; SamplerType sampler_type_;
std::shared_ptr<ShardShuffle> shuffle_op_;
}; };
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore


+ 1
- 1
mindspore/ccsrc/mindrecord/include/shard_shuffle.h View File

@@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator {


~ShardShuffle() override{}; ~ShardShuffle() override{};


MSRStatus operator()(ShardTask &tasks) override;
MSRStatus execute(ShardTask &tasks) override;


private: private:
uint32_t shuffle_seed_; uint32_t shuffle_seed_;


+ 9
- 2
mindspore/ccsrc/mindrecord/io/shard_reader.cc View File

@@ -779,8 +779,12 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {


// Sort row group by (group_id, shard_id), prepare for parallel reading // Sort row group by (group_id, shard_id), prepare for parallel reading
std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
CreateTasks(row_group_summary, operators_);
MS_LOG(INFO) << "Launching read threads";
if (CreateTasks(row_group_summary, operators_) != SUCCESS) {
MS_LOG(ERROR) << "Failed to launch read threads.";
interrupt_ = true;
return FAILED;
}
MS_LOG(INFO) << "Launching read threads.";


if (isSimpleReader) return SUCCESS; if (isSimpleReader) return SUCCESS;


@@ -1152,6 +1156,9 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetBlockNext()
} }


std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() { std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() {
if (interrupt_) {
return std::vector<std::tuple<std::vector<uint8_t>, json>>();
}
if (block_reader_) return GetBlockNext(); if (block_reader_) return GetBlockNext();
if (deliver_id_ >= static_cast<int>(tasks_.Size())) { if (deliver_id_ >= static_cast<int>(tasks_.Size())) {
return std::vector<std::tuple<std::vector<uint8_t>, json>>(); return std::vector<std::tuple<std::vector<uint8_t>, json>>();


+ 1
- 1
mindspore/ccsrc/mindrecord/meta/shard_category.cc View File

@@ -23,6 +23,6 @@ ShardCategory::ShardCategory(const std::vector<std::pair<std::string, std::strin


const std::vector<std::pair<std::string, std::string>> &ShardCategory::get_categories() const { return categories_; } const std::vector<std::pair<std::string, std::string>> &ShardCategory::get_categories() const { return categories_; }


MSRStatus ShardCategory::operator()(ShardTask &tasks) { return SUCCESS; }
MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 14
- 3
mindspore/ccsrc/mindrecord/meta/shard_sample.cc View File

@@ -46,13 +46,15 @@ ShardSample::ShardSample(int num, int den, int par)
indices_({}), indices_({}),
sampler_type_(kCustomTopPercentSampler) {} sampler_type_(kCustomTopPercentSampler) {}


ShardSample::ShardSample(const std::vector<int> &indices)
ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed)
: numerator_(0), : numerator_(0),
denominator_(0), denominator_(0),
no_of_samples_(0), no_of_samples_(0),
partition_id_(0), partition_id_(0),
indices_(indices), indices_(indices),
sampler_type_(kSubsetRandomSampler) {}
sampler_type_(kSubsetRandomSampler) {
shuffle_op_ = std::make_shared<ShardShuffle>(seed);
}


const std::pair<int, int> ShardSample::get_partitions() const { const std::pair<int, int> ShardSample::get_partitions() const {
if (numerator_ == 1 && denominator_ > 1) { if (numerator_ == 1 && denominator_ > 1) {
@@ -61,7 +63,7 @@ const std::pair<int, int> ShardSample::get_partitions() const {
return std::pair<int, int>(-1, -1); return std::pair<int, int>(-1, -1);
} }


MSRStatus ShardSample::operator()(ShardTask &tasks) {
MSRStatus ShardSample::execute(ShardTask &tasks) {
int no_of_categories = static_cast<int>(tasks.categories); int no_of_categories = static_cast<int>(tasks.categories);
int total_no = static_cast<int>(tasks.Size()); int total_no = static_cast<int>(tasks.Size());


@@ -115,5 +117,14 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) {
} }
return SUCCESS; return SUCCESS;
} }

MSRStatus ShardSample::suf_execute(ShardTask &tasks) {
if (sampler_type_ == kSubsetRandomSampler) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
}
}
return SUCCESS;
}
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 1
- 1
mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc View File

@@ -22,7 +22,7 @@ namespace mindspore {
namespace mindrecord { namespace mindrecord {
ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {} ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {}


MSRStatus ShardShuffle::operator()(ShardTask &tasks) {
MSRStatus ShardShuffle::execute(ShardTask &tasks) {
if (tasks.categories < 1) { if (tasks.categories < 1) {
return FAILED; return FAILED;
} }


+ 1
- 3
mindspore/dataset/engine/datasets.py View File

@@ -1683,9 +1683,7 @@ class MindDataset(SourceDataset):
args["block_reader"] = self.block_reader args["block_reader"] = self.block_reader
args["num_shards"] = self.num_shards args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id args["shard_id"] = self.shard_id
if self.sampler:
args["sampler_name"] = self.sampler.__class__.__name__
args["sampler_params"] = self.sampler.__dict__
args["sampler"] = self.sampler
return args return args


def get_dataset_size(self): def get_dataset_size(self):


+ 2
- 0
mindspore/dataset/engine/samplers.py View File

@@ -195,6 +195,8 @@ class SubsetRandomSampler():
def create(self): def create(self):
return cde.SubsetRandomSampler(self.indices) return cde.SubsetRandomSampler(self.indices)


def _create_for_minddataset(self):
return cde.MindrecordSubsetRandomSampler(self.indices)


class WeightedRandomSampler(): class WeightedRandomSampler():
""" """


+ 47
- 48
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc View File

@@ -30,9 +30,9 @@
#include "mindrecord/include/shard_shuffle.h" #include "mindrecord/include/shard_shuffle.h"
#include "ut_common.h" #include "ut_common.h"


using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream; using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;


namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
@@ -65,31 +65,31 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
ASSERT_TRUE(i <= kSampleCount); ASSERT_TRUE(i <= kSampleCount);
} }


// TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
// MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
//
// std::string file_name = "./imagenet.shard01";
// auto column_list = std::vector<std::string>{"file_name"};
//
// const int kNum = 5;
// const int kDen = 0;
// std::vector<std::shared_ptr<ShardOperator>> ops;
// ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
//
// ShardReader dataset;
// dataset.Open(file_name, 4, column_list, ops);
// dataset.Launch();
//
// int i = 0;
// while (true) {
// auto x = dataset.GetNext();
// if (x.empty()) break;
// MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
// i++;
// }
// dataset.Finish();
// ASSERT_TRUE(i <= 5);
// }
TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"file_name"};
const int kNum = 5;
const int kDen = 0;
std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardSample>(kNum, kDen));
ShardReader dataset;
dataset.Open(file_name, 4, column_list, ops);
dataset.Launch();
int i = 0;
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
ASSERT_TRUE(i <= 5);
}


TEST_F(TestShardOperator, TestShardSampleRatio) { TEST_F(TestShardOperator, TestShardSampleRatio) {
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
@@ -117,7 +117,6 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
ASSERT_TRUE(i <= 10); ASSERT_TRUE(i <= 10);
} }



TEST_F(TestShardOperator, TestShardSamplePartition) { TEST_F(TestShardOperator, TestShardSamplePartition) {
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
@@ -170,8 +169,8 @@ TEST_F(TestShardOperator, TestShardCategory) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;


MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;


ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
@@ -199,8 +198,8 @@ TEST_F(TestShardOperator, TestShardShuffle) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
} }
dataset.Finish(); dataset.Finish();
@@ -224,8 +223,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
} }
dataset.Finish(); dataset.Finish();
@@ -251,8 +250,8 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
} }
dataset.Finish(); dataset.Finish();
@@ -278,8 +277,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
} }
dataset.Finish(); dataset.Finish();
@@ -307,8 +306,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;


auto y = compare_dataset.GetNext(); auto y = compare_dataset.GetNext();
@@ -342,8 +341,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;


ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
@@ -376,8 +375,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
category_no++; category_no++;
@@ -410,8 +409,8 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;


ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);
@@ -448,8 +447,8 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
while (true) { while (true) {
auto x = dataset.GetNext(); auto x = dataset.GetNext();
if (x.empty()) break; if (x.empty()) break;
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) <<
", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;


ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second);


+ 0
- 10
tests/ut/python/dataset/test_minddataset_sampler.py View File

@@ -81,8 +81,6 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------".format(item["file_name"])) "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info( logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"])) "-------------- item[label]: {} ----------------------------".format(item["label"]))
assert data[indices[num_iter]]['file_name'] == "".join(
[chr(x) for x in item['file_name']])
num_iter += 1 num_iter += 1
assert num_iter == 5 assert num_iter == 5


@@ -107,8 +105,6 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------".format(item["file_name"])) "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info( logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"])) "-------------- item[label]: {} ----------------------------".format(item["label"]))
assert data[indices[num_iter]]['file_name'] == "".join(
[chr(x) for x in item['file_name']])
num_iter += 1 num_iter += 1
assert num_iter == 6 assert num_iter == 6


@@ -133,8 +129,6 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------".format(item["file_name"])) "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info( logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"])) "-------------- item[label]: {} ----------------------------".format(item["label"]))
assert data[indices[num_iter]]['file_name'] == "".join(
[chr(x) for x in item['file_name']])
num_iter += 1 num_iter += 1
assert num_iter == 0 assert num_iter == 0


@@ -159,8 +153,6 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------".format(item["file_name"])) "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info( logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"])) "-------------- item[label]: {} ----------------------------".format(item["label"]))
assert data[indices[num_iter] % len(data)]['file_name'] == "".join([
chr(x) for x in item['file_name']])
num_iter += 1 num_iter += 1
assert num_iter == 5 assert num_iter == 5


@@ -185,8 +177,6 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
"-------------- item[file_name]: {} ------------------------".format(item["file_name"])) "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info( logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"])) "-------------- item[label]: {} ----------------------------".format(item["label"]))
assert data[indices[num_iter] % len(data)]['file_name'] == "".join([
chr(x) for x in item['file_name']])
num_iter += 1 num_iter += 1
assert num_iter == 5 assert num_iter == 5




Loading…
Cancel
Save