Merge pull request !6905 from lixiachen/CacheOp_devtags/v1.1.0
| @@ -1075,7 +1075,7 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() { | |||
| std::shared_ptr<ClueOp> clue_op = | |||
| std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, | |||
| sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||
| RETURN_EMPTY_IF_ERROR(clue_op->Init()); | |||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| @@ -1256,7 +1256,7 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() { | |||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | |||
| sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, | |||
| num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||
| RETURN_EMPTY_IF_ERROR(csv_op->Init()); | |||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| @@ -1502,7 +1502,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() { | |||
| // Create and initalize TextFileOp | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr)); | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||
| RETURN_EMPTY_IF_ERROR(text_file_op->Init()); | |||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||
| @@ -1345,6 +1345,9 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||
| std::string err_msg = "Error: No dataset files specified for manifest"; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| int num_workers = 0; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<ManifestOp::Builder> builder = std::make_shared<ManifestOp::Builder>(); | |||
| (void)builder->SetManifestFile(ToString(args["dataset_file"])); | |||
| @@ -1354,7 +1357,8 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| @@ -1365,12 +1369,27 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||
| (void)builder->SetDecode(ToBool(value)); | |||
| } else if (key == "usage") { | |||
| (void)builder->SetUsage(ToString(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<ManifestOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *top = op; | |||
| std::shared_ptr<ManifestOp> manifest_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&manifest_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(manifest_op)); | |||
| *top = manifest_op; | |||
| // Additionally, add a cache if required. | |||
| // Note that this cache op is only acting as a place holder for the caching position | |||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||
| // caching logic in the tree. | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, manifest_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = manifest_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1380,6 +1399,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified."); | |||
| int num_workers = 0; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>(); | |||
| (void)builder->SetDir(ToString(args["dataset_dir"])); | |||
| (void)builder->SetTask(ToString(args["task"])); | |||
| @@ -1389,7 +1410,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| @@ -1398,12 +1420,26 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| (void)builder->SetDecode(ToBool(value)); | |||
| } else if (key == "class_indexing") { | |||
| (void)builder->SetClassIndex(ToStringMap(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<VOCOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *top = op; | |||
| std::shared_ptr<VOCOp> voc_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&voc_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(voc_op)); | |||
| *top = voc_op; | |||
| // Additionally, add a cache if required. | |||
| // Note that this cache op is only acting as a place holder for the caching position | |||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||
| // caching logic in the tree. | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, voc_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = voc_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1425,6 +1461,8 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| int num_workers = 0; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<CocoOp::Builder> builder = std::make_shared<CocoOp::Builder>(); | |||
| (void)builder->SetDir(ToString(args["dataset_dir"])); | |||
| (void)builder->SetFile(ToString(args["annotation_file"])); | |||
| @@ -1434,19 +1472,35 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "decode") { | |||
| (void)builder->SetDecode(ToBool(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<CocoOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *top = op; | |||
| std::shared_ptr<CocoOp> coco_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&coco_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(coco_op)); | |||
| *top = coco_op; | |||
| // Additionally, add a cache if required. | |||
| // Note that this cache op is only acting as a place holder for the caching position | |||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||
| // caching logic in the tree. | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, coco_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = coco_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1458,6 +1512,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| int num_workers = 0; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>(); | |||
| (void)builder->SetCifarDir(ToString(args["dataset_dir"])); | |||
| @@ -1467,22 +1523,38 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "usage") { | |||
| (void)builder->SetUsage(ToString(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } | |||
| } | |||
| } | |||
| (void)builder->SetCifarType(true); | |||
| std::shared_ptr<CifarOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *top = op; | |||
| std::shared_ptr<CifarOp> cifar_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&cifar_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op)); | |||
| *top = cifar_op; | |||
| // Additionally, add a cache if required. | |||
| // Note that this cache op is only acting as a place holder for the caching position | |||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||
| // caching logic in the tree. | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = cifar_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1494,6 +1566,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| int num_workers = 0; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>(); | |||
| (void)builder->SetCifarDir(ToString(args["dataset_dir"])); | |||
| @@ -1503,22 +1577,37 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "usage") { | |||
| (void)builder->SetUsage(ToString(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } | |||
| } | |||
| } | |||
| (void)builder->SetCifarType(false); | |||
| std::shared_ptr<CifarOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *top = op; | |||
| std::shared_ptr<CifarOp> cifar_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&cifar_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op)); | |||
| *top = cifar_op; | |||
| // Additionally, add a cache if required. | |||
| // Note that this cache op is only acting as a place holder for the caching position | |||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||
| // caching logic in the tree. | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = cifar_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1609,6 +1698,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| int num_workers = 0; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<MnistOp::Builder> builder = std::make_shared<MnistOp::Builder>(); | |||
| (void)builder->SetDir(ToString(args["dataset_dir"])); | |||
| @@ -1618,19 +1709,35 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (key == "usage") { | |||
| (void)builder->SetUsage(ToString(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<MnistOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *top = op; | |||
| std::shared_ptr<MnistOp> mnist_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&mnist_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(mnist_op)); | |||
| *top = mnist_op; | |||
| // Additionally, add a cache if required. | |||
| // Note that this cache op is only acting as a place holder for the caching position | |||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||
| // caching logic in the tree. | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, mnist_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = mnist_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1642,6 +1749,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| int num_workers = 0; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<CelebAOp::Builder> builder = std::make_shared<CelebAOp::Builder>(); | |||
| if (builder == nullptr) { | |||
| std::string err_msg = "Create celebaop builder failed"; | |||
| @@ -1653,7 +1762,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| @@ -1664,13 +1774,28 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| (void)builder->SetExtensions(ToStringSet(value)); | |||
| } else if (key == "usage") { | |||
| (void)builder->SetUsage(ToString(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<CelebAOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *top = op; | |||
| std::shared_ptr<CelebAOp> celeba_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&celeba_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(celeba_op)); | |||
| *top = celeba_op; | |||
| // Additionally, add a cache if required. | |||
| // Note that this cache op is only acting as a place holder for the caching position | |||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||
| // caching logic in the tree. | |||
| if (cache_client) { | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, celeba_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = celeba_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1678,6 +1803,9 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| // Required arguments | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| int num_workers = 0; | |||
| std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| files_list = ToStringVector(args["dataset_files"]); | |||
| @@ -1693,7 +1821,8 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "shuffle_global") { | |||
| @@ -1705,16 +1834,35 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| (void)builder->SetNumDevices(num_devices); | |||
| } else if (key == "shard_id") { | |||
| (void)builder->SetDeviceId(ToInt(value)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| } | |||
| } | |||
| } | |||
| // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed | |||
| // because TextFileOp is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| if (sampler) { | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (cache_client) { | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| std::shared_ptr<TextFileOp> txt_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&txt_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); | |||
| *top = txt_op; | |||
| if (shuffle_required) { | |||
| if (!cache_client && shuffle_required) { | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t shuffle_size = 0; | |||
| int64_t num_rows = 0; | |||
| @@ -1729,6 +1877,15 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| *bottom = txt_op; | |||
| } | |||
| // Add a cache op over this op if required and update the output subtree (top/bottom) | |||
| if (cache_client) { | |||
| // Note, it is not allowed to have both shuffle and cache | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, txt_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = txt_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1829,6 +1986,10 @@ Status DEPipeline::ParseBuildSentencePieceVocabOp(const py::dict &args, std::sha | |||
| Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| int num_workers = 0; | |||
| std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| files_list = ToStringVector(args["dataset_files"]); | |||
| @@ -1844,7 +2005,8 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "shuffle_global") { | |||
| @@ -1866,16 +2028,35 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| } | |||
| } | |||
| (void)builder->SetColsKeyMap(map_dict); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| } | |||
| } | |||
| } | |||
| // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed | |||
| // because ClueOp is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| if (sampler) { | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (cache_client) { | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| std::shared_ptr<ClueOp> clue_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&clue_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); | |||
| *top = clue_op; | |||
| if (shuffle_required) { | |||
| if (!cache_client && shuffle_required) { | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t shuffle_size = 0; | |||
| int64_t num_rows = 0; | |||
| @@ -1890,6 +2071,15 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| *bottom = clue_op; | |||
| } | |||
| // Add a cache op over this op if required and update the output subtree (top/bottom) | |||
| if (cache_client) { | |||
| // Note, it is not allowed to have both shuffle and cache | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, clue_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = clue_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -1921,6 +2111,9 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num | |||
| Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||
| std::shared_ptr<DatasetOp> *bottom) { | |||
| std::vector<std::string> files_list; | |||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||
| std::shared_ptr<Sampler> sampler = nullptr; | |||
| int num_workers = 0; | |||
| std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); | |||
| if (!args["dataset_files"].is_none()) { | |||
| files_list = ToStringVector(args["dataset_files"]); | |||
| @@ -1938,7 +2131,8 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| num_workers = ToInt(value); | |||
| (void)builder->SetNumWorkers(num_workers); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "shuffle_global") { | |||
| @@ -1971,16 +2165,35 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| } else if (key == "column_names") { | |||
| col_names = ToStringVector(value); | |||
| (void)builder->SetColumName(col_names); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||
| } | |||
| } | |||
| } | |||
| // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed | |||
| // because CsvOp is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| if (sampler) { | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } else if (cache_client) { | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| (void)builder->SetSampler(std::move(sampler)); | |||
| } | |||
| std::shared_ptr<CsvOp> csv_op; | |||
| RETURN_IF_NOT_OK(builder->Build(&csv_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op)); | |||
| *top = csv_op; | |||
| if (shuffle_required) { | |||
| if (!cache_client && shuffle_required) { | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t shuffle_size = 0; | |||
| int64_t num_rows = 0; | |||
| @@ -1995,6 +2208,15 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| *bottom = csv_op; | |||
| } | |||
| // Add a cache op over this op if required and update the output subtree (top/bottom) | |||
| if (cache_client) { | |||
| // Note, it is not allowed to have both shuffle and cache | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, csv_op, &cache_op)); | |||
| *top = cache_op; | |||
| *bottom = csv_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -70,13 +70,12 @@ Status AlbumOp::Builder::SanityCheck() { | |||
| AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | |||
| const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_wkrs, queue_size), | |||
| : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| folder_path_(file_dir), | |||
| decode_(do_decode), | |||
| extensions_(exts), | |||
| data_schema_(std::move(data_schema)), | |||
| sampler_(std::move(sampler)), | |||
| row_cnt_(0), | |||
| buf_cnt_(0), | |||
| sampler_ind_(0), | |||
| @@ -284,7 +284,6 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| std::set<std::string> extensions_; // extensions allowed | |||
| std::unordered_map<std::string, int32_t> col_name_map_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| int64_t row_cnt_; | |||
| int64_t buf_cnt_; | |||
| int64_t sampler_ind_; | |||
| @@ -25,13 +25,18 @@ | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| #include "minddata/dataset/engine/jagged_connector.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| ClueOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| : builder_device_id_(0), | |||
| builder_num_devices_(1), | |||
| builder_num_samples_(0), | |||
| builder_shuffle_files_(false), | |||
| builder_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| @@ -68,7 +73,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) { | |||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, | |||
| builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, | |||
| builder_device_id_); | |||
| builder_device_id_, std::move(builder_sampler_)); | |||
| RETURN_IF_NOT_OK(clue_op->Init()); | |||
| *op = std::move(clue_op); | |||
| @@ -88,8 +93,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim | |||
| ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_rows_per_shard_(0), | |||
| all_num_rows_(0), | |||
| @@ -539,5 +544,21 @@ Status ClueOp::ComputeColMap() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so | |||
| // that this clue op will produce the full set of data into the cache. | |||
| void ClueOp::MakeSimpleProducer() { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| shuffle_files_ = false; | |||
| num_samples_ = 0; | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status ClueOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<ClueOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -20,6 +20,7 @@ | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <nlohmann/json.hpp> | |||
| @@ -122,6 +123,14 @@ class ClueOp : public ParallelOp { | |||
| // @return - the a string vector | |||
| std::vector<std::string> split(const std::string &s, char delim); | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| @@ -133,12 +142,13 @@ class ClueOp : public ParallelOp { | |||
| std::vector<std::string> builder_clue_files_list_; | |||
| bool builder_shuffle_files_; | |||
| std::map<std::string, std::string> builder_cols_to_keyword_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| }; | |||
| // Constructor of ClueOp | |||
| ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler); | |||
| // Default destructor | |||
| ~ClueOp() = default; | |||
| @@ -173,6 +183,17 @@ class ClueOp : public ParallelOp { | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return clue_files_list_; } | |||
| /// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so | |||
| /// that this clue op will produce the full set of data into the cache. | |||
| void MakeSimpleProducer(); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -124,7 +124,7 @@ Status CocoOp::Builder::SanityCheck() { | |||
| CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | |||
| int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_workers, queue_size), | |||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||
| decode_(decode), | |||
| row_cnt_(0), | |||
| buf_cnt_(0), | |||
| @@ -132,7 +132,6 @@ CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, | |||
| image_folder_path_(image_folder_path), | |||
| annotation_path_(annotation_path), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| sampler_(std::move(sampler)), | |||
| data_schema_(std::move(data_schema)) { | |||
| io_block_queues_.Init(num_workers_, queue_size); | |||
| } | |||
| @@ -206,6 +206,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| /// \return Status of the node visit | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "CocoOp"; } | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| @@ -324,7 +328,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| std::string annotation_path_; | |||
| TaskType task_type_; | |||
| int32_t rows_per_buffer_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| WaitPost wp_; | |||
| @@ -22,12 +22,17 @@ | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/jagged_connector.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CsvOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| : builder_device_id_(0), | |||
| builder_num_devices_(1), | |||
| builder_num_samples_(0), | |||
| builder_shuffle_files_(false), | |||
| builder_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| @@ -59,7 +64,8 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) { | |||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | |||
| builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | |||
| builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); | |||
| builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_, | |||
| std::move(builder_sampler_)); | |||
| RETURN_IF_NOT_OK(csv_op->Init()); | |||
| *op = std::move(csv_op); | |||
| @@ -70,8 +76,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, | |||
| const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | |||
| int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | |||
| int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| csv_files_list_(std::move(csv_files_list)), | |||
| field_delim_(field_delim), | |||
| column_default_list_(column_default), | |||
| @@ -889,5 +895,21 @@ Status CsvOp::ComputeColMap() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so | |||
| // that this csv op will produce the full set of data into the cache. | |||
| void CsvOp::MakeSimpleProducer() { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| shuffle_files_ = false; | |||
| num_samples_ = 0; | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status CsvOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<CsvOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -240,6 +240,14 @@ class CsvOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| @@ -253,6 +261,7 @@ class CsvOp : public ParallelOp { | |||
| char builder_field_delim_; | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | |||
| std::vector<std::string> builder_column_name_list_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| }; | |||
| // Constructor of CsvOp | |||
| @@ -261,7 +270,8 @@ class CsvOp : public ParallelOp { | |||
| CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | |||
| int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, | |||
| std::shared_ptr<Sampler> sampler); | |||
| // Default destructor | |||
| ~CsvOp() = default; | |||
| @@ -297,6 +307,17 @@ class CsvOp : public ParallelOp { | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return csv_files_list_; } | |||
| /// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so | |||
| /// that this csv op will produce the full set of data into the cache. | |||
| void MakeSimpleProducer(); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -29,6 +29,7 @@ | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -499,5 +500,21 @@ Status TextFileOp::ComputeColMap() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so | |||
| // that this text file op will produce the full set of data into the cache. | |||
| void TextFileOp::MakeSimpleProducer() { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| shuffle_files_ = false; | |||
| total_rows_ = 0; | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status TextFileOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<TextFileOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -188,6 +188,17 @@ class TextFileOp : public ParallelOp { | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return text_files_list_; } | |||
| /// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so | |||
| /// that this text file op will produce the full set of data into the cache. | |||
| void MakeSimpleProducer(); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -212,6 +212,7 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten | |||
| folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); | |||
| RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); | |||
| RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation)); | |||
| trow->setId(row_id); | |||
| trow->push_back(std::move(image)); | |||
| trow->insert(trow->end(), annotation.begin(), annotation.end()); | |||
| } | |||
| @@ -45,6 +45,9 @@ | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||
| #ifdef ENABLE_PYTHON | |||
| @@ -260,6 +263,21 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| @@ -81,6 +81,12 @@ class CacheMergeOp; | |||
| class CacheLookupOp; | |||
| class BuildSentencePieceVocabOp; | |||
| class ClueOp; | |||
| class CsvOp; | |||
| class TextFileOp; | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| @@ -211,6 +217,12 @@ class NodePass : public Pass { | |||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | |||
| @@ -36,6 +36,9 @@ | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| @@ -141,6 +144,36 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| // If we are a ClueOp in a caching tree, then change our config so that it becomes a basic | |||
| // ClueOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| // If we are a CsvOp in a caching tree, then change our config so that it becomes a basic | |||
| // CsvOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| // If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic | |||
| // TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| #endif | |||
| // Perform leaf node cache transform identification | |||
| @@ -163,34 +196,22 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, b | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -214,18 +235,12 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> nod | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| #endif | |||
| @@ -65,6 +65,24 @@ class CacheTransformPass : public TreePass { | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Perform leaf node cache tranform identifications | |||
| @@ -2969,6 +2969,8 @@ class MnistDataset(MappableDataset): | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -2988,7 +2990,7 @@ class MnistDataset(MappableDataset): | |||
| @check_mnist_cifar_dataset | |||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | |||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||
| shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_dir = dataset_dir | |||
| @@ -2998,6 +3000,7 @@ class MnistDataset(MappableDataset): | |||
| self.shuffle_level = shuffle | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.cache = cache | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -3008,6 +3011,7 @@ class MnistDataset(MappableDataset): | |||
| args["sampler"] = self.sampler | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -3872,6 +3876,8 @@ class ManifestDataset(MappableDataset): | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -3897,7 +3903,8 @@ class ManifestDataset(MappableDataset): | |||
| @check_manifestdataset | |||
| def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None, | |||
| shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None): | |||
| shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, | |||
| cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_file = dataset_file | |||
| @@ -3913,6 +3920,7 @@ class ManifestDataset(MappableDataset): | |||
| self.shuffle_level = shuffle | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.cache = cache | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -3925,6 +3933,7 @@ class ManifestDataset(MappableDataset): | |||
| args["decode"] = self.decode | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -4055,6 +4064,8 @@ class Cifar10Dataset(MappableDataset): | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -4082,7 +4093,7 @@ class Cifar10Dataset(MappableDataset): | |||
| @check_mnist_cifar_dataset | |||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | |||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||
| shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_dir = dataset_dir | |||
| @@ -4092,6 +4103,7 @@ class Cifar10Dataset(MappableDataset): | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.shuffle_level = shuffle | |||
| self.cache = cache | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -4102,6 +4114,7 @@ class Cifar10Dataset(MappableDataset): | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["shuffle"] = self.shuffle_level | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -4202,6 +4215,8 @@ class Cifar100Dataset(MappableDataset): | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -4226,7 +4241,7 @@ class Cifar100Dataset(MappableDataset): | |||
| @check_mnist_cifar_dataset | |||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | |||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||
| shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_dir = dataset_dir | |||
| @@ -4236,6 +4251,7 @@ class Cifar100Dataset(MappableDataset): | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.shuffle_level = shuffle | |||
| self.cache = cache | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -4246,6 +4262,7 @@ class Cifar100Dataset(MappableDataset): | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["shuffle"] = self.shuffle_level | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -4630,6 +4647,8 @@ class VOCDataset(MappableDataset): | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Raises: | |||
| RuntimeError: If xml of Annotations is an invalid format. | |||
| @@ -4667,7 +4686,8 @@ class VOCDataset(MappableDataset): | |||
| @check_vocdataset | |||
| def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None, | |||
| num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): | |||
| num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, | |||
| cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_dir = dataset_dir | |||
| self.task = task | |||
| @@ -4679,6 +4699,7 @@ class VOCDataset(MappableDataset): | |||
| self.shuffle_level = shuffle | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.cache = cache | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -4692,6 +4713,7 @@ class VOCDataset(MappableDataset): | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -4838,6 +4860,8 @@ class CocoDataset(MappableDataset): | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -4873,7 +4897,7 @@ class CocoDataset(MappableDataset): | |||
| @check_cocodataset | |||
| def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None, | |||
| shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): | |||
| shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_dir = dataset_dir | |||
| self.annotation_file = annotation_file | |||
| @@ -4884,6 +4908,7 @@ class CocoDataset(MappableDataset): | |||
| self.shuffle_level = shuffle | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.cache = cache | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -4896,6 +4921,7 @@ class CocoDataset(MappableDataset): | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -4993,6 +5019,8 @@ class CelebADataset(MappableDataset): | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -5003,7 +5031,7 @@ class CelebADataset(MappableDataset): | |||
| @check_celebadataset | |||
| def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, | |||
| extensions=None, num_samples=None, num_shards=None, shard_id=None): | |||
| extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_dir = dataset_dir | |||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | |||
| @@ -5015,6 +5043,7 @@ class CelebADataset(MappableDataset): | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.shuffle_level = shuffle | |||
| self.cache = cache | |||
| if usage != "all": | |||
| dir = os.path.realpath(self.dataset_dir) | |||
| @@ -5033,6 +5062,7 @@ class CelebADataset(MappableDataset): | |||
| args["usage"] = self.usage | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -5142,6 +5172,8 @@ class CLUEDataset(SourceDataset): | |||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -5152,7 +5184,7 @@ class CLUEDataset(SourceDataset): | |||
| @check_cluedataset | |||
| def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, | |||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_files = self._find_files(dataset_files) | |||
| self.dataset_files.sort() | |||
| @@ -5293,6 +5325,15 @@ class CLUEDataset(SourceDataset): | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| # The clue dataset does not directly support a sampler. It has provided sampling arguments | |||
| # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in | |||
| # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. | |||
| sampler_shuffle = self.shuffle_files | |||
| sampler = None | |||
| self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, | |||
| non_mappable=True) | |||
| self.cache = cache | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["dataset_files"] = self.dataset_files | |||
| @@ -5304,6 +5345,8 @@ class CLUEDataset(SourceDataset): | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["cols_to_keyword"] = self.cols_to_keyword | |||
| args["sampler"] = self.sampler | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -5359,6 +5402,9 @@ class CSVDataset(SourceDataset): | |||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -5369,7 +5415,7 @@ class CSVDataset(SourceDataset): | |||
| @check_csvdataset | |||
| def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, | |||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_files = self._find_files(dataset_files) | |||
| self.dataset_files.sort() | |||
| @@ -5394,6 +5440,15 @@ class CSVDataset(SourceDataset): | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.cache = cache | |||
| # The CSV dataset does not directly support a sampler. It has provided sampling arguments | |||
| # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in | |||
| # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. | |||
| sampler_shuffle = self.shuffle_files | |||
| sampler = None | |||
| self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, | |||
| non_mappable=True) | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["dataset_files"] = self.dataset_files | |||
| @@ -5407,6 +5462,8 @@ class CSVDataset(SourceDataset): | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["sampler"] = self.sampler | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -5457,6 +5514,9 @@ class TextFileDataset(SourceDataset): | |||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> | |||
| @@ -5466,7 +5526,7 @@ class TextFileDataset(SourceDataset): | |||
| @check_textfiledataset | |||
| def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, | |||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.dataset_files = self._find_files(dataset_files) | |||
| self.dataset_files.sort() | |||
| @@ -5488,6 +5548,15 @@ class TextFileDataset(SourceDataset): | |||
| self.num_shards = num_shards | |||
| self.shard_id = shard_id | |||
| self.cache = cache | |||
| # The text file dataset does not directly support a sampler. It has provided sampling arguments | |||
| # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in | |||
| # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. | |||
| sampler_shuffle = self.shuffle_files | |||
| sampler = None | |||
| self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, | |||
| non_mappable=True) | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["dataset_files"] = self.dataset_files | |||
| @@ -5498,6 +5567,8 @@ class TextFileDataset(SourceDataset): | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| args["sampler"] = self.sampler | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -83,6 +83,9 @@ def check_mnist_cifar_dataset(method): | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -110,6 +113,9 @@ def check_manifestdataset(method): | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -180,6 +186,9 @@ def check_vocdataset(method): | |||
| validate_dataset_param_value(nreq_param_dict, param_dict, dict) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -216,6 +225,9 @@ def check_cocodataset(method): | |||
| raise ValueError("CocoDataset doesn't support PKSampler") | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -252,6 +264,9 @@ def check_celebadataset(method): | |||
| if sampler is not None and isinstance(sampler, samplers.PKSampler): | |||
| raise ValueError("CelebADataset does not support PKSampler.") | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -843,6 +858,9 @@ def check_cluedataset(method): | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -886,6 +904,9 @@ def check_csvdataset(method): | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -905,6 +926,9 @@ def check_textfiledataset(method): | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -103,6 +103,24 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_coco" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_mnist" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_celeba" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_manifest" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_cifar" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_voc" 1 | |||
| HandleRcExit $? 0 0 | |||
| # Run two parallel pipelines (sharing cache) | |||
| for i in $(seq 1 2) | |||
| do | |||
| @@ -282,6 +300,15 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_clue" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_csv" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1 | |||
| HandleRcExit $? 0 0 | |||
| for i in $(seq 1 3) | |||
| do | |||
| test_name="test_cache_nomap_multiple_cache${i}" | |||
| @@ -17,6 +17,7 @@ Testing cache operator with mappable datasets | |||
| """ | |||
| import os | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| @@ -26,7 +27,13 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" | |||
| COCO_DATA_DIR = "../data/dataset/testCOCO/train/" | |||
| COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" | |||
| NO_IMAGE_DIR = "../data/dataset/testRandomData/" | |||
| MNIST_DATA_DIR = "../data/dataset/testMnistData/" | |||
| CELEBA_DATA_DIR = "../data/dataset/testCelebAData/" | |||
| VOC_DATA_DIR = "../data/dataset/testVOC2012/" | |||
| MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" | |||
| CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/" | |||
| CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/" | |||
| MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord" | |||
| GENERATE_GOLDEN = False | |||
| @@ -443,7 +450,7 @@ def test_cache_map_failure5(): | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_failure6(): | |||
| """ | |||
| Test no-cache-supporting leaf ops with Map under cache (failure) | |||
| Test no-cache-supporting MindRecord leaf with Map under cache (failure) | |||
| repeat | |||
| | | |||
| @@ -451,7 +458,7 @@ def test_cache_map_failure6(): | |||
| | | |||
| Map(resize) | |||
| | | |||
| Coco | |||
| MindRecord | |||
| """ | |||
| logger.info("Test cache failure 6") | |||
| @@ -461,22 +468,66 @@ def test_cache_map_failure6(): | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| data = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) | |||
| columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] | |||
| num_readers = 1 | |||
| # The dataset has 5 records | |||
| data = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, num_readers) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| data = data.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| data = data.map(input_columns=["img_data"], operations=resize_op, cache=some_cache) | |||
| data = data.repeat(4) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "There is currently no support for CocoOp under cache" in str(e.value) | |||
| assert "There is currently no support for MindRecordOp under cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure6 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_failure7(): | |||
| """ | |||
| Test no-cache-supporting Generator leaf with Map under cache (failure) | |||
| repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(lambda x: x) | |||
| | | |||
| Generator | |||
| """ | |||
| def generator_1d(): | |||
| for i in range(64): | |||
| yield (np.array(i),) | |||
| logger.info("Test cache failure 7") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| data = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data = data.map((lambda x: x), ["data"], cache=some_cache) | |||
| data = data.repeat(4) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "There is currently no support for GeneratorOp under cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure7 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_parameter_check(): | |||
| """ | |||
| @@ -1236,6 +1287,421 @@ def test_cache_map_epoch_ctrl3(): | |||
| logger.info("test_cache_map_epoch_ctrl3 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_coco1(): | |||
| """ | |||
| Test mappable coco leaf with cache op right over the leaf | |||
| cache | |||
| | | |||
| Coco | |||
| """ | |||
| logger.info("Test cache map coco1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 6 records | |||
| ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, | |||
| cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 6 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_coco1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_coco2(): | |||
| """ | |||
| Test mappable coco leaf with the cache op later in the tree above the map(resize) | |||
| cache | |||
| | | |||
| Map(resize) | |||
| | | |||
| Coco | |||
| """ | |||
| logger.info("Test cache map coco2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 6 records | |||
| ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 6 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_coco2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_mnist1(): | |||
| """ | |||
| Test mappable mnist leaf with cache op right over the leaf | |||
| cache | |||
| | | |||
| Mnist | |||
| """ | |||
| logger.info("Test cache map mnist1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 10 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_mnist1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_mnist2(): | |||
| """ | |||
| Test mappable mnist leaf with the cache op later in the tree above the map(resize) | |||
| cache | |||
| | | |||
| Map(resize) | |||
| | | |||
| Mnist | |||
| """ | |||
| logger.info("Test cache map mnist2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 10 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_mnist2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_celeba1(): | |||
| """ | |||
| Test mappable celeba leaf with cache op right over the leaf | |||
| cache | |||
| | | |||
| CelebA | |||
| """ | |||
| logger.info("Test cache map celeba1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 4 records | |||
| ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 4 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_celeba1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_celeba2(): | |||
| """ | |||
| Test mappable celeba leaf with the cache op later in the tree above the map(resize) | |||
| cache | |||
| | | |||
| Map(resize) | |||
| | | |||
| CelebA | |||
| """ | |||
| logger.info("Test cache map celeba2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 4 records | |||
| ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 4 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_celeba2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_manifest1(): | |||
| """ | |||
| Test mappable manifest leaf with cache op right over the leaf | |||
| cache | |||
| | | |||
| Manifest | |||
| """ | |||
| logger.info("Test cache map manifest1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 4 records | |||
| ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 4 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_manifest1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_manifest2(): | |||
| """ | |||
| Test mappable manifest leaf with the cache op later in the tree above the map(resize) | |||
| cache | |||
| | | |||
| Map(resize) | |||
| | | |||
| Manifest | |||
| """ | |||
| logger.info("Test cache map manifest2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 4 records | |||
| ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 4 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_manifest2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_cifar1(): | |||
| """ | |||
| Test mappable cifar10 leaf with cache op right over the leaf | |||
| cache | |||
| | | |||
| Cifar10 | |||
| """ | |||
| logger.info("Test cache map cifar1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 10 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_cifar1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_cifar2(): | |||
| """ | |||
| Test mappable cifar100 leaf with the cache op later in the tree above the map(resize) | |||
| cache | |||
| | | |||
| Map(resize) | |||
| | | |||
| Cifar100 | |||
| """ | |||
| logger.info("Test cache map cifar2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 10 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_cifar2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_voc1(): | |||
| """ | |||
| Test mappable voc leaf with cache op right over the leaf | |||
| cache | |||
| | | |||
| VOC | |||
| """ | |||
| logger.info("Test cache map voc1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 9 records | |||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 9 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_voc1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_voc2(): | |||
| """ | |||
| Test mappable voc leaf with the cache op later in the tree above the map(resize) | |||
| cache | |||
| | | |||
| Map(resize) | |||
| | | |||
| VOC | |||
| """ | |||
| logger.info("Test cache map voc2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 9 records | |||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 9 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_map_voc2 Ended.\n") | |||
| if __name__ == '__main__': | |||
| test_cache_map_basic1() | |||
| test_cache_map_basic2() | |||
| @@ -20,22 +20,26 @@ import itertools | |||
| import pytest | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| import mindspore.dataset.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| DATA_DIR2 = ["../data/dataset/testTextTFRecord/text.tfrecord"] | |||
| TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"] | |||
| SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json" | |||
| DATA_DIR3 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | |||
| SCHEMA_DIR3 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" | |||
| TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | |||
| TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" | |||
| DATA_DIR4 = "../data/dataset/testImageNetData/train/" | |||
| IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/" | |||
| CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json' | |||
| CSV_DATA_DIR = '../data/dataset/testCSV/1.csv' | |||
| TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt" | |||
| GENERATE_GOLDEN = False | |||
| @@ -1310,7 +1314,7 @@ def test_cache_nomap_multiple_cache1(): | |||
| eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 12 records in it | |||
| train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3) | |||
| train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) | |||
| decode_op = c_vision.Decode() | |||
| train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) | |||
| @@ -1359,7 +1363,7 @@ def test_cache_nomap_multiple_cache2(): | |||
| image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) | |||
| # This dataset has 3 records in it only | |||
| text_dataset = ds.TFRecordDataset(DATA_DIR2, SCHEMA_DIR2, cache=text_cache) | |||
| text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache) | |||
| num_epoch = 5 | |||
| image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch) | |||
| @@ -1404,7 +1408,7 @@ def test_cache_nomap_multiple_cache3(): | |||
| tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache) | |||
| # This DATA_DIR only has 2 images in it | |||
| image_dataset = ds.ImageFolderDataset(dataset_dir=DATA_DIR4) | |||
| image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR) | |||
| image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) | |||
| num_epoch = 5 | |||
| @@ -1443,7 +1447,7 @@ def test_cache_nomap_multiple_cache_train(): | |||
| train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 12 records in it | |||
| train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3) | |||
| train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) | |||
| decode_op = c_vision.Decode() | |||
| train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) | |||
| @@ -1497,6 +1501,239 @@ def test_cache_nomap_multiple_cache_eval(): | |||
| logger.info("test_cache_nomap_multiple_cache_eval Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_clue1(): | |||
| """ | |||
| A clue dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| In this one, the clue dataset will be given sharding configuration, however since a cache is | |||
| used, the tree prepare should undo the sharding configuration and instead, a distributed | |||
| sampler will be chosen with the same shard config. | |||
| Cache | |||
| | | |||
| CLUE | |||
| """ | |||
| logger.info("Test cache nomap clue 1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||
| # In this case, it is a row-based sharding, not the file-based sharding that would happen if | |||
| # there was not any cache. | |||
| ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 1 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_nomap_clue1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_clue2(): | |||
| """ | |||
| A clue dataset (a non mappable dataset) with a cache over it after map | |||
| In this one, a num_samples argument is given | |||
| Cache | |||
| | | |||
| map(lambda x: x) | |||
| | | |||
| CLUE | |||
| """ | |||
| logger.info("Test cache nomap clue 2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) | |||
| ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 2 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_nomap_clue2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_csv1(): | |||
| """ | |||
| A csv dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| In this one, the csv dataset will be given sharding configuration, however since a cache is | |||
| used, the tree prepare should undo the sharding configuration and instead, a distributed | |||
| sampler will be chosen with the same shard config. | |||
| Cache | |||
| | | |||
| CSV | |||
| """ | |||
| logger.info("Test cache nomap csv 1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||
| # In this case, it is a row-based sharding, not the file-based sharding that would happen if | |||
| # there was not any cache. | |||
| ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 1 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_nomap_csv1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_csv2(): | |||
| """ | |||
| A csv dataset (a non mappable dataset) with a cache over it after map | |||
| In this one, a num_samples argument is given | |||
| Cache | |||
| | | |||
| map(lambda x: x) | |||
| | | |||
| CSV | |||
| """ | |||
| logger.info("Test cache nomap csv 2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) | |||
| ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 2 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_nomap_csv2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_textfile1(): | |||
| """ | |||
| A text file dataset (a non mappable dataset) with a cache over it just after the leaf | |||
| In this one, the text file dataset will be given sharding configuration, however since a cache is | |||
| used, the tree prepare should undo the sharding configuration and instead, a distributed | |||
| sampler will be chosen with the same shard config. | |||
| Cache | |||
| | | |||
| TextFile | |||
| """ | |||
| logger.info("Test cache nomap textfile 1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||
| # In this case, it is a row-based sharding, not the file-based sharding that would happen if | |||
| # there was not any cache. | |||
| ds1 = ds.CSVDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 1 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_nomap_textfile1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_textfile2(): | |||
| """ | |||
| A text file dataset (a non mappable dataset) with a cache over it after map | |||
| In this one, a num_samples argument is given | |||
| Cache | |||
| | | |||
| Map(tokenizer) | |||
| | | |||
| TextFile | |||
| """ | |||
| def my_tokenizer(line): | |||
| words = line.split() | |||
| if not words: | |||
| return [""] | |||
| return words | |||
| logger.info("Test cache nomap textfile 2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2) | |||
| tokenizer = text.PythonTokenizer(my_tokenizer) | |||
| ds1 = ds1.map(operations=tokenizer, cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| assert sum([1 for _ in iter1]) == 2 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_nomap_textfile2 Ended.\n") | |||
| if __name__ == '__main__': | |||
| test_cache_nomap_basic1() | |||
| test_cache_nomap_basic2() | |||
| @@ -40,8 +40,9 @@ def test_textline_dataset_all_file(): | |||
| assert count == 5 | |||
| def test_textline_dataset_num_samples_zero(): | |||
| data = ds.TextFileDataset(DATA_FILE, num_samples=0) | |||
| def test_textline_dataset_num_samples_none(): | |||
| # Do not provide a num_samples argument, so it would be None by default | |||
| data = ds.TextFileDataset(DATA_FILE) | |||
| count = 0 | |||
| for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| logger.info("{}".format(i["text"])) | |||
| @@ -208,7 +209,7 @@ def test_textline_dataset_exceptions(): | |||
| if __name__ == "__main__": | |||
| test_textline_dataset_one_file() | |||
| test_textline_dataset_all_file() | |||
| test_textline_dataset_num_samples_zero() | |||
| test_textline_dataset_num_samples_none() | |||
| test_textline_dataset_shuffle_false4() | |||
| test_textline_dataset_shuffle_false1() | |||
| test_textline_dataset_shuffle_files4() | |||