Browse Source

!11329 [MD] Rename SamplerObj::Build() to SamplerBuild()

From: @lixiachen
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
c155dfee7a
35 changed files with 185 additions and 103 deletions
  1. +12
    -12
      mindspore/ccsrc/minddata/dataset/api/samplers.cc
  2. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc
  3. +8
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc
  4. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  5. +3
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc
  6. +6
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h
  7. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc
  8. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h
  9. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc
  10. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h
  11. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc
  12. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h
  13. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc
  14. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc
  15. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h
  16. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc
  17. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h
  18. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  19. +7
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h
  20. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc
  21. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h
  22. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc
  23. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h
  24. +4
    -3
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc
  25. +7
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h
  26. +4
    -10
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc
  27. +0
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h
  28. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc
  29. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc
  30. +5
    -4
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc
  31. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h
  32. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc
  33. +23
    -23
      mindspore/ccsrc/minddata/dataset/include/samplers.h
  34. +15
    -15
      tests/ut/cpp/dataset/c_api_samplers_test.cc
  35. +1
    -1
      tests/ut/python/dataset/test_cache_map.py

+ 12
- 12
mindspore/ccsrc/minddata/dataset/api/samplers.cc View File

@@ -52,12 +52,12 @@ SamplerObj::SamplerObj() {}


void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) { void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
for (auto child : children_) { for (auto child : children_) {
auto sampler_rt = child->Build();
auto sampler_rt = child->SamplerBuild();
sampler->AddChild(sampler_rt); sampler->AddChild(sampler_rt);
} }
} }


Status SamplerObj::AddChild(std::shared_ptr<SamplerObj> child) {
Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
if (child == nullptr) { if (child == nullptr) {
return Status::OK(); return Status::OK();
} }
@@ -183,7 +183,7 @@ Status DistributedSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() {
std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_); offset_, even_dist_);
@@ -215,7 +215,7 @@ Status PKSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
BuildChildren(sampler); BuildChildren(sampler);
@@ -232,7 +232,7 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator


Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }


std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() {
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
BuildChildren(sp_); BuildChildren(sp_);
return sp_; return sp_;
} }
@@ -241,19 +241,19 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() {
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif #endif


std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
if (sp_minddataset_ != nullptr) { if (sp_minddataset_ != nullptr) {
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChild(child);
sampler->AddChildSampler(child);
} }
return sampler; return sampler;
} }
#endif #endif
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_); auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChild(child);
sampler->AddChildSampler(child);
} }
return sampler; return sampler;
} }
@@ -289,7 +289,7 @@ Status RandomSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


std::shared_ptr<SamplerRT> RandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
// runtime sampler object // runtime sampler object
bool reshuffle_each_epoch = true; bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
@@ -324,7 +324,7 @@ Status SequentialSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() {
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
BuildChildren(sampler); BuildChildren(sampler);
@@ -352,7 +352,7 @@ Status SubsetRandomSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
// runtime sampler object // runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
BuildChildren(sampler); BuildChildren(sampler);
@@ -395,7 +395,7 @@ Status WeightedRandomSamplerObj::ValidateParams() {
return Status::OK(); return Status::OK();
} }


std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler); BuildChildren(sampler);
return sampler; return sampler;


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc View File

@@ -40,7 +40,7 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
} }


std::shared_ptr<DatasetNode> ConcatNode::Copy() { std::shared_ptr<DatasetNode> ConcatNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
// create an empty vector to copy a concat // create an empty vector to copy a concat
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler, auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler,
children_flag_and_nums_, children_start_end_index_); children_flag_and_nums_, children_start_end_index_);
@@ -77,8 +77,8 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_)); node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_));
} else { } else {
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_,
children_start_end_index_));
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(),
children_flag_and_nums_, children_start_end_index_));
} }


return Status::OK(); return Status::OK();


+ 8
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc View File

@@ -594,9 +594,14 @@ Status DatasetNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
} }


