Merge pull request !273 from liyong126/mindrecord_subset_sampler_pythontags/v0.2.0-alpha
| @@ -391,30 +391,6 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DEPipeline::GetMindrecordSampler(const std::string &sampler_name, const py::dict &args, | |||||
| std::shared_ptr<mindrecord::ShardOperator> *ptr) { | |||||
| std::vector<int> indices; | |||||
| for (auto &arg : args) { | |||||
| std::string key = py::str(arg.first); | |||||
| py::handle value = arg.second; | |||||
| if (!value.is_none()) { | |||||
| if (key == "indices") { | |||||
| indices = ToIntVector(value); | |||||
| } else { | |||||
| std::string err_msg = "ERROR: parameter " + key + " is invalid."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (sampler_name == "SubsetRandomSampler") { | |||||
| *ptr = std::make_shared<mindrecord::ShardSample>(indices); | |||||
| } else { | |||||
| std::string err_msg = "ERROR: parameter sampler_name is invalid."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | ||||
| if (args["dataset_file"].is_none()) { | if (args["dataset_file"].is_none()) { | ||||
| std::string err_msg = "Error: at least one of dataset_files is missing"; | std::string err_msg = "Error: at least one of dataset_files is missing"; | ||||
| @@ -446,12 +422,10 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas | |||||
| } else if (key == "global_shuffle" && ToBool(value) == true) { | } else if (key == "global_shuffle" && ToBool(value) == true) { | ||||
| uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0; | uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0; | ||||
| operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed)); | operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed)); | ||||
| } else if (key == "sampler_name") { | |||||
| std::shared_ptr<mindrecord::ShardOperator> sample_op; | |||||
| auto ret = GetMindrecordSampler(ToString(value), args["sampler_params"], &sample_op); | |||||
| if (Status::OK() != ret) { | |||||
| return ret; | |||||
| } | |||||
| } else if (key == "sampler") { | |||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("_create_for_minddataset"); | |||||
| std::shared_ptr<mindrecord::ShardOperator> sample_op = | |||||
| create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||||
| operators.push_back(sample_op); | operators.push_back(sample_op); | ||||
| } | } | ||||
| } | } | ||||
| @@ -145,9 +145,6 @@ class DEPipeline { | |||||
| Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status GetMindrecordSampler(const std::string &sampler_name, const py::dict &args, | |||||
| std::shared_ptr<mindrecord::ShardOperator> *ptr); | |||||
| private: | private: | ||||
| // Execution tree that links the dataset operators. | // Execution tree that links the dataset operators. | ||||
| std::shared_ptr<ExecutionTree> tree_; | std::shared_ptr<ExecutionTree> tree_; | ||||
| @@ -54,6 +54,9 @@ | |||||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | #include "dataset/engine/datasetops/source/tf_reader_op.h" | ||||
| #include "dataset/engine/jagged_connector.h" | #include "dataset/engine/jagged_connector.h" | ||||
| #include "dataset/kernels/data/to_float16_op.h" | #include "dataset/kernels/data/to_float16_op.h" | ||||
| #include "dataset/util/random.h" | |||||
| #include "mindrecord/include/shard_operator.h" | |||||
| #include "mindrecord/include/shard_sample.h" | |||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include "pybind11/stl.h" | #include "pybind11/stl.h" | ||||
| #include "pybind11/stl_bind.h" | #include "pybind11/stl_bind.h" | ||||
| @@ -382,6 +385,7 @@ void bindTensorOps4(py::module *m) { | |||||
| void bindSamplerOps(py::module *m) { | void bindSamplerOps(py::module *m) { | ||||
| (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler"); | (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler"); | ||||
| (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); | |||||
| (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | ||||
| .def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"), | .def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"), | ||||
| @@ -399,6 +403,10 @@ void bindSamplerOps(py::module *m) { | |||||
| (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler") | (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler") | ||||
| .def(py::init<std::vector<int64_t>>(), py::arg("indices")); | .def(py::init<std::vector<int64_t>>(), py::arg("indices")); | ||||
| (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>( | |||||
| *m, "MindrecordSubsetRandomSampler") | |||||
| .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); | |||||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | ||||
| .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), | .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), | ||||
| py::arg("replacement")); | py::arg("replacement")); | ||||
| @@ -32,7 +32,7 @@ class ShardCategory : public ShardOperator { | |||||
| const std::vector<std::pair<std::string, std::string>> &get_categories() const; | const std::vector<std::pair<std::string, std::string>> &get_categories() const; | ||||
| MSRStatus operator()(ShardTask &tasks) override; | |||||
| MSRStatus execute(ShardTask &tasks) override; | |||||
| private: | private: | ||||
| std::vector<std::pair<std::string, std::string>> categories_; | std::vector<std::pair<std::string, std::string>> categories_; | ||||
| @@ -24,7 +24,25 @@ namespace mindrecord { | |||||
| class ShardOperator { | class ShardOperator { | ||||
| public: | public: | ||||
| virtual ~ShardOperator() = default; | virtual ~ShardOperator() = default; | ||||
| virtual MSRStatus operator()(ShardTask &tasks) = 0; | |||||
| MSRStatus operator()(ShardTask &tasks) { | |||||
| if (SUCCESS != this->pre_execute(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| if (SUCCESS != this->execute(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| if (SUCCESS != this->suf_execute(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; } | |||||
| virtual MSRStatus execute(ShardTask &tasks) = 0; | |||||
| virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,10 +17,12 @@ | |||||
| #ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ | #ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ | ||||
| #define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ | #define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ | ||||
| #include <memory> | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "mindrecord/include/shard_operator.h" | #include "mindrecord/include/shard_operator.h" | ||||
| #include "mindrecord/include/shard_shuffle.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| @@ -32,21 +34,23 @@ class ShardSample : public ShardOperator { | |||||
| ShardSample(int num, int den, int par); | ShardSample(int num, int den, int par); | ||||
| explicit ShardSample(const std::vector<int> &indices); | |||||
| ShardSample(const std::vector<int64_t> &indices, uint32_t seed); | |||||
| ~ShardSample() override{}; | ~ShardSample() override{}; | ||||
| const std::pair<int, int> get_partitions() const; | const std::pair<int, int> get_partitions() const; | ||||
| MSRStatus operator()(ShardTask &tasks) override; | |||||
| MSRStatus execute(ShardTask &tasks) override; | |||||
| MSRStatus suf_execute(ShardTask &tasks) override; | |||||
| private: | private: | ||||
| int numerator_; | int numerator_; | ||||
| int denominator_; | int denominator_; | ||||
| int no_of_samples_; | int no_of_samples_; | ||||
| int partition_id_; | int partition_id_; | ||||
| std::vector<int> indices_; | |||||
| std::vector<int64_t> indices_; | |||||
| SamplerType sampler_type_; | SamplerType sampler_type_; | ||||
| std::shared_ptr<ShardShuffle> shuffle_op_; | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator { | |||||
| ~ShardShuffle() override{}; | ~ShardShuffle() override{}; | ||||
| MSRStatus operator()(ShardTask &tasks) override; | |||||
| MSRStatus execute(ShardTask &tasks) override; | |||||
| private: | private: | ||||
| uint32_t shuffle_seed_; | uint32_t shuffle_seed_; | ||||
| @@ -779,8 +779,12 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { | |||||
| // Sort row group by (group_id, shard_id), prepare for parallel reading | // Sort row group by (group_id, shard_id), prepare for parallel reading | ||||
| std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); | std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); | ||||
| CreateTasks(row_group_summary, operators_); | |||||
| MS_LOG(INFO) << "Launching read threads"; | |||||
| if (CreateTasks(row_group_summary, operators_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Failed to launch read threads."; | |||||
| interrupt_ = true; | |||||
| return FAILED; | |||||
| } | |||||
| MS_LOG(INFO) << "Launching read threads."; | |||||
| if (isSimpleReader) return SUCCESS; | if (isSimpleReader) return SUCCESS; | ||||
| @@ -1152,6 +1156,9 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetBlockNext() | |||||
| } | } | ||||
| std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() { | std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() { | ||||
| if (interrupt_) { | |||||
| return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | |||||
| } | |||||
| if (block_reader_) return GetBlockNext(); | if (block_reader_) return GetBlockNext(); | ||||
| if (deliver_id_ >= static_cast<int>(tasks_.Size())) { | if (deliver_id_ >= static_cast<int>(tasks_.Size())) { | ||||
| return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | ||||
| @@ -23,6 +23,6 @@ ShardCategory::ShardCategory(const std::vector<std::pair<std::string, std::strin | |||||
| const std::vector<std::pair<std::string, std::string>> &ShardCategory::get_categories() const { return categories_; } | const std::vector<std::pair<std::string, std::string>> &ShardCategory::get_categories() const { return categories_; } | ||||
| MSRStatus ShardCategory::operator()(ShardTask &tasks) { return SUCCESS; } | |||||
| MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -46,13 +46,15 @@ ShardSample::ShardSample(int num, int den, int par) | |||||
| indices_({}), | indices_({}), | ||||
| sampler_type_(kCustomTopPercentSampler) {} | sampler_type_(kCustomTopPercentSampler) {} | ||||
| ShardSample::ShardSample(const std::vector<int> &indices) | |||||
| ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed) | |||||
| : numerator_(0), | : numerator_(0), | ||||
| denominator_(0), | denominator_(0), | ||||
| no_of_samples_(0), | no_of_samples_(0), | ||||
| partition_id_(0), | partition_id_(0), | ||||
| indices_(indices), | indices_(indices), | ||||
| sampler_type_(kSubsetRandomSampler) {} | |||||
| sampler_type_(kSubsetRandomSampler) { | |||||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed); | |||||
| } | |||||
| const std::pair<int, int> ShardSample::get_partitions() const { | const std::pair<int, int> ShardSample::get_partitions() const { | ||||
| if (numerator_ == 1 && denominator_ > 1) { | if (numerator_ == 1 && denominator_ > 1) { | ||||
| @@ -61,7 +63,7 @@ const std::pair<int, int> ShardSample::get_partitions() const { | |||||
| return std::pair<int, int>(-1, -1); | return std::pair<int, int>(-1, -1); | ||||
| } | } | ||||
| MSRStatus ShardSample::operator()(ShardTask &tasks) { | |||||
| MSRStatus ShardSample::execute(ShardTask &tasks) { | |||||
| int no_of_categories = static_cast<int>(tasks.categories); | int no_of_categories = static_cast<int>(tasks.categories); | ||||
| int total_no = static_cast<int>(tasks.Size()); | int total_no = static_cast<int>(tasks.Size()); | ||||
| @@ -115,5 +117,14 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| MSRStatus ShardSample::suf_execute(ShardTask &tasks) { | |||||
| if (sampler_type_ == kSubsetRandomSampler) { | |||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,7 @@ namespace mindspore { | |||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {} | ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {} | ||||
| MSRStatus ShardShuffle::operator()(ShardTask &tasks) { | |||||
| MSRStatus ShardShuffle::execute(ShardTask &tasks) { | |||||
| if (tasks.categories < 1) { | if (tasks.categories < 1) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -1683,9 +1683,7 @@ class MindDataset(SourceDataset): | |||||
| args["block_reader"] = self.block_reader | args["block_reader"] = self.block_reader | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| if self.sampler: | |||||
| args["sampler_name"] = self.sampler.__class__.__name__ | |||||
| args["sampler_params"] = self.sampler.__dict__ | |||||
| args["sampler"] = self.sampler | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -195,6 +195,8 @@ class SubsetRandomSampler(): | |||||
| def create(self): | def create(self): | ||||
| return cde.SubsetRandomSampler(self.indices) | return cde.SubsetRandomSampler(self.indices) | ||||
| def _create_for_minddataset(self): | |||||
| return cde.MindrecordSubsetRandomSampler(self.indices) | |||||
| class WeightedRandomSampler(): | class WeightedRandomSampler(): | ||||
| """ | """ | ||||
| @@ -30,9 +30,9 @@ | |||||
| #include "mindrecord/include/shard_shuffle.h" | #include "mindrecord/include/shard_shuffle.h" | ||||
| #include "ut_common.h" | #include "ut_common.h" | ||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | using mindspore::LogStream; | ||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| @@ -65,31 +65,31 @@ TEST_F(TestShardOperator, TestShardSampleBasic) { | |||||
| ASSERT_TRUE(i <= kSampleCount); | ASSERT_TRUE(i <= kSampleCount); | ||||
| } | } | ||||
| // TEST_F(TestShardOperator, TestShardSampleWrongNumber) { | |||||
| // MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | |||||
| // | |||||
| // std::string file_name = "./imagenet.shard01"; | |||||
| // auto column_list = std::vector<std::string>{"file_name"}; | |||||
| // | |||||
| // const int kNum = 5; | |||||
| // const int kDen = 0; | |||||
| // std::vector<std::shared_ptr<ShardOperator>> ops; | |||||
| // ops.push_back(std::make_shared<ShardSample>(kNum, kDen)); | |||||
| // | |||||
| // ShardReader dataset; | |||||
| // dataset.Open(file_name, 4, column_list, ops); | |||||
| // dataset.Launch(); | |||||
| // | |||||
| // int i = 0; | |||||
| // while (true) { | |||||
| // auto x = dataset.GetNext(); | |||||
| // if (x.empty()) break; | |||||
| // MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | |||||
| // i++; | |||||
| // } | |||||
| // dataset.Finish(); | |||||
| // ASSERT_TRUE(i <= 5); | |||||
| // } | |||||
| TEST_F(TestShardOperator, TestShardSampleWrongNumber) { | |||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | |||||
| std::string file_name = "./imagenet.shard01"; | |||||
| auto column_list = std::vector<std::string>{"file_name"}; | |||||
| const int kNum = 5; | |||||
| const int kDen = 0; | |||||
| std::vector<std::shared_ptr<ShardOperator>> ops; | |||||
| ops.push_back(std::make_shared<ShardSample>(kNum, kDen)); | |||||
| ShardReader dataset; | |||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Launch(); | |||||
| int i = 0; | |||||
| while (true) { | |||||
| auto x = dataset.GetNext(); | |||||
| if (x.empty()) break; | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | |||||
| i++; | |||||
| } | |||||
| dataset.Finish(); | |||||
| ASSERT_TRUE(i <= 5); | |||||
| } | |||||
| TEST_F(TestShardOperator, TestShardSampleRatio) { | TEST_F(TestShardOperator, TestShardSampleRatio) { | ||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | ||||
| @@ -117,7 +117,6 @@ TEST_F(TestShardOperator, TestShardSampleRatio) { | |||||
| ASSERT_TRUE(i <= 10); | ASSERT_TRUE(i <= 10); | ||||
| } | } | ||||
| TEST_F(TestShardOperator, TestShardSamplePartition) { | TEST_F(TestShardOperator, TestShardSamplePartition) { | ||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | ||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| @@ -170,8 +169,8 @@ TEST_F(TestShardOperator, TestShardCategory) { | |||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ||||
| @@ -199,8 +198,8 @@ TEST_F(TestShardOperator, TestShardShuffle) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | dataset.Finish(); | ||||
| @@ -224,8 +223,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | dataset.Finish(); | ||||
| @@ -251,8 +250,8 @@ TEST_F(TestShardOperator, TestShardShuffleSample) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | dataset.Finish(); | ||||
| @@ -278,8 +277,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | dataset.Finish(); | ||||
| @@ -307,8 +306,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| auto y = compare_dataset.GetNext(); | auto y = compare_dataset.GetNext(); | ||||
| @@ -342,8 +341,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ||||
| @@ -376,8 +375,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ||||
| category_no++; | category_no++; | ||||
| @@ -410,8 +409,8 @@ TEST_F(TestShardOperator, TestShardCategorySample) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ||||
| @@ -448,8 +447,8 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) { | |||||
| while (true) { | while (true) { | ||||
| auto x = dataset.GetNext(); | auto x = dataset.GetNext(); | ||||
| if (x.empty()) break; | if (x.empty()) break; | ||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << | |||||
| ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||||
| i++; | i++; | ||||
| ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); | ||||
| @@ -81,8 +81,6 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | ||||
| logger.info( | logger.info( | ||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | "-------------- item[label]: {} ----------------------------".format(item["label"])) | ||||
| assert data[indices[num_iter]]['file_name'] == "".join( | |||||
| [chr(x) for x in item['file_name']]) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 5 | assert num_iter == 5 | ||||
| @@ -107,8 +105,6 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file): | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | ||||
| logger.info( | logger.info( | ||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | "-------------- item[label]: {} ----------------------------".format(item["label"])) | ||||
| assert data[indices[num_iter]]['file_name'] == "".join( | |||||
| [chr(x) for x in item['file_name']]) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 6 | assert num_iter == 6 | ||||
| @@ -133,8 +129,6 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | ||||
| logger.info( | logger.info( | ||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | "-------------- item[label]: {} ----------------------------".format(item["label"])) | ||||
| assert data[indices[num_iter]]['file_name'] == "".join( | |||||
| [chr(x) for x in item['file_name']]) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 0 | assert num_iter == 0 | ||||
| @@ -159,8 +153,6 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | ||||
| logger.info( | logger.info( | ||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | "-------------- item[label]: {} ----------------------------".format(item["label"])) | ||||
| assert data[indices[num_iter] % len(data)]['file_name'] == "".join([ | |||||
| chr(x) for x in item['file_name']]) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 5 | assert num_iter == 5 | ||||
| @@ -185,8 +177,6 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): | |||||
| "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | ||||
| logger.info( | logger.info( | ||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | "-------------- item[label]: {} ----------------------------".format(item["label"])) | ||||
| assert data[indices[num_iter] % len(data)]['file_name'] == "".join([ | |||||
| chr(x) for x in item['file_name']]) | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 5 | assert num_iter == 5 | ||||