From: @mhmotallebi Reviewed-by: @robingrosman,@nsyca Signed-off-by: @robingrosmantags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -51,10 +51,12 @@ Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_pt | |||
| std::shared_ptr<SamplerObj> sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||
| std::shared_ptr<CacheLookupOp> lookup_op = nullptr; | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt)); | |||
| RETURN_IF_NOT_OK(CacheLookupOp::Builder() | |||
| .SetNumWorkers(num_workers) | |||
| .SetClient(cache_client_) | |||
| .SetSampler(sampler->SamplerBuild()) | |||
| .SetSampler(sampler_rt) | |||
| .Build(&lookup_op)); | |||
| *ds = lookup_op; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -53,10 +53,13 @@ Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::share | |||
| std::shared_ptr<SamplerObj> sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||
| std::shared_ptr<CacheLookupOp> lookup_op = nullptr; | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt)); | |||
| RETURN_IF_NOT_OK(CacheLookupOp::Builder() | |||
| .SetNumWorkers(num_workers) | |||
| .SetClient(cache_client_) | |||
| .SetSampler(sampler->SamplerBuild()) | |||
| .SetSampler(sampler_rt) | |||
| .Build(&lookup_op)); | |||
| *ds = lookup_op; | |||
| @@ -74,10 +74,11 @@ std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() { | |||
| return std::static_pointer_cast<SamplerObj>(lookup_node_copy_); | |||
| } | |||
| std::shared_ptr<SamplerRT> CacheLookupNode::SamplerBuild() { | |||
| Status CacheLookupNode::SamplerBuild(std::shared_ptr<SamplerRT> *out) { | |||
| // Runtime cache lookup op should already been built, so we just return it here | |||
| auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_); | |||
| return std::shared_ptr<SamplerRT>(lookup_op); | |||
| *out = std::shared_ptr<SamplerRT>(lookup_op); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -48,8 +48,9 @@ class CacheLookupNode : public DatasetNode, public SamplerObj { | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to convert a SamplerObj class into a runtime sampler object | |||
| /// \return Shared pointers to the newly created Sampler | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| /// \param[out] out Shared pointer to the newly created Sampler | |||
| /// \return The Status code of the function. It returns OK status if sampler is created successfully. | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *out) override; | |||
| /// \brief a base class override function to copy a SamplerObj class | |||
| /// \return Shared pointers to the newly copied SamplerObj | |||
| @@ -52,7 +52,9 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| RETURN_IF_NOT_OK(cache_->Build()); | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); | |||
| cache_op->SetSampler(sampler_->SamplerBuild()); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| cache_op->SetSampler(sampler_rt); | |||
| cache_op->set_total_repeats(GetTotalRepeats()); | |||
| cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(cache_op); | |||
| @@ -95,8 +95,10 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||
| // calculate the size of the shard | |||
| int64_t shard_dataset_size = 0; | |||
| std::shared_ptr<SamplerRT> sampler_rt_base = nullptr; | |||
| if (sampler_) RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt_base)); | |||
| std::shared_ptr<DistributedSamplerRT> sampler_rt = | |||
| sampler_ ? std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_->SamplerBuild()) : nullptr; | |||
| sampler_ ? std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_rt_base) : nullptr; | |||
| if (sampler_rt != nullptr) { | |||
| sampler_rt->SetNumRowsInDataset(total_dataset_size); | |||
| sampler_rt->InitSampler(); | |||
| @@ -123,8 +125,10 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||
| if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { | |||
| op = std::make_shared<ConcatOp>(connector_que_size_); | |||
| } else { | |||
| op = std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(), children_flag_and_nums_, | |||
| children_start_end_index_); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| op = | |||
| std::make_shared<ConcatOp>(connector_que_size_, sampler_rt, children_flag_and_nums_, children_start_end_index_); | |||
| } | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| @@ -71,9 +71,11 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| // Argument that is not exposed to user in the API. | |||
| std::set<std::string> extensions = {}; | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto album_op = std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, | |||
| extensions, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| extensions, std::move(schema), std::move(sampler_rt)); | |||
| album_op->set_total_repeats(GetTotalRepeats()); | |||
| album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(album_op); | |||
| @@ -66,10 +66,11 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| // label is like this:0 1 0 0 1...... | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto celeba_op = | |||
| std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_, | |||
| extensions_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| auto celeba_op = std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| decode_, usage_, extensions_, std::move(schema), std::move(sampler_rt)); | |||
| celeba_op->set_total_repeats(GetTotalRepeats()); | |||
| celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(celeba_op); | |||
| @@ -140,7 +141,9 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||
| num_rows = std::min(num_rows, partition_num); | |||
| } | |||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| return Status::OK(); | |||
| } | |||
| @@ -63,10 +63,12 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto cifar_op = | |||
| std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||
| connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| connector_que_size_, std::move(schema), std::move(sampler_rt)); | |||
| cifar_op->set_total_repeats(GetTotalRepeats()); | |||
| cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(cifar_op); | |||
| @@ -90,7 +92,10 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows)); | |||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| @@ -61,10 +61,12 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto cifar_op = | |||
| std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||
| connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| connector_que_size_, std::move(schema), std::move(sampler_rt)); | |||
| cifar_op->set_total_repeats(GetTotalRepeats()); | |||
| cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(cifar_op); | |||
| @@ -88,7 +90,10 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows)); | |||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| @@ -119,9 +119,12 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| std::shared_ptr<CocoOp> op = | |||
| std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_rt)); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| @@ -145,7 +148,9 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g | |||
| } | |||
| int64_t num_rows = 0, sample_size; | |||
| RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows)); | |||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| @@ -78,7 +78,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ | |||
| column_types_.push_back((col.type())); | |||
| } | |||
| } | |||
| std::shared_ptr<SamplerRT> sampler_rt = sampler_ ? sampler_->SamplerBuild() : nullptr; | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| if (sampler_) RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by | |||
| // GeneratorOp internally. Here it is given a zero which is the default in generator builder | |||
| @@ -145,7 +146,9 @@ Status GeneratorNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &s | |||
| int64_t sample_size; | |||
| int64_t num_rows; | |||
| num_rows = source_len_; | |||
| sample_size = sampler_ ? sampler_->SamplerBuild()->CalculateNumSamples(num_rows) : num_rows; | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| if (sampler_) RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_ ? sampler_rt->CalculateNumSamples(num_rows) : num_rows; | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| @@ -69,10 +69,12 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto op = std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| recursive_, decode_, exts_, class_indexing_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild())); | |||
| auto op = | |||
| std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, recursive_, | |||
| decode_, exts_, class_indexing_, std::move(schema), std::move(sampler_rt)); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| @@ -95,7 +97,9 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> | |||
| } | |||
| int64_t sample_size, num_rows; | |||
| RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {})); | |||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| @@ -91,9 +91,11 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| std::shared_ptr<ManifestOp> manifest_op; | |||
| manifest_op = | |||
| std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | |||
| class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| manifest_op = std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, | |||
| decode_, class_index_, std::move(schema), std::move(sampler_rt), usage_); | |||
| manifest_op->set_total_repeats(GetTotalRepeats()); | |||
| manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(manifest_op); | |||
| @@ -118,7 +120,9 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||
| int64_t num_rows, sample_size; | |||
| int64_t num_classes; // dummy variable | |||
| RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes)); | |||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| @@ -57,9 +57,11 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto op = std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| std::move(schema), std::move(sampler_rt)); | |||
| op->set_total_repeats(GetTotalRepeats()); | |||
| op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| @@ -83,7 +85,9 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); | |||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -42,11 +42,13 @@ namespace dataset { | |||
| // Constructor | |||
| SamplerObj::SamplerObj() {} | |||
| void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) { | |||
| Status SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> *sampler) { | |||
| for (auto child : children_) { | |||
| auto sampler_rt = child->SamplerBuild(); | |||
| sampler->AddChild(sampler_rt); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(child->SamplerBuild(&sampler_rt)); | |||
| RETURN_IF_NOT_OK((*sampler)->AddChild(sampler_rt)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) { | |||
| @@ -113,12 +115,13 @@ Status DistributedSamplerObj::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() { | |||
| Status DistributedSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||
| offset_, even_dist_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| *sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, | |||
| offset_, even_dist_); | |||
| Status s = BuildChildren(sampler); | |||
| sampler = s.IsOk() ? sampler : nullptr; | |||
| return s; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -186,11 +189,12 @@ Status PKSamplerObj::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() { | |||
| Status PKSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| *sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_); | |||
| Status s = BuildChildren(sampler); | |||
| sampler = s.IsOk() ? sampler : nullptr; | |||
| return s; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -218,9 +222,14 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator | |||
| Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); } | |||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() { | |||
| BuildChildren(sp_); | |||
| return sp_; | |||
| Status PreBuiltSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| Status s = BuildChildren(&sp_); | |||
| if (s.IsOk()) | |||
| *sampler = sp_; | |||
| else | |||
| *sampler = nullptr; | |||
| // FIXME: what to do with sp_ if status is not OK? | |||
| return s; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -280,11 +289,12 @@ Status RandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() { | |||
| Status RandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| *sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_); | |||
| Status s = BuildChildren(sampler); | |||
| sampler = s.IsOk() ? sampler : nullptr; | |||
| return s; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -333,11 +343,12 @@ Status SequentialSamplerObj::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() { | |||
| Status SequentialSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| *sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_); | |||
| Status s = BuildChildren(sampler); | |||
| sampler = s.IsOk() ? sampler : nullptr; | |||
| return s; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -362,11 +373,12 @@ Status SubsetSamplerObj::ValidateParams() { | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() { | |||
| Status SubsetSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| *sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_); | |||
| Status s = BuildChildren(sampler); | |||
| sampler = s.IsOk() ? sampler : nullptr; | |||
| return s; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -399,11 +411,12 @@ Status SubsetSamplerObj::to_json(nlohmann::json *out_json) { | |||
| SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples) | |||
| : SubsetSamplerObj(std::move(indices), num_samples) {} | |||
| std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() { | |||
| Status SubsetRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| // runtime sampler object | |||
| auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| *sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_); | |||
| Status s = BuildChildren(sampler); | |||
| sampler = s.IsOk() ? sampler : nullptr; | |||
| return s; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -480,10 +493,11 @@ Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) { | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() { | |||
| auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | |||
| BuildChildren(sampler); | |||
| return sampler; | |||
| Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) { | |||
| *sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_); | |||
| Status s = BuildChildren(sampler); | |||
| sampler = s.IsOk() ? sampler : nullptr; | |||
| return s; | |||
| } | |||
| } // namespace dataset | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -48,8 +48,9 @@ class SamplerObj { | |||
| virtual Status ValidateParams() = 0; | |||
| /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object | |||
| /// \return Shared pointers to the newly created Sampler | |||
| virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0; | |||
| /// \param[out] sampler Shared pointers to the newly created Sampler | |||
| /// \return The Status code of the function. It returns OK status if sampler is created successfully. | |||
| virtual Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) = 0; | |||
| /// \brief Pure virtual function to copy a SamplerObj class | |||
| /// \return Shared pointers to the newly copied SamplerObj | |||
| @@ -78,7 +79,8 @@ class SamplerObj { | |||
| protected: | |||
| /// \brief A function that calls build on the children of this sampler | |||
| /// \param[in] sampler The samplerRT object built from this sampler | |||
| void BuildChildren(std::shared_ptr<SamplerRT> sampler); | |||
| /// \return the Status code returned | |||
| Status BuildChildren(std::shared_ptr<SamplerRT> *sampler); | |||
| std::vector<std::shared_ptr<SamplerObj>> children_; | |||
| }; | |||
| @@ -91,7 +93,7 @@ class DistributedSamplerObj : public SamplerObj { | |||
| ~DistributedSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, | |||
| @@ -133,7 +135,7 @@ class PKSamplerObj : public SamplerObj { | |||
| ~PKSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_); | |||
| @@ -169,7 +171,7 @@ class PreBuiltSamplerObj : public SamplerObj { | |||
| ~PreBuiltSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| @@ -194,7 +196,7 @@ class RandomSamplerObj : public SamplerObj { | |||
| ~RandomSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_); | |||
| @@ -227,7 +229,7 @@ class SequentialSamplerObj : public SamplerObj { | |||
| ~SequentialSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_); | |||
| @@ -259,7 +261,7 @@ class SubsetSamplerObj : public SamplerObj { | |||
| ~SubsetSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_); | |||
| @@ -293,7 +295,7 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj { | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_); | |||
| @@ -316,7 +318,7 @@ class WeightedRandomSamplerObj : public SamplerObj { | |||
| ~WeightedRandomSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override; | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override { | |||
| auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); | |||
| @@ -108,11 +108,12 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| RETURN_IF_NOT_OK(schema->AddColumn( | |||
| ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| } | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| std::shared_ptr<VOCOp> voc_op; | |||
| voc_op = | |||
| std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_rt)); | |||
| voc_op->set_total_repeats(GetTotalRepeats()); | |||
| voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(voc_op); | |||
| @@ -135,7 +136,9 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge | |||
| } | |||
| int64_t num_rows = 0, sample_size; | |||
| RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows)); | |||
| sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -90,67 +90,74 @@ TEST_F(MindDataTestPipeline, TestCalculateNumSamples) { | |||
| int64_t num_rows = 30; // dummy variable for number of rows in the dataset | |||
| std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1, false, 6); | |||
| EXPECT_NE(sampl, nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt = sampl->SamplerBuild(); | |||
| std::shared_ptr<SamplerRT> sampler_rt; | |||
| sampl->SamplerBuild(&sampler_rt); | |||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); | |||
| sampl = PKSampler(3, false); | |||
| EXPECT_NE(sampl, nullptr); | |||
| sampler_rt = sampl->SamplerBuild(); | |||
| sampl->SamplerBuild(&sampler_rt); | |||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30); | |||
| sampl = RandomSampler(false, 12); | |||
| EXPECT_NE(sampl, nullptr); | |||
| sampler_rt = sampl->SamplerBuild(); | |||
| sampl->SamplerBuild(&sampler_rt); | |||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); | |||
| sampl = SequentialSampler(0, 10); | |||
| EXPECT_NE(sampl, nullptr); | |||
| sampler_rt = sampl->SamplerBuild(); | |||
| sampl->SamplerBuild(&sampler_rt); | |||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); | |||
| std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; | |||
| sampl = WeightedRandomSampler(weights, 12); | |||
| EXPECT_NE(sampl, nullptr); | |||
| sampler_rt = sampl->SamplerBuild(); | |||
| sampl->SamplerBuild(&sampler_rt); | |||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); | |||
| std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; | |||
| sampl = SubsetRandomSampler(indices, 11); | |||
| EXPECT_NE(sampl, nullptr); | |||
| sampler_rt = sampl->SamplerBuild(); | |||
| sampl->SamplerBuild(&sampler_rt); | |||
| EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); | |||
| // Testing chains | |||
| // Parent and child have num_samples | |||
| std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights, 12); | |||
| EXPECT_NE(sampl1, nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->SamplerBuild(); | |||
| std::shared_ptr<SamplerRT> sampler_rt1; | |||
| sampl1->SamplerBuild(&sampler_rt1); | |||
| std::shared_ptr<SamplerObj> sampl2 = SequentialSampler(0, 10); | |||
| EXPECT_NE(sampl2, nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->SamplerBuild(); | |||
| std::shared_ptr<SamplerRT> sampler_rt2; | |||
| sampl2->SamplerBuild(&sampler_rt2); | |||
| sampler_rt2->AddChild(sampler_rt1); | |||
| EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); | |||
| // Parent doesn't have num_samples | |||
| std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights, 12); | |||
| EXPECT_NE(sampl3, nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->SamplerBuild(); | |||
| std::shared_ptr<SamplerRT> sampler_rt3; | |||
| sampl3->SamplerBuild(&sampler_rt3); | |||
| std::shared_ptr<SamplerObj> sampl4 = SubsetRandomSampler(indices); | |||
| EXPECT_NE(sampl4, nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->SamplerBuild(); | |||
| std::shared_ptr<SamplerRT> sampler_rt4; | |||
| sampl4->SamplerBuild(&sampler_rt4); | |||
| sampler_rt4->AddChild(sampler_rt3); | |||
| EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11); | |||
| // Child doesn't have num_samples | |||
| std::shared_ptr<SamplerObj> sampl5 = RandomSampler(false); | |||
| EXPECT_NE(sampl5, nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->SamplerBuild(); | |||
| std::shared_ptr<SamplerRT> sampler_rt5; | |||
| sampl5->SamplerBuild(&sampler_rt5); | |||
| std::shared_ptr<SamplerObj> sampl6 = PKSampler(3, false, 7); | |||
| EXPECT_NE(sampl6, nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->SamplerBuild(); | |||
| std::shared_ptr<SamplerRT> sampler_rt6; | |||
| sampl6->SamplerBuild(&sampler_rt6); | |||
| sampler_rt6->AddChild(sampler_rt5); | |||
| EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7); | |||
| } | |||
| @@ -159,10 +166,14 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { | |||
| std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; | |||
| std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices); | |||
| EXPECT_FALSE(indices.empty()); | |||
| EXPECT_NE(sampl1->SamplerBuild(), nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| sampl1->SamplerBuild(&sampler_rt); | |||
| EXPECT_NE(sampler_rt, nullptr); | |||
| std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices)); | |||
| EXPECT_TRUE(indices.empty()); | |||
| EXPECT_NE(sampl2->SamplerBuild(), nullptr); | |||
| std::shared_ptr<SamplerRT> sampler_rt2 = nullptr; | |||
| sampl2->SamplerBuild(&sampler_rt2); | |||
| EXPECT_NE(sampler_rt, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) { | |||