Status DatasetNode::GetShardId(int32_t *const shard_id) { Status DatasetNode::GetShardId(int32_t *const shard_id) {
if (!Children().empty()) {
if (children_.size() == 1) {
// Get shard id from the child node // Get shard id from the child node
return Children()[0]->GetShardId(shard_id);
return children_[0]->GetShardId(shard_id);
} else if (children_.size() > 1) {
// It is okay for dataset to have more than 1 child, GetShardId shouldn't fail in this case.
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
// always be in front of the child_ structure, so we get the dataset size from the last child.
return children_.back()->GetShardId(shard_id);
} else { } else {
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
} }
@@ -648,7 +653,7 @@ Status MappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
} }


Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) { Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) {
return p->Visit(shared_from_base<MappableSourceNode>(), modified);
return p->Visit(shared_from_base<NonMappableSourceNode>(), modified);
} }


} // namespace dataset } // namespace dataset


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -330,6 +330,13 @@ class MappableSourceNode : public DatasetNode {
/// \brief Node name getter /// \brief Node name getter
/// \return Name of the current node /// \return Name of the current node
virtual std::string Name() const = 0; virtual std::string Name() const = 0;

/// \brief Sampler getter
/// \return SamplerObj of the current node
virtual std::shared_ptr<SamplerObj> Sampler() = 0;

/// \brief Sampler setter
virtual void SetSampler(std::shared_ptr<SamplerObj> sampler) = 0;
}; };


// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.


+ 3
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc View File

@@ -40,7 +40,7 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch
sampler_(sampler) {} sampler_(sampler) {}


std::shared_ptr<DatasetNode> AlbumNode::Copy() { std::shared_ptr<DatasetNode> AlbumNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
return node; return node;
} }
@@ -75,7 +75,8 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
RETURN_IF_NOT_OK(AddCacheOp(node_ops)); RETURN_IF_NOT_OK(AddCacheOp(node_ops));


node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, extensions, std::move(schema), std::move(sampler_->Build())));
decode_, extensions, std::move(schema),
std::move(sampler_->SamplerBuild())));
return Status::OK(); return Status::OK();
} }




+ 6
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h View File

@@ -66,6 +66,12 @@ class AlbumNode : public MappableSourceNode {
const std::string &SchemaPath() const { return schema_path_; } const std::string &SchemaPath() const { return schema_path_; }
const std::vector<std::string> &ColumnNames() const { return column_names_; } const std::vector<std::string> &ColumnNames() const { return column_names_; }
bool Decode() const { return decode_; } bool Decode() const { return decode_; }
/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }


private: private:
std::string dataset_dir_; std::string dataset_dir_;


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc View File

@@ -40,7 +40,7 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
extensions_(extensions) {} extensions_(extensions) {}


std::shared_ptr<DatasetNode> CelebANode::Copy() { std::shared_ptr<DatasetNode> CelebANode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
return node; return node;
} }
@@ -71,7 +71,7 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops


node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, usage_, extensions_, std::move(schema), decode_, usage_, extensions_, std::move(schema),
std::move(sampler_->Build())));
std::move(sampler_->SamplerBuild())));


return Status::OK(); return Status::OK();
} }
@@ -139,7 +139,7 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
num_rows = std::min(num_rows, partition_num); num_rows = std::min(num_rows, partition_num);
} }


sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;
return Status::OK(); return Status::OK();
} }


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h View File

@@ -82,6 +82,13 @@ class CelebANode : public MappableSourceNode {
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
std::string dataset_dir_; std::string dataset_dir_;
std::string usage_; std::string usage_;


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc View File

@@ -33,7 +33,7 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}


std::shared_ptr<DatasetNode> Cifar100Node::Copy() { std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
return node; return node;
} }
@@ -68,7 +68,7 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o


node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema), dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build())));
std::move(sampler_->SamplerBuild())));


return Status::OK(); return Status::OK();
} }
@@ -89,7 +89,7 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
} }
int64_t num_rows, sample_size; int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows)); RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h View File

@@ -78,6 +78,13 @@ class Cifar100Node : public MappableSourceNode {
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
std::string dataset_dir_; std::string dataset_dir_;
std::string usage_; std::string usage_;


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc View File

@@ -33,7 +33,7 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}


std::shared_ptr<DatasetNode> Cifar10Node::Copy() { std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
return node; return node;
} }
@@ -66,7 +66,7 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op


