From: @lixiachen Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -52,12 +52,12 @@ SamplerObj::SamplerObj() {} | |||||
| void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) { | void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) { | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| auto sampler_rt = child->Build(); | |||||
| auto sampler_rt = child->SamplerBuild(); | |||||
| sampler->AddChild(sampler_rt); | sampler->AddChild(sampler_rt); | ||||
| } | } | ||||
| } | } | ||||
| Status SamplerObj::AddChild(std::shared_ptr<SamplerObj> child) { | |||||
| Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) { | |||||
| if (child == nullptr) { | if (child == nullptr) { | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -183,7 +183,7 @@ Status DistributedSamplerObj::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | ||||
| offset_, even_dist_); | offset_, even_dist_); | ||||
| @@ -215,7 +215,7 @@ Status PKSamplerObj::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | ||||
| BuildChildren(sampler); | BuildChildren(sampler); | ||||
| @@ -232,7 +232,7 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator | |||||
| Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } | Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } | ||||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() { | |||||
| BuildChildren(sp_); | BuildChildren(sp_); | ||||
| return sp_; | return sp_; | ||||
| } | } | ||||
| @@ -241,19 +241,19 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { | |||||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | ||||
| #endif | #endif | ||||
| std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() { | |||||
| std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() { | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| if (sp_minddataset_ != nullptr) { | if (sp_minddataset_ != nullptr) { | ||||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_); | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| sampler->AddChild(child); | |||||
| sampler->AddChildSampler(child); | |||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| #endif | #endif | ||||
| auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_); | auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_); | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| sampler->AddChild(child); | |||||
| sampler->AddChildSampler(child); | |||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -289,7 +289,7 @@ Status RandomSamplerObj::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| std::shared_ptr<SamplerRT> RandomSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| bool reshuffle_each_epoch = true; | bool reshuffle_each_epoch = true; | ||||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); | auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch); | ||||
| @@ -324,7 +324,7 @@ Status SequentialSamplerObj::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | ||||
| BuildChildren(sampler); | BuildChildren(sampler); | ||||
| @@ -352,7 +352,7 @@ Status SubsetRandomSamplerObj::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() { | |||||
| // runtime sampler object | // runtime sampler object | ||||
| auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | ||||
| BuildChildren(sampler); | BuildChildren(sampler); | ||||
| @@ -395,7 +395,7 @@ Status WeightedRandomSamplerObj::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() { | |||||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() { | |||||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | ||||
| BuildChildren(sampler); | BuildChildren(sampler); | ||||
| return sampler; | return sampler; | ||||
| @@ -40,7 +40,7 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets | |||||
| } | } | ||||
| std::shared_ptr<DatasetNode> ConcatNode::Copy() { | std::shared_ptr<DatasetNode> ConcatNode::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| // create an empty vector to copy a concat | // create an empty vector to copy a concat | ||||
| auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler, | auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>(), sampler, | ||||
| children_flag_and_nums_, children_start_end_index_); | children_flag_and_nums_, children_start_end_index_); | ||||
| @@ -77,8 +77,8 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||||
| if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { | if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { | ||||
| node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_)); | node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_)); | ||||
| } else { | } else { | ||||
| node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_, | |||||
| children_start_end_index_)); | |||||
| node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(), | |||||
| children_flag_and_nums_, children_start_end_index_)); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -594,9 +594,14 @@ Status DatasetNode::AcceptAfter(IRNodePass *const p, bool *const modified) { | |||||
| } | } | ||||
| Status DatasetNode::GetShardId(int32_t *const shard_id) { | Status DatasetNode::GetShardId(int32_t *const shard_id) { | ||||
| if (!Children().empty()) { | |||||
| if (children_.size() == 1) { | |||||
| // Get shard id from the child node | // Get shard id from the child node | ||||
| return Children()[0]->GetShardId(shard_id); | |||||
| return children_[0]->GetShardId(shard_id); | |||||
| } else if (children_.size() > 1) { | |||||
| // It is okay for dataset to have more than 1 child, GetShardId shouldn't fail in this case. | |||||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||||
| // always be in front of the child_ structure, so we get the dataset size from the last child. | |||||
| return children_.back()->GetShardId(shard_id); | |||||
| } else { | } else { | ||||
| RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); | RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n"); | ||||
| } | } | ||||
| @@ -648,7 +653,7 @@ Status MappableSourceNode::Accept(IRNodePass *const p, bool *const modified) { | |||||
| } | } | ||||
| Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) { | Status NonMappableSourceNode::Accept(IRNodePass *const p, bool *const modified) { | ||||
| return p->Visit(shared_from_base<MappableSourceNode>(), modified); | |||||
| return p->Visit(shared_from_base<NonMappableSourceNode>(), modified); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -330,6 +330,13 @@ class MappableSourceNode : public DatasetNode { | |||||
| /// \brief Node name getter | /// \brief Node name getter | ||||
| /// \return Name of the current node | /// \return Name of the current node | ||||
| virtual std::string Name() const = 0; | virtual std::string Name() const = 0; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| virtual std::shared_ptr<SamplerObj> Sampler() = 0; | |||||
| /// \brief Sampler setter | |||||
| virtual void SetSampler(std::shared_ptr<SamplerObj> sampler) = 0; | |||||
| }; | }; | ||||
| // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. | // NonMappableSourceNode represents the leaf nodes that can not be randomly accessed. | ||||
| @@ -40,7 +40,7 @@ AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_sch | |||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> AlbumNode::Copy() { | std::shared_ptr<DatasetNode> AlbumNode::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); | auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -75,7 +75,8 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | ||||
| node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | ||||
| decode_, extensions, std::move(schema), std::move(sampler_->Build()))); | |||||
| decode_, extensions, std::move(schema), | |||||
| std::move(sampler_->SamplerBuild()))); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -66,6 +66,12 @@ class AlbumNode : public MappableSourceNode { | |||||
| const std::string &SchemaPath() const { return schema_path_; } | const std::string &SchemaPath() const { return schema_path_; } | ||||
| const std::vector<std::string> &ColumnNames() const { return column_names_; } | const std::vector<std::string> &ColumnNames() const { return column_names_; } | ||||
| bool Decode() const { return decode_; } | bool Decode() const { return decode_; } | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| @@ -40,7 +40,7 @@ CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | |||||
| extensions_(extensions) {} | extensions_(extensions) {} | ||||
| std::shared_ptr<DatasetNode> CelebANode::Copy() { | std::shared_ptr<DatasetNode> CelebANode::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); | auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||||
| node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | ||||
| decode_, usage_, extensions_, std::move(schema), | decode_, usage_, extensions_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | |||||
| std::move(sampler_->SamplerBuild()))); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -139,7 +139,7 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||||
| num_rows = std::min(num_rows, partition_num); | num_rows = std::min(num_rows, partition_num); | ||||
| } | } | ||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -82,6 +82,13 @@ class CelebANode : public MappableSourceNode { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -33,7 +33,7 @@ Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &us | |||||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> Cifar100Node::Copy() { | std::shared_ptr<DatasetNode> Cifar100Node::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_); | auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -68,7 +68,7 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||||
| node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | ||||
| dataset_dir_, connector_que_size_, std::move(schema), | dataset_dir_, connector_que_size_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | |||||
| std::move(sampler_->SamplerBuild()))); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -89,7 +89,7 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||||
| } | } | ||||
| int64_t num_rows, sample_size; | int64_t num_rows, sample_size; | ||||
| RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows)); | RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows)); | ||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -78,6 +78,13 @@ class Cifar100Node : public MappableSourceNode { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -33,7 +33,7 @@ Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usag | |||||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> Cifar10Node::Copy() { | std::shared_ptr<DatasetNode> Cifar10Node::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_); | auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -66,7 +66,7 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op | |||||
| node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | ||||
| dataset_dir_, connector_que_size_, std::move(schema), | dataset_dir_, connector_que_size_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | |||||
| std::move(sampler_->SamplerBuild()))); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -87,7 +87,7 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||||
| } | } | ||||
| int64_t num_rows, sample_size; | int64_t num_rows, sample_size; | ||||
| RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); | RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); | ||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -78,6 +78,13 @@ class Cifar10Node : public MappableSourceNode { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -205,7 +205,7 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | ||||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, | num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, | ||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | |||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); | |||||
| RETURN_IF_NOT_OK(clue_op->Init()); | RETURN_IF_NOT_OK(clue_op->Init()); | ||||
| @@ -38,7 +38,7 @@ CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation | |||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> CocoNode::Copy() { | std::shared_ptr<DatasetNode> CocoNode::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); | auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -121,7 +121,7 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||||
| } | } | ||||
| std::shared_ptr<CocoOp> op = | std::shared_ptr<CocoOp> op = | ||||
| std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | ||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | |||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | ||||
| node_ops->push_back(op); | node_ops->push_back(op); | ||||
| @@ -145,7 +145,7 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||||
| } | } | ||||
| int64_t num_rows = 0, sample_size; | int64_t num_rows = 0, sample_size; | ||||
| RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows)); | RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows)); | ||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -80,6 +80,13 @@ class CocoNode : public MappableSourceNode { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string annotation_file_; | std::string annotation_file_; | ||||
| @@ -122,7 +122,7 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||||
| std::shared_ptr<CsvOp> csv_op = | std::shared_ptr<CsvOp> csv_op = | ||||
| std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, | std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, | ||||
| rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, | rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, | ||||
| num_shards_, shard_id_, std::move(sampler_->Build())); | |||||
| num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); | |||||
| RETURN_IF_NOT_OK(csv_op->Init()); | RETURN_IF_NOT_OK(csv_op->Init()); | ||||
| @@ -89,6 +89,13 @@ class GeneratorNode : public MappableSourceNode { | |||||
| const std::vector<DataType> &ColumnTypes() const { return column_types_; } | const std::vector<DataType> &ColumnTypes() const { return column_types_; } | ||||
| const std::shared_ptr<SchemaObj> &Schema() const { return schema_; } | const std::shared_ptr<SchemaObj> &Schema() const { return schema_; } | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return nullptr; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override {} | |||||
| private: | private: | ||||
| py::function generator_function_; | py::function generator_function_; | ||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||
| @@ -42,7 +42,7 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar | |||||
| exts_(extensions) {} | exts_(extensions) {} | ||||
| std::shared_ptr<DatasetNode> ImageFolderNode::Copy() { | std::shared_ptr<DatasetNode> ImageFolderNode::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = | auto node = | ||||
| std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); | std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_); | ||||
| return node; | return node; | ||||
| @@ -74,7 +74,7 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod | |||||
| node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | ||||
| recursive_, decode_, exts_, class_indexing_, std::move(schema), | recursive_, decode_, exts_, class_indexing_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | |||||
| std::move(sampler_->SamplerBuild()))); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -94,7 +94,7 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> | |||||
| } | } | ||||
| int64_t sample_size, num_rows; | int64_t sample_size, num_rows; | ||||
| RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {})); | RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {})); | ||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -79,7 +79,6 @@ class ImageFolderNode : public MappableSourceNode { | |||||
| const std::string &DatasetDir() const { return dataset_dir_; } | const std::string &DatasetDir() const { return dataset_dir_; } | ||||
| bool Decode() const { return decode_; } | bool Decode() const { return decode_; } | ||||
| bool Recursive() const { return recursive_; } | bool Recursive() const { return recursive_; } | ||||
| const std::shared_ptr<SamplerObj> &Sampler() const { return sampler_; } | |||||
| const std::map<std::string, int32_t> &ClassIndexing() const { return class_indexing_; } | const std::map<std::string, int32_t> &ClassIndexing() const { return class_indexing_; } | ||||
| const std::set<std::string> &Exts() const { return exts_; } | const std::set<std::string> &Exts() const { return exts_; } | ||||
| @@ -88,6 +87,13 @@ class ImageFolderNode : public MappableSourceNode { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| bool decode_; | bool decode_; | ||||
| @@ -40,7 +40,7 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u | |||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> ManifestNode::Copy() { | std::shared_ptr<DatasetNode> ManifestNode::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_); | auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -93,7 +93,7 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||||
| std::shared_ptr<ManifestOp> manifest_op; | std::shared_ptr<ManifestOp> manifest_op; | ||||
| manifest_op = | manifest_op = | ||||
| std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | ||||
| class_index_, std::move(schema), std::move(sampler_->Build()), usage_); | |||||
| class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_); | |||||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | ||||
| node_ops->push_back(manifest_op); | node_ops->push_back(manifest_op); | ||||
| @@ -118,7 +118,7 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||||
| int64_t num_rows, sample_size; | int64_t num_rows, sample_size; | ||||
| int64_t num_classes; // dummy variable | int64_t num_classes; // dummy variable | ||||
| RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes)); | RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes)); | ||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -81,6 +81,13 @@ class ManifestNode : public MappableSourceNode { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_file_; | std::string dataset_file_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -54,7 +54,7 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st | |||||
| std::shared_ptr<DatasetNode> MindDataNode::Copy() { | std::shared_ptr<DatasetNode> MindDataNode::Copy() { | ||||
| std::shared_ptr<MindDataNode> node; | std::shared_ptr<MindDataNode> node; | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| if (dataset_files_.empty()) { | if (dataset_files_.empty()) { | ||||
| node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); | node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_); | ||||
| } else { | } else { | ||||
| @@ -85,6 +85,13 @@ class MindDataNode : public MappableSourceNode { | |||||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | ||||
| int64_t *dataset_size) override; | int64_t *dataset_size) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_file_; // search_for_pattern_ will be true in this mode | std::string dataset_file_; // search_for_pattern_ will be true in this mode | ||||
| std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode | std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode | ||||
| @@ -32,7 +32,7 @@ MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr | |||||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> MnistNode::Copy() { | std::shared_ptr<DatasetNode> MnistNode::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_); | auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -60,7 +60,8 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | ||||
| node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | ||||
| connector_que_size_, std::move(schema), std::move(sampler_->Build()))); | |||||
| connector_que_size_, std::move(schema), | |||||
| std::move(sampler_->SamplerBuild()))); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -81,7 +82,7 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ | |||||
| } | } | ||||
| int64_t num_rows, sample_size; | int64_t num_rows, sample_size; | ||||
| RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); | RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); | ||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -72,13 +72,19 @@ class MnistNode : public MappableSourceNode { | |||||
| /// \brief Getter functions | /// \brief Getter functions | ||||
| const std::string &DatasetDir() const { return dataset_dir_; } | const std::string &DatasetDir() const { return dataset_dir_; } | ||||
| const std::string &Usage() const { return usage_; } | const std::string &Usage() const { return usage_; } | ||||
| const std::shared_ptr<SamplerObj> &Sampler() const { return sampler_; } | |||||
| /// \brief Get the arguments of node | /// \brief Get the arguments of node | ||||
| /// \param[out] out_json JSON string of all attributes | /// \param[out] out_json JSON string of all attributes | ||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | std::string usage_; | ||||
| @@ -114,7 +114,7 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||||
| std::shared_ptr<RandomDataOp> op; | std::shared_ptr<RandomDataOp> op; | ||||
| op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | ||||
| std::move(data_schema_), std::move(sampler_->Build())); | |||||
| std::move(data_schema_), std::move(sampler_->SamplerBuild())); | |||||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | ||||
| node_ops->push_back(op); | node_ops->push_back(op); | ||||
| @@ -124,8 +124,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||||
| // Get the shard id of node | // Get the shard id of node | ||||
| Status RandomNode::GetShardId(int32_t *shard_id) { | Status RandomNode::GetShardId(int32_t *shard_id) { | ||||
| *shard_id = sampler_->ShardId(); | |||||
| // RandomDataset doesn't support multiple shards | |||||
| *shard_id = 0; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -138,13 +138,7 @@ Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||||
| } | } | ||||
| int64_t num_rows; | int64_t num_rows; | ||||
| num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | num_rows = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | ||||
| if (sampler_ != nullptr) { | |||||
| int64_t sample_size; | |||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | |||||
| } else { | |||||
| *dataset_size = num_rows; | |||||
| } | |||||
| *dataset_size = num_rows; | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -110,7 +110,6 @@ class RandomNode : public NonMappableSourceNode { | |||||
| std::string schema_path_; | std::string schema_path_; | ||||
| std::shared_ptr<SchemaObj> schema_; | std::shared_ptr<SchemaObj> schema_; | ||||
| std::vector<std::string> columns_list_; | std::vector<std::string> columns_list_; | ||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| std::mt19937 rand_gen_; | std::mt19937 rand_gen_; | ||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| }; | }; | ||||
| @@ -90,7 +90,7 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||||
| // Create and initalize TextFileOp | // Create and initalize TextFileOp | ||||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | ||||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | ||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | |||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); | |||||
| RETURN_IF_NOT_OK(text_file_op->Init()); | RETURN_IF_NOT_OK(text_file_op->Init()); | ||||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { | if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { | ||||
| @@ -131,7 +131,7 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||||
| std::shared_ptr<TFReaderOp> tf_reader_op = | std::shared_ptr<TFReaderOp> tf_reader_op = | ||||
| std::make_shared<TFReaderOp>(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, | std::make_shared<TFReaderOp>(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, | ||||
| std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, | std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, | ||||
| shard_id_, shard_equal_rows_, std::move(sampler_->Build())); | |||||
| shard_id_, shard_equal_rows_, std::move(sampler_->SamplerBuild())); | |||||
| RETURN_IF_NOT_OK(tf_reader_op->Init()); | RETURN_IF_NOT_OK(tf_reader_op->Init()); | ||||
| @@ -41,7 +41,7 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const | |||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| std::shared_ptr<DatasetNode> VOCNode::Copy() { | std::shared_ptr<DatasetNode> VOCNode::Copy() { | ||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->Copy(); | |||||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||||
| auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); | auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -110,8 +110,9 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||||
| } | } | ||||
| std::shared_ptr<VOCOp> voc_op; | std::shared_ptr<VOCOp> voc_op; | ||||
| voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | |||||
| voc_op = | |||||
| std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | ||||
| node_ops->push_back(voc_op); | node_ops->push_back(voc_op); | ||||
| @@ -134,7 +135,7 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge | |||||
| } | } | ||||
| int64_t num_rows = 0, sample_size; | int64_t num_rows = 0, sample_size; | ||||
| RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows)); | RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows)); | ||||
| sample_size = sampler_->Build()->CalculateNumSamples(num_rows); | |||||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -83,6 +83,13 @@ class VOCNode : public MappableSourceNode { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Sampler getter | |||||
| /// \return SamplerObj of the current node | |||||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||||
| /// \brief Sampler setter | |||||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||||
| private: | private: | ||||
| const std::string kColumnImage = "image"; | const std::string kColumnImage = "image"; | ||||
| const std::string kColumnTarget = "target"; | const std::string kColumnTarget = "target"; | ||||
| @@ -101,7 +101,7 @@ Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MappableSourceNode> n | |||||
| } | } | ||||
| Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) { | Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) { | ||||
| auto itr = weight_profile_.find("NonMappableSourceNode"); | |||||
| auto itr = weight_profile_.find("NonMappableSource"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), | CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), | ||||
| "NonLeafSource::" + node->Name() + "'s weight doesn't exist."); | "NonLeafSource::" + node->Name() + "'s weight doesn't exist."); | ||||
| int32_t weight = itr->second; | int32_t weight = itr->second; | ||||
| @@ -49,11 +49,11 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||||
| /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object | /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object | ||||
| /// \return Shared pointers to the newly created Sampler | /// \return Shared pointers to the newly created Sampler | ||||
| virtual std::shared_ptr<SamplerRT> Build() = 0; | |||||
| virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0; | |||||
| /// \brief Pure virtual function to copy a SamplerObj class | /// \brief Pure virtual function to copy a SamplerObj class | ||||
| /// \return Shared pointers to the newly copied SamplerObj | /// \return Shared pointers to the newly copied SamplerObj | ||||
| virtual std::shared_ptr<SamplerObj> Copy() = 0; | |||||
| virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0; | |||||
| /// \brief Function for derived class to get the shard id of sampler | /// \brief Function for derived class to get the shard id of sampler | ||||
| /// \return The shard id of the derived sampler | /// \return The shard id of the derived sampler | ||||
| @@ -62,7 +62,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||||
| /// \brief Adds a child to the sampler | /// \brief Adds a child to the sampler | ||||
| /// \param[in] child The sampler to be added as child | /// \param[in] child The sampler to be added as child | ||||
| /// \return the Status code returned | /// \return the Status code returned | ||||
| Status AddChild(std::shared_ptr<SamplerObj> child); | |||||
| Status AddChildSampler(std::shared_ptr<SamplerObj> child); | |||||
| virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } | virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } | ||||
| @@ -152,13 +152,13 @@ class DistributedSamplerObj : public SamplerObj { | |||||
| ~DistributedSamplerObj() = default; | ~DistributedSamplerObj() = default; | ||||
| std::shared_ptr<SamplerRT> Build() override; | |||||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||||
| auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, | auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, | ||||
| offset_, even_dist_); | offset_, even_dist_); | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| sampler->AddChild(child); | |||||
| sampler->AddChildSampler(child); | |||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -189,12 +189,12 @@ class PKSamplerObj : public SamplerObj { | |||||
| ~PKSamplerObj() = default; | ~PKSamplerObj() = default; | ||||
| std::shared_ptr<SamplerRT> Build() override; | |||||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||||
| auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| sampler->AddChild(child); | |||||
| sampler->AddChildSampler(child); | |||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -220,13 +220,13 @@ class PreBuiltSamplerObj : public SamplerObj { | |||||
| ~PreBuiltSamplerObj() = default; | ~PreBuiltSamplerObj() = default; | ||||
| std::shared_ptr<SamplerRT> Build() override; | |||||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | ||||
| #endif | #endif | ||||
| std::shared_ptr<SamplerObj> Copy() override; | |||||
| std::shared_ptr<SamplerObj> SamplerCopy() override; | |||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| @@ -245,12 +245,12 @@ class RandomSamplerObj : public SamplerObj { | |||||
| ~RandomSamplerObj() = default; | ~RandomSamplerObj() = default; | ||||
| std::shared_ptr<SamplerRT> Build() override; | |||||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_); | auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_); | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| sampler->AddChild(child); | |||||
| sampler->AddChildSampler(child); | |||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -272,12 +272,12 @@ class SequentialSamplerObj : public SamplerObj { | |||||
| ~SequentialSamplerObj() = default; | ~SequentialSamplerObj() = default; | ||||
| std::shared_ptr<SamplerRT> Build() override; | |||||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||||
| auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| sampler->AddChild(child); | |||||
| sampler->AddChildSampler(child); | |||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -299,12 +299,12 @@ class SubsetRandomSamplerObj : public SamplerObj { | |||||
| ~SubsetRandomSamplerObj() = default; | ~SubsetRandomSamplerObj() = default; | ||||
| std::shared_ptr<SamplerRT> Build() override; | |||||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| sampler->AddChild(child); | |||||
| sampler->AddChildSampler(child); | |||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -326,12 +326,12 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||||
| ~WeightedRandomSamplerObj() = default; | ~WeightedRandomSamplerObj() = default; | ||||
| std::shared_ptr<SamplerRT> Build() override; | |||||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||||
| std::shared_ptr<SamplerObj> Copy() override { | |||||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | ||||
| for (auto child : children_) { | for (auto child : children_) { | ||||
| sampler->AddChild(child); | |||||
| sampler->AddChildSampler(child); | |||||
| } | } | ||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| @@ -87,67 +87,67 @@ TEST_F(MindDataTestPipeline, TestCalculateNumSamples) { | |||||
| int64_t num_rows = 30; // dummy variable for number of rows in the dataset | int64_t num_rows = 30; // dummy variable for number of rows in the dataset | ||||
| std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1, false, 6); | std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1, false, 6); | ||||
| EXPECT_NE(sampl, nullptr); | EXPECT_NE(sampl, nullptr); | ||||
| std::shared_ptr<SamplerRT> sampler_rt = sampl->Build(); | |||||
| std::shared_ptr<SamplerRT> sampler_rt = sampl->SamplerBuild(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); | EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); | ||||
| sampl = PKSampler(3, false); | sampl = PKSampler(3, false); | ||||
| EXPECT_NE(sampl, nullptr); | EXPECT_NE(sampl, nullptr); | ||||
| sampler_rt = sampl->Build(); | |||||
| sampler_rt = sampl->SamplerBuild(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); | EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); | ||||
| sampl = RandomSampler(false, 12); | sampl = RandomSampler(false, 12); | ||||
| EXPECT_NE(sampl, nullptr); | EXPECT_NE(sampl, nullptr); | ||||
| sampler_rt = sampl->Build(); | |||||
| sampler_rt = sampl->SamplerBuild(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); | EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); | ||||
| sampl = SequentialSampler(0, 10); | sampl = SequentialSampler(0, 10); | ||||
| EXPECT_NE(sampl, nullptr); | EXPECT_NE(sampl, nullptr); | ||||
| sampler_rt = sampl->Build(); | |||||
| sampler_rt = sampl->SamplerBuild(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); | EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); | ||||
| std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; | std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; | ||||
| sampl = WeightedRandomSampler(weights, 12); | sampl = WeightedRandomSampler(weights, 12); | ||||
| EXPECT_NE(sampl, nullptr); | EXPECT_NE(sampl, nullptr); | ||||
| sampler_rt = sampl->Build(); | |||||
| sampler_rt = sampl->SamplerBuild(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); | EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); | ||||
| std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; | std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; | ||||
| sampl = SubsetRandomSampler(indices, 11); | sampl = SubsetRandomSampler(indices, 11); | ||||
| EXPECT_NE(sampl, nullptr); | EXPECT_NE(sampl, nullptr); | ||||
| sampler_rt = sampl->Build(); | |||||
| sampler_rt = sampl->SamplerBuild(); | |||||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); | EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); | ||||
| // Testing chains | // Testing chains | ||||
| // Parent and child have num_samples | // Parent and child have num_samples | ||||
| std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights, 12); | std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights, 12); | ||||
| EXPECT_NE(sampl1, nullptr); | EXPECT_NE(sampl1, nullptr); | ||||
| std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->Build(); | |||||
| std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->SamplerBuild(); | |||||
| std::shared_ptr<SamplerObj> sampl2 = SequentialSampler(0, 10); | std::shared_ptr<SamplerObj> sampl2 = SequentialSampler(0, 10); | ||||
| EXPECT_NE(sampl2, nullptr); | EXPECT_NE(sampl2, nullptr); | ||||
| std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->Build(); | |||||
| std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->SamplerBuild(); | |||||
| sampler_rt2->AddChild(sampler_rt1); | sampler_rt2->AddChild(sampler_rt1); | ||||
| EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); | EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); | ||||
| // Parent doesn't have num_samples | // Parent doesn't have num_samples | ||||
| std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights, 12); | std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights, 12); | ||||
| EXPECT_NE(sampl3, nullptr); | EXPECT_NE(sampl3, nullptr); | ||||
| std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->Build(); | |||||
| std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->SamplerBuild(); | |||||
| std::shared_ptr<SamplerObj> sampl4 = SubsetRandomSampler(indices); | std::shared_ptr<SamplerObj> sampl4 = SubsetRandomSampler(indices); | ||||
| EXPECT_NE(sampl4, nullptr); | EXPECT_NE(sampl4, nullptr); | ||||
| std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->Build(); | |||||
| std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->SamplerBuild(); | |||||
| sampler_rt4->AddChild(sampler_rt3); | sampler_rt4->AddChild(sampler_rt3); | ||||
| EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 12); | EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 12); | ||||
| // Child doesn't have num_samples | // Child doesn't have num_samples | ||||
| std::shared_ptr<SamplerObj> sampl5 = RandomSampler(false); | std::shared_ptr<SamplerObj> sampl5 = RandomSampler(false); | ||||
| EXPECT_NE(sampl5, nullptr); | EXPECT_NE(sampl5, nullptr); | ||||
| std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->Build(); | |||||
| std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->SamplerBuild(); | |||||
| std::shared_ptr<SamplerObj> sampl6 = PKSampler(3, false, 7); | std::shared_ptr<SamplerObj> sampl6 = PKSampler(3, false, 7); | ||||
| EXPECT_NE(sampl6, nullptr); | EXPECT_NE(sampl6, nullptr); | ||||
| std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->Build(); | |||||
| std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->SamplerBuild(); | |||||
| sampler_rt6->AddChild(sampler_rt5); | sampler_rt6->AddChild(sampler_rt5); | ||||
| EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); | EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); | ||||
| } | } | ||||
| @@ -156,10 +156,10 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { | |||||
| std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; | std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; | ||||
| std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices); | std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices); | ||||
| EXPECT_FALSE(indices.empty()); | EXPECT_FALSE(indices.empty()); | ||||
| EXPECT_NE(sampl1->Build(), nullptr); | |||||
| EXPECT_NE(sampl1->SamplerBuild(), nullptr); | |||||
| std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices)); | std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices)); | ||||
| EXPECT_TRUE(indices.empty()); | EXPECT_TRUE(indices.empty()); | ||||
| EXPECT_NE(sampl2->Build(), nullptr); | |||||
| EXPECT_NE(sampl2->SamplerBuild(), nullptr); | |||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) { | TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) { | ||||
| @@ -216,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestSamplerAddChild) { | |||||
| EXPECT_NE(sampler, nullptr); | EXPECT_NE(sampler, nullptr); | ||||
| auto child_sampler = SequentialSampler(); | auto child_sampler = SequentialSampler(); | ||||
| sampler->AddChild(child_sampler); | |||||
| sampler->AddChildSampler(child_sampler); | |||||
| EXPECT_NE(child_sampler, nullptr); | EXPECT_NE(child_sampler, nullptr); | ||||
| // Create an ImageFolder Dataset | // Create an ImageFolder Dataset | ||||
| @@ -406,7 +406,7 @@ def test_cache_map_failure5(): | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data.create_dict_iterator(): | for _ in data.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant" in str(e.value) | |||||
| assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value) | |||||
| assert num_iter == 0 | assert num_iter == 0 | ||||
| logger.info('test_cache_failure5 Ended.\n') | logger.info('test_cache_failure5 Ended.\n') | ||||