node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema), dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->Build())));
std::move(sampler_->SamplerBuild())));


return Status::OK(); return Status::OK();
} }
@@ -87,7 +87,7 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
} }
int64_t num_rows, sample_size; int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h View File

@@ -78,6 +78,13 @@ class Cifar10Node : public MappableSourceNode {
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
std::string dataset_dir_; std::string dataset_dir_;
std::string usage_; std::string usage_;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc View File

@@ -205,7 +205,7 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)


std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));


RETURN_IF_NOT_OK(clue_op->Init()); RETURN_IF_NOT_OK(clue_op->Init());




+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc View File

@@ -38,7 +38,7 @@ CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation
sampler_(sampler) {} sampler_(sampler) {}


std::shared_ptr<DatasetNode> CocoNode::Copy() { std::shared_ptr<DatasetNode> CocoNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
return node; return node;
} }
@@ -121,7 +121,7 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
} }
std::shared_ptr<CocoOp> op = std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
RETURN_IF_NOT_OK(AddCacheOp(node_ops)); RETURN_IF_NOT_OK(AddCacheOp(node_ops));


node_ops->push_back(op); node_ops->push_back(op);
@@ -145,7 +145,7 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
} }
int64_t num_rows = 0, sample_size; int64_t num_rows = 0, sample_size;
RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows)); RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h View File

@@ -80,6 +80,13 @@ class CocoNode : public MappableSourceNode {
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
std::string dataset_dir_; std::string dataset_dir_;
std::string annotation_file_; std::string annotation_file_;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc View File

@@ -122,7 +122,7 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
std::shared_ptr<CsvOp> csv_op = std::shared_ptr<CsvOp> csv_op =
std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_,
rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
num_shards_, shard_id_, std::move(sampler_->Build()));
num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));


RETURN_IF_NOT_OK(csv_op->Init()); RETURN_IF_NOT_OK(csv_op->Init());




+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h View File

@@ -89,6 +89,13 @@ class GeneratorNode : public MappableSourceNode {
const std::vector<DataType> &ColumnTypes() const { return column_types_; } const std::vector<DataType> &ColumnTypes() const { return column_types_; }
const std::shared_ptr<SchemaObj> &Schema() const { return schema_; } const std::shared_ptr<SchemaObj> &Schema() const { return schema_; }


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return nullptr; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override {}

private: private:
py::function generator_function_; py::function generator_function_;
std::vector<std::string> column_names_; std::vector<std::string> column_names_;


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -42,7 +42,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar
exts_(extensions) {} exts_(extensions) {}


std::shared_ptr<DatasetNode> ImageFolderNode::Copy() { std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = auto node =
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
return node; return node;
@@ -74,7 +74,7 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod


node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema), recursive_, decode_, exts_, class_indexing_, std::move(schema),
std::move(sampler_->Build())));
std::move(sampler_->SamplerBuild())));
return Status::OK(); return Status::OK();
} }


@@ -94,7 +94,7 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
} }
int64_t sample_size, num_rows; int64_t sample_size, num_rows;
RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {})); RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {}));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 7
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h View File

@@ -79,7 +79,6 @@ class ImageFolderNode : public MappableSourceNode {
const std::string &DatasetDir() const { return dataset_dir_; } const std::string &DatasetDir() const { return dataset_dir_; }
bool Decode() const { return decode_; } bool Decode() const { return decode_; }
bool Recursive() const { return recursive_; } bool Recursive() const { return recursive_; }
const std::shared_ptr<SamplerObj> &Sampler() const { return sampler_; }
const std::map<std::string, int32_t> &ClassIndexing() const { return class_indexing_; } const std::map<std::string, int32_t> &ClassIndexing() const { return class_indexing_; }
const std::set<std::string> &Exts() const { return exts_; } const std::set<std::string> &Exts() const { return exts_; }


@@ -88,6 +87,13 @@ class ImageFolderNode : public MappableSourceNode {
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
std::string dataset_dir_; std::string dataset_dir_;
bool decode_; bool decode_;


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc View File

@@ -40,7 +40,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
sampler_(sampler) {} sampler_(sampler) {}


std::shared_ptr<DatasetNode> ManifestNode::Copy() { std::shared_ptr<DatasetNode> ManifestNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_); auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
return node; return node;
} }
@@ -93,7 +93,7 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
std::shared_ptr<ManifestOp> manifest_op; std::shared_ptr<ManifestOp> manifest_op;
manifest_op = manifest_op =
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
class_index_, std::move(schema), std::move(sampler_->Build()), usage_);
class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_);
RETURN_IF_NOT_OK(AddCacheOp(node_ops)); RETURN_IF_NOT_OK(AddCacheOp(node_ops));


node_ops->push_back(manifest_op); node_ops->push_back(manifest_op);
@@ -118,7 +118,7 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
int64_t num_rows, sample_size; int64_t num_rows, sample_size;
int64_t num_classes; // dummy variable int64_t num_classes; // dummy variable
RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes)); RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h View File

@@ -81,6 +81,13 @@ class ManifestNode : public MappableSourceNode {
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
std::string dataset_file_; std::string dataset_file_;
std::string usage_; std::string usage_;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc View File

@@ -54,7 +54,7 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st


std::shared_ptr<DatasetNode> MindDataNode::Copy() { std::shared_ptr<DatasetNode> MindDataNode::Copy() {
std::shared_ptr<MindDataNode> node; std::shared_ptr<MindDataNode> node;
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
if (dataset_files_.empty()) { if (dataset_files_.empty()) {
node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_);
} else { } else {


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h View File

@@ -85,6 +85,13 @@ class MindDataNode : public MappableSourceNode {
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override; int64_t *dataset_size) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
std::string dataset_file_; // search_for_pattern_ will be true in this mode std::string dataset_file_; // search_for_pattern_ will be true in this mode
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode


+ 4
- 3
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc View File

@@ -32,7 +32,7 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}


std::shared_ptr<DatasetNode> MnistNode::Copy() { std::shared_ptr<DatasetNode> MnistNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_); auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
return node; return node;
} }
@@ -60,7 +60,8 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
RETURN_IF_NOT_OK(AddCacheOp(node_ops)); RETURN_IF_NOT_OK(AddCacheOp(node_ops));


node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->Build())));
connector_que_size_, std::move(schema),
std::move(sampler_->SamplerBuild())));


return Status::OK(); return Status::OK();
} }
@@ -81,7 +82,7 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
} }
int64_t num_rows, sample_size; int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 7
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h View File

@@ -72,13 +72,19 @@ class MnistNode : public MappableSourceNode {
/// \brief Getter functions /// \brief Getter functions
const std::string &DatasetDir() const { return dataset_dir_; } const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; } const std::string &Usage() const { return usage_; }
const std::shared_ptr<SamplerObj> &Sampler() const { return sampler_; }


/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
std::string dataset_dir_; std::string dataset_dir_;
std::string usage_; std::string usage_;


+ 4
- 10
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc View File

@@ -114,7 +114,7 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops


std::shared_ptr<RandomDataOp> op; std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema_), std::move(sampler_->Build()));
std::move(data_schema_), std::move(sampler_->SamplerBuild()));
RETURN_IF_NOT_OK(AddCacheOp(node_ops)); RETURN_IF_NOT_OK(AddCacheOp(node_ops));


node_ops->push_back(op); node_ops->push_back(op);
@@ -124,8 +124,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops


// Get the shard id of node // Get the shard id of node
Status RandomNode::GetShardId(int32_t *shard_id) { Status RandomNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
// RandomDataset doesn't support multiple shards
*shard_id = 0;
return Status::OK(); return Status::OK();
} }


@@ -138,13 +138,7 @@ Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
} }
int64_t num_rows; int64_t num_rows;
num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
if (sampler_ != nullptr) {
int64_t sample_size;
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
} else {
*dataset_size = num_rows;
}
*dataset_size = num_rows;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();
} }


+ 0
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h View File

@@ -110,7 +110,6 @@ class RandomNode : public NonMappableSourceNode {
std::string schema_path_; std::string schema_path_;
std::shared_ptr<SchemaObj> schema_; std::shared_ptr<SchemaObj> schema_;
std::vector<std::string> columns_list_; std::vector<std::string> columns_list_;
std::shared_ptr<SamplerObj> sampler_;
std::mt19937 rand_gen_; std::mt19937 rand_gen_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
}; };


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc View File

@@ -90,7 +90,7 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// Create and initalize TextFileOp // Create and initalize TextFileOp
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild()));
RETURN_IF_NOT_OK(text_file_op->Init()); RETURN_IF_NOT_OK(text_file_op->Init());


if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) {


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc View File

@@ -131,7 +131,7 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
std::shared_ptr<TFReaderOp> tf_reader_op = std::shared_ptr<TFReaderOp> tf_reader_op =
std::make_shared<TFReaderOp>(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::make_shared<TFReaderOp>(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files,
std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_,
shard_id_, shard_equal_rows_, std::move(sampler_->Build()));
shard_id_, shard_equal_rows_, std::move(sampler_->SamplerBuild()));


RETURN_IF_NOT_OK(tf_reader_op->Init()); RETURN_IF_NOT_OK(tf_reader_op->Init());




+ 5
- 4
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc View File

@@ -41,7 +41,7 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const
sampler_(sampler) {} sampler_(sampler) {}


std::shared_ptr<DatasetNode> VOCNode::Copy() { std::shared_ptr<DatasetNode> VOCNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy();
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_);
return node; return node;
} }
@@ -110,8 +110,9 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
} }


std::shared_ptr<VOCOp> voc_op; std::shared_ptr<VOCOp> voc_op;
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
voc_op =
std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
RETURN_IF_NOT_OK(AddCacheOp(node_ops)); RETURN_IF_NOT_OK(AddCacheOp(node_ops));


node_ops->push_back(voc_op); node_ops->push_back(voc_op);
@@ -134,7 +135,7 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
} }
int64_t num_rows = 0, sample_size; int64_t num_rows = 0, sample_size;
RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows)); RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows));
sample_size = sampler_->Build()->CalculateNumSamples(num_rows);
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();


+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h View File

@@ -83,6 +83,13 @@ class VOCNode : public MappableSourceNode {
/// \return Status of the function /// \return Status of the function
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;


/// \brief Sampler getter
/// \return SamplerObj of the current node
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }

/// \brief Sampler setter
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }

private: private:
const std::string kColumnImage = "image"; const std::string kColumnImage = "image";
const std::string kColumnTarget = "target"; const std::string kColumnTarget = "target";


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/opt/post/auto_worker_pass.cc View File

@@ -101,7 +101,7 @@ Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MappableSourceNode> n
} }


Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) { Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
auto itr = weight_profile_.find("NonMappableSourceNode");
auto itr = weight_profile_.find("NonMappableSource");
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(),
"NonLeafSource::" + node->Name() + "'s weight doesn't exist."); "NonLeafSource::" + node->Name() + "'s weight doesn't exist.");
int32_t weight = itr->second; int32_t weight = itr->second;


+ 23
- 23
mindspore/ccsrc/minddata/dataset/include/samplers.h View File

@@ -49,11 +49,11 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {


/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler /// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<SamplerRT> Build() = 0;
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;


/// \brief Pure virtual function to copy a SamplerObj class /// \brief Pure virtual function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj /// \return Shared pointers to the newly copied SamplerObj
virtual std::shared_ptr<SamplerObj> Copy() = 0;
virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;


/// \brief Function for derived class to get the shard id of sampler /// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler /// \return The shard id of the derived sampler
@@ -62,7 +62,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \brief Adds a child to the sampler /// \brief Adds a child to the sampler
/// \param[in] child The sampler to be added as child /// \param[in] child The sampler to be added as child
/// \return the Status code returned /// \return the Status code returned
Status AddChild(std::shared_ptr<SamplerObj> child);
Status AddChildSampler(std::shared_ptr<SamplerObj> child);


virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }


@@ -152,13 +152,13 @@ class DistributedSamplerObj : public SamplerObj {


~DistributedSamplerObj() = default; ~DistributedSamplerObj() = default;


std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerRT> SamplerBuild() override;


std::shared_ptr<SamplerObj> Copy() override {
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
offset_, even_dist_); offset_, even_dist_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChild(child);
sampler->AddChildSampler(child);
} }
return sampler; return sampler;
} }
@@ -189,12 +189,12 @@ class PKSamplerObj : public SamplerObj {


~PKSamplerObj() = default; ~PKSamplerObj() = default;


std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerRT> SamplerBuild() override;


std::shared_ptr<SamplerObj> Copy() override {
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChild(child);
sampler->AddChildSampler(child);
} }
return sampler; return sampler;
} }
@@ -220,13 +220,13 @@ class PreBuiltSamplerObj : public SamplerObj {


~PreBuiltSamplerObj() = default; ~PreBuiltSamplerObj() = default;


std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerRT> SamplerBuild() override;


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif #endif


std::shared_ptr<SamplerObj> Copy() override;
std::shared_ptr<SamplerObj> SamplerCopy() override;


Status ValidateParams() override; Status ValidateParams() override;


@@ -245,12 +245,12 @@ class RandomSamplerObj : public SamplerObj {


~RandomSamplerObj() = default; ~RandomSamplerObj() = default;


std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerRT> SamplerBuild() override;


std::shared_ptr<SamplerObj> Copy() override {
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_); auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChild(child);
sampler->AddChildSampler(child);
} }
return sampler; return sampler;
} }
@@ -272,12 +272,12 @@ class SequentialSamplerObj : public SamplerObj {


~SequentialSamplerObj() = default; ~SequentialSamplerObj() = default;


std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerRT> SamplerBuild() override;


std::shared_ptr<SamplerObj> Copy() override {
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChild(child);
sampler->AddChildSampler(child);
} }
return sampler; return sampler;
} }
@@ -299,12 +299,12 @@ class SubsetRandomSamplerObj : public SamplerObj {


~SubsetRandomSamplerObj() = default; ~SubsetRandomSamplerObj() = default;


std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerRT> SamplerBuild() override;


std::shared_ptr<SamplerObj> Copy() override {
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChild(child);
sampler->AddChildSampler(child);
} }
return sampler; return sampler;
} }
@@ -326,12 +326,12 @@ class WeightedRandomSamplerObj : public SamplerObj {


~WeightedRandomSamplerObj() = default; ~WeightedRandomSamplerObj() = default;


std::shared_ptr<SamplerRT> Build() override;
std::shared_ptr<SamplerRT> SamplerBuild() override;


std::shared_ptr<SamplerObj> Copy() override {
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
for (auto child : children_) { for (auto child : children_) {
sampler->AddChild(child);
sampler->AddChildSampler(child);
} }
return sampler; return sampler;
} }


+ 15
- 15
tests/ut/cpp/dataset/c_api_samplers_test.cc View File

@@ -87,67 +87,67 @@ TEST_F(MindDataTestPipeline, TestCalculateNumSamples) {
int64_t num_rows = 30; // dummy variable for number of rows in the dataset int64_t num_rows = 30; // dummy variable for number of rows in the dataset
std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1, false, 6); std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1, false, 6);
EXPECT_NE(sampl, nullptr); EXPECT_NE(sampl, nullptr);
std::shared_ptr<SamplerRT> sampler_rt = sampl->Build();
std::shared_ptr<SamplerRT> sampler_rt = sampl->SamplerBuild();
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6);


sampl = PKSampler(3, false); sampl = PKSampler(3, false);
EXPECT_NE(sampl, nullptr); EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->Build();
sampler_rt = sampl->SamplerBuild();
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);


sampl = RandomSampler(false, 12); sampl = RandomSampler(false, 12);
EXPECT_NE(sampl, nullptr); EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->Build();
sampler_rt = sampl->SamplerBuild();
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);


sampl = SequentialSampler(0, 10); sampl = SequentialSampler(0, 10);
EXPECT_NE(sampl, nullptr); EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->Build();
sampler_rt = sampl->SamplerBuild();
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10);


std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
sampl = WeightedRandomSampler(weights, 12); sampl = WeightedRandomSampler(weights, 12);
EXPECT_NE(sampl, nullptr); EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->Build();
sampler_rt = sampl->SamplerBuild();
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);


std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21};
sampl = SubsetRandomSampler(indices, 11); sampl = SubsetRandomSampler(indices, 11);
EXPECT_NE(sampl, nullptr); EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->Build();
sampler_rt = sampl->SamplerBuild();
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11);


// Testing chains // Testing chains
// Parent and child have num_samples // Parent and child have num_samples
std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights, 12); std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights, 12);
EXPECT_NE(sampl1, nullptr); EXPECT_NE(sampl1, nullptr);
std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->Build();
std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->SamplerBuild();


std::shared_ptr<SamplerObj> sampl2 = SequentialSampler(0, 10); std::shared_ptr<SamplerObj> sampl2 = SequentialSampler(0, 10);
EXPECT_NE(sampl2, nullptr); EXPECT_NE(sampl2, nullptr);
std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->Build();
std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->SamplerBuild();
sampler_rt2->AddChild(sampler_rt1); sampler_rt2->AddChild(sampler_rt1);
EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10);


// Parent doesn't have num_samples // Parent doesn't have num_samples
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights, 12); std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights, 12);
EXPECT_NE(sampl3, nullptr); EXPECT_NE(sampl3, nullptr);
std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->Build();
std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->SamplerBuild();


std::shared_ptr<SamplerObj> sampl4 = SubsetRandomSampler(indices); std::shared_ptr<SamplerObj> sampl4 = SubsetRandomSampler(indices);
EXPECT_NE(sampl4, nullptr); EXPECT_NE(sampl4, nullptr);
std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->Build();
std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->SamplerBuild();
sampler_rt4->AddChild(sampler_rt3); sampler_rt4->AddChild(sampler_rt3);
EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 12); EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 12);


// Child doesn't have num_samples // Child doesn't have num_samples
std::shared_ptr<SamplerObj> sampl5 = RandomSampler(false); std::shared_ptr<SamplerObj> sampl5 = RandomSampler(false);
EXPECT_NE(sampl5, nullptr); EXPECT_NE(sampl5, nullptr);
std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->Build();
std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->SamplerBuild();


std::shared_ptr<SamplerObj> sampl6 = PKSampler(3, false, 7); std::shared_ptr<SamplerObj> sampl6 = PKSampler(3, false, 7);
EXPECT_NE(sampl6, nullptr); EXPECT_NE(sampl6, nullptr);
std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->Build();
std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->SamplerBuild();
sampler_rt6->AddChild(sampler_rt5); sampler_rt6->AddChild(sampler_rt5);
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
} }
@@ -156,10 +156,10 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices); std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices);
EXPECT_FALSE(indices.empty()); EXPECT_FALSE(indices.empty());
EXPECT_NE(sampl1->Build(), nullptr);
EXPECT_NE(sampl1->SamplerBuild(), nullptr);
std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices)); std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices));
EXPECT_TRUE(indices.empty()); EXPECT_TRUE(indices.empty());
EXPECT_NE(sampl2->Build(), nullptr);
EXPECT_NE(sampl2->SamplerBuild(), nullptr);
} }


TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) { TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {
@@ -216,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestSamplerAddChild) {
EXPECT_NE(sampler, nullptr); EXPECT_NE(sampler, nullptr);


auto child_sampler = SequentialSampler(); auto child_sampler = SequentialSampler();
sampler->AddChild(child_sampler);
sampler->AddChildSampler(child_sampler);
EXPECT_NE(child_sampler, nullptr); EXPECT_NE(child_sampler, nullptr);


// Create an ImageFolder Dataset // Create an ImageFolder Dataset


+ 1
- 1
tests/ut/python/dataset/test_cache_map.py View File

@@ -406,7 +406,7 @@ def test_cache_map_failure5():
num_iter = 0 num_iter = 0
for _ in data.create_dict_iterator(): for _ in data.create_dict_iterator():
num_iter += 1 num_iter += 1
assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant" in str(e.value)
assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value)


assert num_iter == 0 assert num_iter == 0
logger.info('test_cache_failure5 Ended.\n') logger.info('test_cache_failure5 Ended.\n')


Loading…
Cancel
Save