Merge pull request !6905 from lixiachen/CacheOp_devtags/v1.1.0
| @@ -1075,7 +1075,7 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() { | |||||
| std::shared_ptr<ClueOp> clue_op = | std::shared_ptr<ClueOp> clue_op = | ||||
| std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, | std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, | ||||
| sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||||
| sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||||
| RETURN_EMPTY_IF_ERROR(clue_op->Init()); | RETURN_EMPTY_IF_ERROR(clue_op->Init()); | ||||
| if (shuffle_ == ShuffleMode::kGlobal) { | if (shuffle_ == ShuffleMode::kGlobal) { | ||||
| // Inject ShuffleOp | // Inject ShuffleOp | ||||
| @@ -1256,7 +1256,7 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() { | |||||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | ||||
| sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, | sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, | ||||
| num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||||
| num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||||
| RETURN_EMPTY_IF_ERROR(csv_op->Init()); | RETURN_EMPTY_IF_ERROR(csv_op->Init()); | ||||
| if (shuffle_ == ShuffleMode::kGlobal) { | if (shuffle_ == ShuffleMode::kGlobal) { | ||||
| // Inject ShuffleOp | // Inject ShuffleOp | ||||
| @@ -1502,7 +1502,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() { | |||||
| // Create and initalize TextFileOp | // Create and initalize TextFileOp | ||||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | ||||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | ||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr)); | |||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||||
| RETURN_EMPTY_IF_ERROR(text_file_op->Init()); | RETURN_EMPTY_IF_ERROR(text_file_op->Init()); | ||||
| if (shuffle_ == ShuffleMode::kGlobal) { | if (shuffle_ == ShuffleMode::kGlobal) { | ||||
| @@ -1345,6 +1345,9 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| std::string err_msg = "Error: No dataset files specified for manifest"; | std::string err_msg = "Error: No dataset files specified for manifest"; | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| int num_workers = 0; | |||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<ManifestOp::Builder> builder = std::make_shared<ManifestOp::Builder>(); | std::shared_ptr<ManifestOp::Builder> builder = std::make_shared<ManifestOp::Builder>(); | ||||
| (void)builder->SetManifestFile(ToString(args["dataset_file"])); | (void)builder->SetManifestFile(ToString(args["dataset_file"])); | ||||
| @@ -1354,7 +1357,8 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| @@ -1365,12 +1369,27 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| (void)builder->SetDecode(ToBool(value)); | (void)builder->SetDecode(ToBool(value)); | ||||
| } else if (key == "usage") { | } else if (key == "usage") { | ||||
| (void)builder->SetUsage(ToString(value)); | (void)builder->SetUsage(ToString(value)); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<ManifestOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *top = op; | |||||
| std::shared_ptr<ManifestOp> manifest_op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&manifest_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(manifest_op)); | |||||
| *top = manifest_op; | |||||
| // Additionally, add a cache if required. | |||||
| // Note that this cache op is only acting as a place holder for the caching position | |||||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||||
| // caching logic in the tree. | |||||
| if (cache_client) { | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, manifest_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = manifest_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1380,6 +1399,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified."); | CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified."); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified."); | CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified."); | ||||
| int num_workers = 0; | |||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>(); | std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>(); | ||||
| (void)builder->SetDir(ToString(args["dataset_dir"])); | (void)builder->SetDir(ToString(args["dataset_dir"])); | ||||
| (void)builder->SetTask(ToString(args["task"])); | (void)builder->SetTask(ToString(args["task"])); | ||||
| @@ -1389,7 +1410,8 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| @@ -1398,12 +1420,26 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| (void)builder->SetDecode(ToBool(value)); | (void)builder->SetDecode(ToBool(value)); | ||||
| } else if (key == "class_indexing") { | } else if (key == "class_indexing") { | ||||
| (void)builder->SetClassIndex(ToStringMap(value)); | (void)builder->SetClassIndex(ToStringMap(value)); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<VOCOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *top = op; | |||||
| std::shared_ptr<VOCOp> voc_op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&voc_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(voc_op)); | |||||
| *top = voc_op; | |||||
| // Additionally, add a cache if required. | |||||
| // Note that this cache op is only acting as a place holder for the caching position | |||||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||||
| // caching logic in the tree. | |||||
| if (cache_client) { | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, voc_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = voc_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1425,6 +1461,8 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| int num_workers = 0; | |||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<CocoOp::Builder> builder = std::make_shared<CocoOp::Builder>(); | std::shared_ptr<CocoOp::Builder> builder = std::make_shared<CocoOp::Builder>(); | ||||
| (void)builder->SetDir(ToString(args["dataset_dir"])); | (void)builder->SetDir(ToString(args["dataset_dir"])); | ||||
| (void)builder->SetFile(ToString(args["annotation_file"])); | (void)builder->SetFile(ToString(args["annotation_file"])); | ||||
| @@ -1434,19 +1472,35 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "decode") { | } else if (key == "decode") { | ||||
| (void)builder->SetDecode(ToBool(value)); | (void)builder->SetDecode(ToBool(value)); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<CocoOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *top = op; | |||||
| std::shared_ptr<CocoOp> coco_op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&coco_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(coco_op)); | |||||
| *top = coco_op; | |||||
| // Additionally, add a cache if required. | |||||
| // Note that this cache op is only acting as a place holder for the caching position | |||||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||||
| // caching logic in the tree. | |||||
| if (cache_client) { | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, coco_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = coco_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1458,6 +1512,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| int num_workers = 0; | |||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>(); | std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>(); | ||||
| (void)builder->SetCifarDir(ToString(args["dataset_dir"])); | (void)builder->SetCifarDir(ToString(args["dataset_dir"])); | ||||
| @@ -1467,22 +1523,38 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | } else if (key == "usage") { | ||||
| (void)builder->SetUsage(ToString(value)); | (void)builder->SetUsage(ToString(value)); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| (void)builder->SetCifarType(true); | (void)builder->SetCifarType(true); | ||||
| std::shared_ptr<CifarOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *top = op; | |||||
| std::shared_ptr<CifarOp> cifar_op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&cifar_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op)); | |||||
| *top = cifar_op; | |||||
| // Additionally, add a cache if required. | |||||
| // Note that this cache op is only acting as a place holder for the caching position | |||||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||||
| // caching logic in the tree. | |||||
| if (cache_client) { | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = cifar_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1494,6 +1566,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| int num_workers = 0; | |||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>(); | std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>(); | ||||
| (void)builder->SetCifarDir(ToString(args["dataset_dir"])); | (void)builder->SetCifarDir(ToString(args["dataset_dir"])); | ||||
| @@ -1503,22 +1577,37 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | } else if (key == "usage") { | ||||
| (void)builder->SetUsage(ToString(value)); | (void)builder->SetUsage(ToString(value)); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| (void)builder->SetCifarType(false); | (void)builder->SetCifarType(false); | ||||
| std::shared_ptr<CifarOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *top = op; | |||||
| std::shared_ptr<CifarOp> cifar_op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&cifar_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(cifar_op)); | |||||
| *top = cifar_op; | |||||
| // Additionally, add a cache if required. | |||||
| // Note that this cache op is only acting as a place holder for the caching position | |||||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||||
| // caching logic in the tree. | |||||
| if (cache_client) { | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, cifar_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = cifar_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1609,6 +1698,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| int num_workers = 0; | |||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<MnistOp::Builder> builder = std::make_shared<MnistOp::Builder>(); | std::shared_ptr<MnistOp::Builder> builder = std::make_shared<MnistOp::Builder>(); | ||||
| (void)builder->SetDir(ToString(args["dataset_dir"])); | (void)builder->SetDir(ToString(args["dataset_dir"])); | ||||
| @@ -1618,19 +1709,35 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | } else if (key == "usage") { | ||||
| (void)builder->SetUsage(ToString(value)); | (void)builder->SetUsage(ToString(value)); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<MnistOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *top = op; | |||||
| std::shared_ptr<MnistOp> mnist_op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&mnist_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(mnist_op)); | |||||
| *top = mnist_op; | |||||
| // Additionally, add a cache if required. | |||||
| // Note that this cache op is only acting as a place holder for the caching position | |||||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||||
| // caching logic in the tree. | |||||
| if (cache_client) { | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, mnist_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = mnist_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1642,6 +1749,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | ||||
| } | } | ||||
| int num_workers = 0; | |||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<CelebAOp::Builder> builder = std::make_shared<CelebAOp::Builder>(); | std::shared_ptr<CelebAOp::Builder> builder = std::make_shared<CelebAOp::Builder>(); | ||||
| if (builder == nullptr) { | if (builder == nullptr) { | ||||
| std::string err_msg = "Create celebaop builder failed"; | std::string err_msg = "Create celebaop builder failed"; | ||||
| @@ -1653,7 +1762,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "sampler") { | } else if (key == "sampler") { | ||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| @@ -1664,13 +1774,28 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||||
| (void)builder->SetExtensions(ToStringSet(value)); | (void)builder->SetExtensions(ToStringSet(value)); | ||||
| } else if (key == "usage") { | } else if (key == "usage") { | ||||
| (void)builder->SetUsage(ToString(value)); | (void)builder->SetUsage(ToString(value)); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<CelebAOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *top = op; | |||||
| std::shared_ptr<CelebAOp> celeba_op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&celeba_op)); | |||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(celeba_op)); | |||||
| *top = celeba_op; | |||||
| // Additionally, add a cache if required. | |||||
| // Note that this cache op is only acting as a place holder for the caching position | |||||
| // within the tree. Later, a pre-pass will execute a tree transform to set up the actual | |||||
| // caching logic in the tree. | |||||
| if (cache_client) { | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, celeba_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = celeba_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1678,6 +1803,9 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| std::shared_ptr<DatasetOp> *bottom) { | std::shared_ptr<DatasetOp> *bottom) { | ||||
| // Required arguments | // Required arguments | ||||
| std::vector<std::string> files_list; | std::vector<std::string> files_list; | ||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<Sampler> sampler = nullptr; | |||||
| int num_workers = 0; | |||||
| std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); | std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>(); | ||||
| if (!args["dataset_files"].is_none()) { | if (!args["dataset_files"].is_none()) { | ||||
| files_list = ToStringVector(args["dataset_files"]); | files_list = ToStringVector(args["dataset_files"]); | ||||
| @@ -1693,7 +1821,8 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "shuffle_files") { | } else if (key == "shuffle_files") { | ||||
| (void)builder->SetShuffleFiles(ToBool(value)); | (void)builder->SetShuffleFiles(ToBool(value)); | ||||
| } else if (key == "shuffle_global") { | } else if (key == "shuffle_global") { | ||||
| @@ -1705,16 +1834,35 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| (void)builder->SetNumDevices(num_devices); | (void)builder->SetNumDevices(num_devices); | ||||
| } else if (key == "shard_id") { | } else if (key == "shard_id") { | ||||
| (void)builder->SetDeviceId(ToInt(value)); | (void)builder->SetDeviceId(ToInt(value)); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } else if (key == "sampler") { | |||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed | |||||
| // because TextFileOp is a non-mappable dataset that does not support sampling. | |||||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||||
| if (sampler) { | |||||
| (void)builder->SetSampler(std::move(sampler)); | |||||
| } else if (cache_client) { | |||||
| int64_t num_samples = 0; | |||||
| int64_t start_index = 0; | |||||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| (void)builder->SetSampler(std::move(sampler)); | |||||
| } | |||||
| std::shared_ptr<TextFileOp> txt_op; | std::shared_ptr<TextFileOp> txt_op; | ||||
| RETURN_IF_NOT_OK(builder->Build(&txt_op)); | RETURN_IF_NOT_OK(builder->Build(&txt_op)); | ||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); | RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); | ||||
| *top = txt_op; | *top = txt_op; | ||||
| if (shuffle_required) { | |||||
| if (!cache_client && shuffle_required) { | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | std::shared_ptr<DatasetOp> shuffle_op = nullptr; | ||||
| int64_t shuffle_size = 0; | int64_t shuffle_size = 0; | ||||
| int64_t num_rows = 0; | int64_t num_rows = 0; | ||||
| @@ -1729,6 +1877,15 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| *bottom = txt_op; | *bottom = txt_op; | ||||
| } | } | ||||
| // Add a cache op over this op if required and update the output subtree (top/bottom) | |||||
| if (cache_client) { | |||||
| // Note, it is not allowed to have both shuffle and cache | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, txt_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = txt_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1829,6 +1986,10 @@ Status DEPipeline::ParseBuildSentencePieceVocabOp(const py::dict &args, std::sha | |||||
| Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | ||||
| std::shared_ptr<DatasetOp> *bottom) { | std::shared_ptr<DatasetOp> *bottom) { | ||||
| std::vector<std::string> files_list; | std::vector<std::string> files_list; | ||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<Sampler> sampler = nullptr; | |||||
| int num_workers = 0; | |||||
| std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>(); | std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>(); | ||||
| if (!args["dataset_files"].is_none()) { | if (!args["dataset_files"].is_none()) { | ||||
| files_list = ToStringVector(args["dataset_files"]); | files_list = ToStringVector(args["dataset_files"]); | ||||
| @@ -1844,7 +2005,8 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "shuffle_files") { | } else if (key == "shuffle_files") { | ||||
| (void)builder->SetShuffleFiles(ToBool(value)); | (void)builder->SetShuffleFiles(ToBool(value)); | ||||
| } else if (key == "shuffle_global") { | } else if (key == "shuffle_global") { | ||||
| @@ -1866,16 +2028,35 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| } | } | ||||
| } | } | ||||
| (void)builder->SetColsKeyMap(map_dict); | (void)builder->SetColsKeyMap(map_dict); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } else if (key == "sampler") { | |||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed | |||||
| // because ClueOp is a non-mappable dataset that does not support sampling. | |||||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||||
| if (sampler) { | |||||
| (void)builder->SetSampler(std::move(sampler)); | |||||
| } else if (cache_client) { | |||||
| int64_t num_samples = 0; | |||||
| int64_t start_index = 0; | |||||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| (void)builder->SetSampler(std::move(sampler)); | |||||
| } | |||||
| std::shared_ptr<ClueOp> clue_op; | std::shared_ptr<ClueOp> clue_op; | ||||
| RETURN_IF_NOT_OK(builder->Build(&clue_op)); | RETURN_IF_NOT_OK(builder->Build(&clue_op)); | ||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); | RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); | ||||
| *top = clue_op; | *top = clue_op; | ||||
| if (shuffle_required) { | |||||
| if (!cache_client && shuffle_required) { | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | std::shared_ptr<DatasetOp> shuffle_op = nullptr; | ||||
| int64_t shuffle_size = 0; | int64_t shuffle_size = 0; | ||||
| int64_t num_rows = 0; | int64_t num_rows = 0; | ||||
| @@ -1890,6 +2071,15 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| *bottom = clue_op; | *bottom = clue_op; | ||||
| } | } | ||||
| // Add a cache op over this op if required and update the output subtree (top/bottom) | |||||
| if (cache_client) { | |||||
| // Note, it is not allowed to have both shuffle and cache | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, clue_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = clue_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1921,6 +2111,9 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num | |||||
| Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | ||||
| std::shared_ptr<DatasetOp> *bottom) { | std::shared_ptr<DatasetOp> *bottom) { | ||||
| std::vector<std::string> files_list; | std::vector<std::string> files_list; | ||||
| std::shared_ptr<CacheClient> cache_client = nullptr; | |||||
| std::shared_ptr<Sampler> sampler = nullptr; | |||||
| int num_workers = 0; | |||||
| std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); | std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>(); | ||||
| if (!args["dataset_files"].is_none()) { | if (!args["dataset_files"].is_none()) { | ||||
| files_list = ToStringVector(args["dataset_files"]); | files_list = ToStringVector(args["dataset_files"]); | ||||
| @@ -1938,7 +2131,8 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| if (!value.is_none()) { | if (!value.is_none()) { | ||||
| if (key == "num_parallel_workers") { | if (key == "num_parallel_workers") { | ||||
| (void)builder->SetNumWorkers(ToInt(value)); | |||||
| num_workers = ToInt(value); | |||||
| (void)builder->SetNumWorkers(num_workers); | |||||
| } else if (key == "shuffle_files") { | } else if (key == "shuffle_files") { | ||||
| (void)builder->SetShuffleFiles(ToBool(value)); | (void)builder->SetShuffleFiles(ToBool(value)); | ||||
| } else if (key == "shuffle_global") { | } else if (key == "shuffle_global") { | ||||
| @@ -1971,16 +2165,35 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| } else if (key == "column_names") { | } else if (key == "column_names") { | ||||
| col_names = ToStringVector(value); | col_names = ToStringVector(value); | ||||
| (void)builder->SetColumName(col_names); | (void)builder->SetColumName(col_names); | ||||
| } else if (key == "cache") { | |||||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||||
| } else if (key == "sampler") { | |||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||||
| sampler = create().cast<std::shared_ptr<Sampler>>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed | |||||
| // because CsvOp is a non-mappable dataset that does not support sampling. | |||||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||||
| if (sampler) { | |||||
| (void)builder->SetSampler(std::move(sampler)); | |||||
| } else if (cache_client) { | |||||
| int64_t num_samples = 0; | |||||
| int64_t start_index = 0; | |||||
| sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||||
| (void)builder->SetSampler(std::move(sampler)); | |||||
| } | |||||
| std::shared_ptr<CsvOp> csv_op; | std::shared_ptr<CsvOp> csv_op; | ||||
| RETURN_IF_NOT_OK(builder->Build(&csv_op)); | RETURN_IF_NOT_OK(builder->Build(&csv_op)); | ||||
| RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op)); | RETURN_IF_NOT_OK(tree_->AssociateNode(csv_op)); | ||||
| *top = csv_op; | *top = csv_op; | ||||
| if (shuffle_required) { | |||||
| if (!cache_client && shuffle_required) { | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | std::shared_ptr<DatasetOp> shuffle_op = nullptr; | ||||
| int64_t shuffle_size = 0; | int64_t shuffle_size = 0; | ||||
| int64_t num_rows = 0; | int64_t num_rows = 0; | ||||
| @@ -1995,6 +2208,15 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||||
| *bottom = csv_op; | *bottom = csv_op; | ||||
| } | } | ||||
| // Add a cache op over this op if required and update the output subtree (top/bottom) | |||||
| if (cache_client) { | |||||
| // Note, it is not allowed to have both shuffle and cache | |||||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, csv_op, &cache_op)); | |||||
| *top = cache_op; | |||||
| *bottom = csv_op; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -70,13 +70,12 @@ Status AlbumOp::Builder::SanityCheck() { | |||||
| AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, | ||||
| const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, | const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler) | std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_wkrs, queue_size), | |||||
| : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | |||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| folder_path_(file_dir), | folder_path_(file_dir), | ||||
| decode_(do_decode), | decode_(do_decode), | ||||
| extensions_(exts), | extensions_(exts), | ||||
| data_schema_(std::move(data_schema)), | data_schema_(std::move(data_schema)), | ||||
| sampler_(std::move(sampler)), | |||||
| row_cnt_(0), | row_cnt_(0), | ||||
| buf_cnt_(0), | buf_cnt_(0), | ||||
| sampler_ind_(0), | sampler_ind_(0), | ||||
| @@ -284,7 +284,6 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||||
| std::set<std::string> extensions_; // extensions allowed | std::set<std::string> extensions_; // extensions allowed | ||||
| std::unordered_map<std::string, int32_t> col_name_map_; | std::unordered_map<std::string, int32_t> col_name_map_; | ||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| int64_t row_cnt_; | int64_t row_cnt_; | ||||
| int64_t buf_cnt_; | int64_t buf_cnt_; | ||||
| int64_t sampler_ind_; | int64_t sampler_ind_; | ||||
| @@ -25,13 +25,18 @@ | |||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| #include "minddata/dataset/engine/jagged_connector.h" | #include "minddata/dataset/engine/jagged_connector.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | #include "minddata/dataset/engine/datasetops/source/io_block.h" | ||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| ClueOp::Builder::Builder() | ClueOp::Builder::Builder() | ||||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||||
| : builder_device_id_(0), | |||||
| builder_num_devices_(1), | |||||
| builder_num_samples_(0), | |||||
| builder_shuffle_files_(false), | |||||
| builder_sampler_(nullptr) { | |||||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | ||||
| builder_num_workers_ = config_manager->num_parallel_workers(); | builder_num_workers_ = config_manager->num_parallel_workers(); | ||||
| builder_op_connector_size_ = config_manager->op_connector_size(); | builder_op_connector_size_ = config_manager->op_connector_size(); | ||||
| @@ -68,7 +73,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) { | |||||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | ||||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, | builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, | ||||
| builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, | builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, | ||||
| builder_device_id_); | |||||
| builder_device_id_, std::move(builder_sampler_)); | |||||
| RETURN_IF_NOT_OK(clue_op->Init()); | RETURN_IF_NOT_OK(clue_op->Init()); | ||||
| *op = std::move(clue_op); | *op = std::move(clue_op); | ||||
| @@ -88,8 +93,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim | |||||
| ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ||||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ||||
| bool shuffle_files, int32_t num_device, int32_t device_id) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| num_rows_per_shard_(0), | num_rows_per_shard_(0), | ||||
| all_num_rows_(0), | all_num_rows_(0), | ||||
| @@ -539,5 +544,21 @@ Status ClueOp::ComputeColMap() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing | |||||
| // a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so | |||||
| // that this clue op will produce the full set of data into the cache. | |||||
| void ClueOp::MakeSimpleProducer() { | |||||
| device_id_ = 0; | |||||
| num_devices_ = 1; | |||||
| shuffle_files_ = false; | |||||
| num_samples_ = 0; | |||||
| } | |||||
| // Visitor accept method for NodePass | |||||
| Status ClueOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(shared_from_base<ClueOp>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include <nlohmann/json.hpp> | #include <nlohmann/json.hpp> | ||||
| @@ -122,6 +123,14 @@ class ClueOp : public ParallelOp { | |||||
| // @return - the a string vector | // @return - the a string vector | ||||
| std::vector<std::string> split(const std::string &s, char delim); | std::vector<std::string> split(const std::string &s, char delim); | ||||
| // Setter method | |||||
| // @param std::shared_ptr<Sampler> sampler | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | |||||
| return *this; | |||||
| } | |||||
| private: | private: | ||||
| int32_t builder_device_id_; | int32_t builder_device_id_; | ||||
| int32_t builder_num_devices_; | int32_t builder_num_devices_; | ||||
| @@ -133,12 +142,13 @@ class ClueOp : public ParallelOp { | |||||
| std::vector<std::string> builder_clue_files_list_; | std::vector<std::string> builder_clue_files_list_; | ||||
| bool builder_shuffle_files_; | bool builder_shuffle_files_; | ||||
| std::map<std::string, std::string> builder_cols_to_keyword_; | std::map<std::string, std::string> builder_cols_to_keyword_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| }; | }; | ||||
| // Constructor of ClueOp | // Constructor of ClueOp | ||||
| ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ||||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ||||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler); | |||||
| // Default destructor | // Default destructor | ||||
| ~ClueOp() = default; | ~ClueOp() = default; | ||||
| @@ -173,6 +183,17 @@ class ClueOp : public ParallelOp { | |||||
| // @return Vector of the input file names | // @return Vector of the input file names | ||||
| std::vector<std::string> FileNames() { return clue_files_list_; } | std::vector<std::string> FileNames() { return clue_files_list_; } | ||||
| /// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing | |||||
| /// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so | |||||
| /// that this clue op will produce the full set of data into the cache. | |||||
| void MakeSimpleProducer(); | |||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // The entry point for when workers are launched. | // The entry point for when workers are launched. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| @@ -124,7 +124,7 @@ Status CocoOp::Builder::SanityCheck() { | |||||
| CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, | ||||
| int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, | ||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_workers, queue_size), | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||||
| decode_(decode), | decode_(decode), | ||||
| row_cnt_(0), | row_cnt_(0), | ||||
| buf_cnt_(0), | buf_cnt_(0), | ||||
| @@ -132,7 +132,6 @@ CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, | |||||
| image_folder_path_(image_folder_path), | image_folder_path_(image_folder_path), | ||||
| annotation_path_(annotation_path), | annotation_path_(annotation_path), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| sampler_(std::move(sampler)), | |||||
| data_schema_(std::move(data_schema)) { | data_schema_(std::move(data_schema)) { | ||||
| io_block_queues_.Init(num_workers_, queue_size); | io_block_queues_.Init(num_workers_, queue_size); | ||||
| } | } | ||||
| @@ -206,6 +206,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| Status Accept(NodePass *p, bool *modified) override; | Status Accept(NodePass *p, bool *modified) override; | ||||
| // Op name getter | |||||
| // @return Name of the current Op | |||||
| std::string Name() const override { return "CocoOp"; } | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| @@ -324,7 +328,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| std::string annotation_path_; | std::string annotation_path_; | ||||
| TaskType task_type_; | TaskType task_type_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| WaitPost wp_; | WaitPost wp_; | ||||
| @@ -22,12 +22,17 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/engine/jagged_connector.h" | #include "minddata/dataset/engine/jagged_connector.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CsvOp::Builder::Builder() | CsvOp::Builder::Builder() | ||||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||||
| : builder_device_id_(0), | |||||
| builder_num_devices_(1), | |||||
| builder_num_samples_(0), | |||||
| builder_shuffle_files_(false), | |||||
| builder_sampler_(nullptr) { | |||||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | ||||
| builder_num_workers_ = config_manager->num_parallel_workers(); | builder_num_workers_ = config_manager->num_parallel_workers(); | ||||
| builder_op_connector_size_ = config_manager->op_connector_size(); | builder_op_connector_size_ = config_manager->op_connector_size(); | ||||
| @@ -59,7 +64,8 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) { | |||||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | ||||
| builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, | builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, | ||||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | ||||
| builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); | |||||
| builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_, | |||||
| std::move(builder_sampler_)); | |||||
| RETURN_IF_NOT_OK(csv_op->Init()); | RETURN_IF_NOT_OK(csv_op->Init()); | ||||
| *op = std::move(csv_op); | *op = std::move(csv_op); | ||||
| @@ -70,8 +76,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, | const std::vector<std::shared_ptr<BaseRecord>> &column_default, | ||||
| const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | ||||
| int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | ||||
| int32_t num_device, int32_t device_id) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||||
| csv_files_list_(std::move(csv_files_list)), | csv_files_list_(std::move(csv_files_list)), | ||||
| field_delim_(field_delim), | field_delim_(field_delim), | ||||
| column_default_list_(column_default), | column_default_list_(column_default), | ||||
| @@ -889,5 +895,21 @@ Status CsvOp::ComputeColMap() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing | |||||
| // a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so | |||||
| // that this csv op will produce the full set of data into the cache. | |||||
| void CsvOp::MakeSimpleProducer() { | |||||
| device_id_ = 0; | |||||
| num_devices_ = 1; | |||||
| shuffle_files_ = false; | |||||
| num_samples_ = 0; | |||||
| } | |||||
| // Visitor accept method for NodePass | |||||
| Status CsvOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(shared_from_base<CsvOp>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -240,6 +240,14 @@ class CsvOp : public ParallelOp { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method | |||||
| // @param std::shared_ptr<Sampler> sampler | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | |||||
| return *this; | |||||
| } | |||||
| private: | private: | ||||
| int32_t builder_device_id_; | int32_t builder_device_id_; | ||||
| int32_t builder_num_devices_; | int32_t builder_num_devices_; | ||||
| @@ -253,6 +261,7 @@ class CsvOp : public ParallelOp { | |||||
| char builder_field_delim_; | char builder_field_delim_; | ||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | ||||
| std::vector<std::string> builder_column_name_list_; | std::vector<std::string> builder_column_name_list_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| }; | }; | ||||
| // Constructor of CsvOp | // Constructor of CsvOp | ||||
| @@ -261,7 +270,8 @@ class CsvOp : public ParallelOp { | |||||
| CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | ||||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | ||||
| int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); | |||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, | |||||
| std::shared_ptr<Sampler> sampler); | |||||
| // Default destructor | // Default destructor | ||||
| ~CsvOp() = default; | ~CsvOp() = default; | ||||
| @@ -297,6 +307,17 @@ class CsvOp : public ParallelOp { | |||||
| // @return Vector of the input file names | // @return Vector of the input file names | ||||
| std::vector<std::string> FileNames() { return csv_files_list_; } | std::vector<std::string> FileNames() { return csv_files_list_; } | ||||
| /// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing | |||||
| /// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so | |||||
| /// that this csv op will produce the full set of data into the cache. | |||||
| void MakeSimpleProducer(); | |||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // The entry point for when workers are launched. | // The entry point for when workers are launched. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | #include "minddata/dataset/engine/datasetops/source/io_block.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -499,5 +500,21 @@ Status TextFileOp::ComputeColMap() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing | |||||
| // a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so | |||||
| // that this text file op will produce the full set of data into the cache. | |||||
| void TextFileOp::MakeSimpleProducer() { | |||||
| device_id_ = 0; | |||||
| num_devices_ = 1; | |||||
| shuffle_files_ = false; | |||||
| total_rows_ = 0; | |||||
| } | |||||
| // Visitor accept method for NodePass | |||||
| Status TextFileOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->RunOnNode(shared_from_base<TextFileOp>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -188,6 +188,17 @@ class TextFileOp : public ParallelOp { | |||||
| // @return Vector of the input file names | // @return Vector of the input file names | ||||
| std::vector<std::string> FileNames() { return text_files_list_; } | std::vector<std::string> FileNames() { return text_files_list_; } | ||||
| /// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing | |||||
| /// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so | |||||
| /// that this text file op will produce the full set of data into the cache. | |||||
| void MakeSimpleProducer(); | |||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| // The entry point for when workers are launched. | // The entry point for when workers are launched. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| @@ -212,6 +212,7 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten | |||||
| folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); | folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); | ||||
| RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); | RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); | ||||
| RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation)); | RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation)); | ||||
| trow->setId(row_id); | |||||
| trow->push_back(std::move(image)); | trow->push_back(std::move(image)); | ||||
| trow->insert(trow->end(), annotation.begin(), annotation.end()); | trow->insert(trow->end(), annotation.begin(), annotation.end()); | ||||
| } | } | ||||
| @@ -45,6 +45,9 @@ | |||||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | #include "minddata/dataset/engine/datasetops/source/voc_op.h" | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| @@ -260,6 +263,21 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| @@ -81,6 +81,12 @@ class CacheMergeOp; | |||||
| class CacheLookupOp; | class CacheLookupOp; | ||||
| class BuildSentencePieceVocabOp; | class BuildSentencePieceVocabOp; | ||||
| class ClueOp; | |||||
| class CsvOp; | |||||
| class TextFileOp; | |||||
| #endif | #endif | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| @@ -211,6 +217,12 @@ class NodePass : public Pass { | |||||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | ||||
| @@ -36,6 +36,9 @@ | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||||
| #endif | #endif | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| @@ -141,6 +144,36 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node | |||||
| } | } | ||||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | ||||
| } | } | ||||
| // Perform leaf node cache transform identification | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) { | |||||
| if (is_caching_) { | |||||
| // If we are a ClueOp in a caching tree, then change our config so that it becomes a basic | |||||
| // ClueOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||||
| node->MakeSimpleProducer(); | |||||
| } | |||||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache transform identification | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) { | |||||
| if (is_caching_) { | |||||
| // If we are a CsvOp in a caching tree, then change our config so that it becomes a basic | |||||
| // CsvOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||||
| node->MakeSimpleProducer(); | |||||
| } | |||||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| // Perform leaf node cache transform identification | |||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) { | |||||
| if (is_caching_) { | |||||
| // If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic | |||||
| // TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||||
| node->MakeSimpleProducer(); | |||||
| } | |||||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | |||||
| #endif | #endif | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| @@ -163,34 +196,22 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, b | |||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | ||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | ||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | ||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | ||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -214,18 +235,12 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> nod | |||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | ||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | ||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -65,6 +65,24 @@ class CacheTransformPass : public TreePass { | |||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; | |||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; | |||||
| #endif | #endif | ||||
| /// \brief Perform leaf node cache tranform identifications | /// \brief Perform leaf node cache tranform identifications | ||||
| @@ -2969,6 +2969,8 @@ class MnistDataset(MappableDataset): | |||||
| into (default=None). | into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -2988,7 +2990,7 @@ class MnistDataset(MappableDataset): | |||||
| @check_mnist_cifar_dataset | @check_mnist_cifar_dataset | ||||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | ||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| @@ -2998,6 +3000,7 @@ class MnistDataset(MappableDataset): | |||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.cache = cache | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| @@ -3008,6 +3011,7 @@ class MnistDataset(MappableDataset): | |||||
| args["sampler"] = self.sampler | args["sampler"] = self.sampler | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -3872,6 +3876,8 @@ class ManifestDataset(MappableDataset): | |||||
| into (default=None). | into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -3897,7 +3903,8 @@ class ManifestDataset(MappableDataset): | |||||
| @check_manifestdataset | @check_manifestdataset | ||||
| def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None, | def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None, | ||||
| shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None): | |||||
| shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, | |||||
| cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_file = dataset_file | self.dataset_file = dataset_file | ||||
| @@ -3913,6 +3920,7 @@ class ManifestDataset(MappableDataset): | |||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.cache = cache | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| @@ -3925,6 +3933,7 @@ class ManifestDataset(MappableDataset): | |||||
| args["decode"] = self.decode | args["decode"] = self.decode | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -4055,6 +4064,8 @@ class Cifar10Dataset(MappableDataset): | |||||
| into (default=None). | into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -4082,7 +4093,7 @@ class Cifar10Dataset(MappableDataset): | |||||
| @check_mnist_cifar_dataset | @check_mnist_cifar_dataset | ||||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | ||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| @@ -4092,6 +4103,7 @@ class Cifar10Dataset(MappableDataset): | |||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| self.cache = cache | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| @@ -4102,6 +4114,7 @@ class Cifar10Dataset(MappableDataset): | |||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["shuffle"] = self.shuffle_level | args["shuffle"] = self.shuffle_level | ||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -4202,6 +4215,8 @@ class Cifar100Dataset(MappableDataset): | |||||
| into (default=None). | into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -4226,7 +4241,7 @@ class Cifar100Dataset(MappableDataset): | |||||
| @check_mnist_cifar_dataset | @check_mnist_cifar_dataset | ||||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | ||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| @@ -4236,6 +4251,7 @@ class Cifar100Dataset(MappableDataset): | |||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| self.cache = cache | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| @@ -4246,6 +4262,7 @@ class Cifar100Dataset(MappableDataset): | |||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["shuffle"] = self.shuffle_level | args["shuffle"] = self.shuffle_level | ||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -4630,6 +4647,8 @@ class VOCDataset(MappableDataset): | |||||
| into (default=None). | into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If xml of Annotations is an invalid format. | RuntimeError: If xml of Annotations is an invalid format. | ||||
| @@ -4667,7 +4686,8 @@ class VOCDataset(MappableDataset): | |||||
| @check_vocdataset | @check_vocdataset | ||||
| def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None, | def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None, | ||||
| num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): | |||||
| num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, | |||||
| cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| self.task = task | self.task = task | ||||
| @@ -4679,6 +4699,7 @@ class VOCDataset(MappableDataset): | |||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.cache = cache | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| @@ -4692,6 +4713,7 @@ class VOCDataset(MappableDataset): | |||||
| args["shuffle"] = self.shuffle_level | args["shuffle"] = self.shuffle_level | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -4838,6 +4860,8 @@ class CocoDataset(MappableDataset): | |||||
| into (default=None). | into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -4873,7 +4897,7 @@ class CocoDataset(MappableDataset): | |||||
| @check_cocodataset | @check_cocodataset | ||||
| def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None, | def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None, | ||||
| shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): | |||||
| shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| self.annotation_file = annotation_file | self.annotation_file = annotation_file | ||||
| @@ -4884,6 +4908,7 @@ class CocoDataset(MappableDataset): | |||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.cache = cache | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| @@ -4896,6 +4921,7 @@ class CocoDataset(MappableDataset): | |||||
| args["shuffle"] = self.shuffle_level | args["shuffle"] = self.shuffle_level | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -4993,6 +5019,8 @@ class CelebADataset(MappableDataset): | |||||
| into (default=None). | into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -5003,7 +5031,7 @@ class CelebADataset(MappableDataset): | |||||
| @check_celebadataset | @check_celebadataset | ||||
| def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, | def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, | ||||
| extensions=None, num_samples=None, num_shards=None, shard_id=None): | |||||
| extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | ||||
| @@ -5015,6 +5043,7 @@ class CelebADataset(MappableDataset): | |||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| self.cache = cache | |||||
| if usage != "all": | if usage != "all": | ||||
| dir = os.path.realpath(self.dataset_dir) | dir = os.path.realpath(self.dataset_dir) | ||||
| @@ -5033,6 +5062,7 @@ class CelebADataset(MappableDataset): | |||||
| args["usage"] = self.usage | args["usage"] = self.usage | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -5142,6 +5172,8 @@ class CLUEDataset(SourceDataset): | |||||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -5152,7 +5184,7 @@ class CLUEDataset(SourceDataset): | |||||
| @check_cluedataset | @check_cluedataset | ||||
| def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, | def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, | ||||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_files = self._find_files(dataset_files) | self.dataset_files = self._find_files(dataset_files) | ||||
| self.dataset_files.sort() | self.dataset_files.sort() | ||||
| @@ -5293,6 +5325,15 @@ class CLUEDataset(SourceDataset): | |||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| # The clue dataset does not directly support a sampler. It has provided sampling arguments | |||||
| # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in | |||||
| # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. | |||||
| sampler_shuffle = self.shuffle_files | |||||
| sampler = None | |||||
| self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, | |||||
| non_mappable=True) | |||||
| self.cache = cache | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["dataset_files"] = self.dataset_files | args["dataset_files"] = self.dataset_files | ||||
| @@ -5304,6 +5345,8 @@ class CLUEDataset(SourceDataset): | |||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["cols_to_keyword"] = self.cols_to_keyword | args["cols_to_keyword"] = self.cols_to_keyword | ||||
| args["sampler"] = self.sampler | |||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -5359,6 +5402,9 @@ class CSVDataset(SourceDataset): | |||||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -5369,7 +5415,7 @@ class CSVDataset(SourceDataset): | |||||
| @check_csvdataset | @check_csvdataset | ||||
| def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, | def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, | ||||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_files = self._find_files(dataset_files) | self.dataset_files = self._find_files(dataset_files) | ||||
| self.dataset_files.sort() | self.dataset_files.sort() | ||||
| @@ -5394,6 +5440,15 @@ class CSVDataset(SourceDataset): | |||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.cache = cache | |||||
| # The CSV dataset does not directly support a sampler. It has provided sampling arguments | |||||
| # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in | |||||
| # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. | |||||
| sampler_shuffle = self.shuffle_files | |||||
| sampler = None | |||||
| self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, | |||||
| non_mappable=True) | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["dataset_files"] = self.dataset_files | args["dataset_files"] = self.dataset_files | ||||
| @@ -5407,6 +5462,8 @@ class CSVDataset(SourceDataset): | |||||
| args["shuffle"] = self.shuffle_level | args["shuffle"] = self.shuffle_level | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["sampler"] = self.sampler | |||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -5457,6 +5514,9 @@ class TextFileDataset(SourceDataset): | |||||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| >>> | >>> | ||||
| @@ -5466,7 +5526,7 @@ class TextFileDataset(SourceDataset): | |||||
| @check_textfiledataset | @check_textfiledataset | ||||
| def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, | def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, | ||||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | |||||
| shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_files = self._find_files(dataset_files) | self.dataset_files = self._find_files(dataset_files) | ||||
| self.dataset_files.sort() | self.dataset_files.sort() | ||||
| @@ -5488,6 +5548,15 @@ class TextFileDataset(SourceDataset): | |||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.cache = cache | |||||
| # The text file dataset does not directly support a sampler. It has provided sampling arguments | |||||
| # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in | |||||
| # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. | |||||
| sampler_shuffle = self.shuffle_files | |||||
| sampler = None | |||||
| self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, | |||||
| non_mappable=True) | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["dataset_files"] = self.dataset_files | args["dataset_files"] = self.dataset_files | ||||
| @@ -5498,6 +5567,8 @@ class TextFileDataset(SourceDataset): | |||||
| args["shuffle"] = self.shuffle_level | args["shuffle"] = self.shuffle_level | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| args["sampler"] = self.sampler | |||||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||||
| return args | return args | ||||
| def get_dataset_size(self): | def get_dataset_size(self): | ||||
| @@ -83,6 +83,9 @@ def check_mnist_cifar_dataset(method): | |||||
| check_sampler_shuffle_shard_options(param_dict) | check_sampler_shuffle_shard_options(param_dict) | ||||
| cache = param_dict.get('cache') | |||||
| check_cache_option(cache) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -110,6 +113,9 @@ def check_manifestdataset(method): | |||||
| check_sampler_shuffle_shard_options(param_dict) | check_sampler_shuffle_shard_options(param_dict) | ||||
| cache = param_dict.get('cache') | |||||
| check_cache_option(cache) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -180,6 +186,9 @@ def check_vocdataset(method): | |||||
| validate_dataset_param_value(nreq_param_dict, param_dict, dict) | validate_dataset_param_value(nreq_param_dict, param_dict, dict) | ||||
| check_sampler_shuffle_shard_options(param_dict) | check_sampler_shuffle_shard_options(param_dict) | ||||
| cache = param_dict.get('cache') | |||||
| check_cache_option(cache) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -216,6 +225,9 @@ def check_cocodataset(method): | |||||
| raise ValueError("CocoDataset doesn't support PKSampler") | raise ValueError("CocoDataset doesn't support PKSampler") | ||||
| check_sampler_shuffle_shard_options(param_dict) | check_sampler_shuffle_shard_options(param_dict) | ||||
| cache = param_dict.get('cache') | |||||
| check_cache_option(cache) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -252,6 +264,9 @@ def check_celebadataset(method): | |||||
| if sampler is not None and isinstance(sampler, samplers.PKSampler): | if sampler is not None and isinstance(sampler, samplers.PKSampler): | ||||
| raise ValueError("CelebADataset does not support PKSampler.") | raise ValueError("CelebADataset does not support PKSampler.") | ||||
| cache = param_dict.get('cache') | |||||
| check_cache_option(cache) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -843,6 +858,9 @@ def check_cluedataset(method): | |||||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | validate_dataset_param_value(nreq_param_int, param_dict, int) | ||||
| check_sampler_shuffle_shard_options(param_dict) | check_sampler_shuffle_shard_options(param_dict) | ||||
| cache = param_dict.get('cache') | |||||
| check_cache_option(cache) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -886,6 +904,9 @@ def check_csvdataset(method): | |||||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | validate_dataset_param_value(nreq_param_int, param_dict, int) | ||||
| check_sampler_shuffle_shard_options(param_dict) | check_sampler_shuffle_shard_options(param_dict) | ||||
| cache = param_dict.get('cache') | |||||
| check_cache_option(cache) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -905,6 +926,9 @@ def check_textfiledataset(method): | |||||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | validate_dataset_param_value(nreq_param_int, param_dict, int) | ||||
| check_sampler_shuffle_shard_options(param_dict) | check_sampler_shuffle_shard_options(param_dict) | ||||
| cache = param_dict.get('cache') | |||||
| check_cache_option(cache) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -103,6 +103,24 @@ HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1 | PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| PytestCmd "test_cache_map.py" "test_cache_map_coco" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_mnist" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_celeba" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_manifest" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_cifar" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_voc" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # Run two parallel pipelines (sharing cache) | # Run two parallel pipelines (sharing cache) | ||||
| for i in $(seq 1 2) | for i in $(seq 1 2) | ||||
| do | do | ||||
| @@ -282,6 +300,15 @@ HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1 | PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_clue" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_csv" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| for i in $(seq 1 3) | for i in $(seq 1 3) | ||||
| do | do | ||||
| test_name="test_cache_nomap_multiple_cache${i}" | test_name="test_cache_nomap_multiple_cache${i}" | ||||
| @@ -17,6 +17,7 @@ Testing cache operator with mappable datasets | |||||
| """ | """ | ||||
| import os | import os | ||||
| import pytest | import pytest | ||||
| import numpy as np | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.vision.c_transforms as c_vision | import mindspore.dataset.vision.c_transforms as c_vision | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -26,7 +27,13 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| COCO_DATA_DIR = "../data/dataset/testCOCO/train/" | COCO_DATA_DIR = "../data/dataset/testCOCO/train/" | ||||
| COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" | COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" | ||||
| NO_IMAGE_DIR = "../data/dataset/testRandomData/" | NO_IMAGE_DIR = "../data/dataset/testRandomData/" | ||||
| MNIST_DATA_DIR = "../data/dataset/testMnistData/" | |||||
| CELEBA_DATA_DIR = "../data/dataset/testCelebAData/" | |||||
| VOC_DATA_DIR = "../data/dataset/testVOC2012/" | |||||
| MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" | |||||
| CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/" | |||||
| CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/" | |||||
| MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord" | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| @@ -443,7 +450,7 @@ def test_cache_map_failure5(): | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_map_failure6(): | def test_cache_map_failure6(): | ||||
| """ | """ | ||||
| Test no-cache-supporting leaf ops with Map under cache (failure) | |||||
| Test no-cache-supporting MindRecord leaf with Map under cache (failure) | |||||
| repeat | repeat | ||||
| | | | | ||||
| @@ -451,7 +458,7 @@ def test_cache_map_failure6(): | |||||
| | | | | ||||
| Map(resize) | Map(resize) | ||||
| | | | | ||||
| Coco | |||||
| MindRecord | |||||
| """ | """ | ||||
| logger.info("Test cache failure 6") | logger.info("Test cache failure 6") | ||||
| @@ -461,22 +468,66 @@ def test_cache_map_failure6(): | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | raise RuntimeError("Testcase requires SESSION_ID environment variable") | ||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | ||||
| data = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) | |||||
| columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] | |||||
| num_readers = 1 | |||||
| # The dataset has 5 records | |||||
| data = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, num_readers) | |||||
| resize_op = c_vision.Resize((224, 224)) | resize_op = c_vision.Resize((224, 224)) | ||||
| data = data.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| data = data.map(input_columns=["img_data"], operations=resize_op, cache=some_cache) | |||||
| data = data.repeat(4) | data = data.repeat(4) | ||||
| with pytest.raises(RuntimeError) as e: | with pytest.raises(RuntimeError) as e: | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data.create_dict_iterator(): | for _ in data.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert "There is currently no support for CocoOp under cache" in str(e.value) | |||||
| assert "There is currently no support for MindRecordOp under cache" in str(e.value) | |||||
| assert num_iter == 0 | assert num_iter == 0 | ||||
| logger.info('test_cache_failure6 Ended.\n') | logger.info('test_cache_failure6 Ended.\n') | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_failure7(): | |||||
| """ | |||||
| Test no-cache-supporting Generator leaf with Map under cache (failure) | |||||
| repeat | |||||
| | | |||||
| Cache | |||||
| | | |||||
| Map(lambda x: x) | |||||
| | | |||||
| Generator | |||||
| """ | |||||
| def generator_1d(): | |||||
| for i in range(64): | |||||
| yield (np.array(i),) | |||||
| logger.info("Test cache failure 7") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| data = ds.GeneratorDataset(generator_1d, ["data"]) | |||||
| data = data.map((lambda x: x), ["data"], cache=some_cache) | |||||
| data = data.repeat(4) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in data.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert "There is currently no support for GeneratorOp under cache" in str(e.value) | |||||
| assert num_iter == 0 | |||||
| logger.info('test_cache_failure7 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_map_parameter_check(): | def test_cache_map_parameter_check(): | ||||
| """ | """ | ||||
| @@ -1236,6 +1287,421 @@ def test_cache_map_epoch_ctrl3(): | |||||
| logger.info("test_cache_map_epoch_ctrl3 Ended.\n") | logger.info("test_cache_map_epoch_ctrl3 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_coco1(): | |||||
| """ | |||||
| Test mappable coco leaf with cache op right over the leaf | |||||
| cache | |||||
| | | |||||
| Coco | |||||
| """ | |||||
| logger.info("Test cache map coco1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 6 records | |||||
| ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, | |||||
| cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 6 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_coco1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_coco2(): | |||||
| """ | |||||
| Test mappable coco leaf with the cache op later in the tree above the map(resize) | |||||
| cache | |||||
| | | |||||
| Map(resize) | |||||
| | | |||||
| Coco | |||||
| """ | |||||
| logger.info("Test cache map coco2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 6 records | |||||
| ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) | |||||
| resize_op = c_vision.Resize((224, 224)) | |||||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 6 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_coco2 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_mnist1(): | |||||
| """ | |||||
| Test mappable mnist leaf with cache op right over the leaf | |||||
| cache | |||||
| | | |||||
| Mnist | |||||
| """ | |||||
| logger.info("Test cache map mnist1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 10 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_mnist1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_mnist2(): | |||||
| """ | |||||
| Test mappable mnist leaf with the cache op later in the tree above the map(resize) | |||||
| cache | |||||
| | | |||||
| Map(resize) | |||||
| | | |||||
| Mnist | |||||
| """ | |||||
| logger.info("Test cache map mnist2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) | |||||
| resize_op = c_vision.Resize((224, 224)) | |||||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 10 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_mnist2 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_celeba1(): | |||||
| """ | |||||
| Test mappable celeba leaf with cache op right over the leaf | |||||
| cache | |||||
| | | |||||
| CelebA | |||||
| """ | |||||
| logger.info("Test cache map celeba1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 4 records | |||||
| ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 4 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_celeba1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_celeba2(): | |||||
| """ | |||||
| Test mappable celeba leaf with the cache op later in the tree above the map(resize) | |||||
| cache | |||||
| | | |||||
| Map(resize) | |||||
| | | |||||
| CelebA | |||||
| """ | |||||
| logger.info("Test cache map celeba2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 4 records | |||||
| ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) | |||||
| resize_op = c_vision.Resize((224, 224)) | |||||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 4 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_celeba2 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_manifest1(): | |||||
| """ | |||||
| Test mappable manifest leaf with cache op right over the leaf | |||||
| cache | |||||
| | | |||||
| Manifest | |||||
| """ | |||||
| logger.info("Test cache map manifest1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 4 records | |||||
| ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 4 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_manifest1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_manifest2(): | |||||
| """ | |||||
| Test mappable manifest leaf with the cache op later in the tree above the map(resize) | |||||
| cache | |||||
| | | |||||
| Map(resize) | |||||
| | | |||||
| Manifest | |||||
| """ | |||||
| logger.info("Test cache map manifest2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 4 records | |||||
| ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) | |||||
| resize_op = c_vision.Resize((224, 224)) | |||||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 4 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_manifest2 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_cifar1(): | |||||
| """ | |||||
| Test mappable cifar10 leaf with cache op right over the leaf | |||||
| cache | |||||
| | | |||||
| Cifar10 | |||||
| """ | |||||
| logger.info("Test cache map cifar1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 10 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_cifar1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_cifar2(): | |||||
| """ | |||||
| Test mappable cifar100 leaf with the cache op later in the tree above the map(resize) | |||||
| cache | |||||
| | | |||||
| Map(resize) | |||||
| | | |||||
| Cifar100 | |||||
| """ | |||||
| logger.info("Test cache map cifar2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) | |||||
| resize_op = c_vision.Resize((224, 224)) | |||||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 10 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_cifar2 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_voc1(): | |||||
| """ | |||||
| Test mappable voc leaf with cache op right over the leaf | |||||
| cache | |||||
| | | |||||
| VOC | |||||
| """ | |||||
| logger.info("Test cache map voc1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 9 records | |||||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 9 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_voc1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_voc2(): | |||||
| """ | |||||
| Test mappable voc leaf with the cache op later in the tree above the map(resize) | |||||
| cache | |||||
| | | |||||
| Map(resize) | |||||
| | | |||||
| VOC | |||||
| """ | |||||
| logger.info("Test cache map voc2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 9 records | |||||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| resize_op = c_vision.Resize((224, 224)) | |||||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 9 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_voc2 Ended.\n") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_cache_map_basic1() | test_cache_map_basic1() | ||||
| test_cache_map_basic2() | test_cache_map_basic2() | ||||
| @@ -20,22 +20,26 @@ import itertools | |||||
| import pytest | import pytest | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.text as text | |||||
| import mindspore.dataset.vision.c_transforms as c_vision | import mindspore.dataset.vision.c_transforms as c_vision | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | ||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | ||||
| DATA_DIR2 = ["../data/dataset/testTextTFRecord/text.tfrecord"] | |||||
| TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"] | |||||
| SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json" | SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json" | ||||
| DATA_DIR3 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | |||||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", | |||||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | |||||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | |||||
| SCHEMA_DIR3 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" | |||||
| TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | |||||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", | |||||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | |||||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | |||||
| TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" | |||||
| DATA_DIR4 = "../data/dataset/testImageNetData/train/" | |||||
| IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/" | |||||
| CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json' | |||||
| CSV_DATA_DIR = '../data/dataset/testCSV/1.csv' | |||||
| TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt" | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| @@ -1310,7 +1314,7 @@ def test_cache_nomap_multiple_cache1(): | |||||
| eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | ||||
| # This dataset has 12 records in it | # This dataset has 12 records in it | ||||
| train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3) | |||||
| train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) | |||||
| decode_op = c_vision.Decode() | decode_op = c_vision.Decode() | ||||
| train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) | train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) | ||||
| @@ -1359,7 +1363,7 @@ def test_cache_nomap_multiple_cache2(): | |||||
| image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) | image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) | ||||
| # This dataset has 3 records in it only | # This dataset has 3 records in it only | ||||
| text_dataset = ds.TFRecordDataset(DATA_DIR2, SCHEMA_DIR2, cache=text_cache) | |||||
| text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache) | |||||
| num_epoch = 5 | num_epoch = 5 | ||||
| image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch) | image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch) | ||||
| @@ -1404,7 +1408,7 @@ def test_cache_nomap_multiple_cache3(): | |||||
| tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache) | tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache) | ||||
| # This DATA_DIR only has 2 images in it | # This DATA_DIR only has 2 images in it | ||||
| image_dataset = ds.ImageFolderDataset(dataset_dir=DATA_DIR4) | |||||
| image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR) | |||||
| image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) | image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) | ||||
| num_epoch = 5 | num_epoch = 5 | ||||
| @@ -1443,7 +1447,7 @@ def test_cache_nomap_multiple_cache_train(): | |||||
| train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | ||||
| # This dataset has 12 records in it | # This dataset has 12 records in it | ||||
| train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3) | |||||
| train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) | |||||
| decode_op = c_vision.Decode() | decode_op = c_vision.Decode() | ||||
| train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) | train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) | ||||
| @@ -1497,6 +1501,239 @@ def test_cache_nomap_multiple_cache_eval(): | |||||
| logger.info("test_cache_nomap_multiple_cache_eval Ended.\n") | logger.info("test_cache_nomap_multiple_cache_eval Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_clue1(): | |||||
| """ | |||||
| A clue dataset (a non mappable dataset) with a cache over it just after the leaf | |||||
| In this one, the clue dataset will be given sharding configuration, however since a cache is | |||||
| used, the tree prepare should undo the sharding configuration and instead, a distributed | |||||
| sampler will be chosen with the same shard config. | |||||
| Cache | |||||
| | | |||||
| CLUE | |||||
| """ | |||||
| logger.info("Test cache nomap clue 1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||||
| # In this case, it is a row-based sharding, not the file-based sharding that would happen if | |||||
| # there was not any cache. | |||||
| ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 1 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_nomap_clue1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_clue2(): | |||||
| """ | |||||
| A clue dataset (a non mappable dataset) with a cache over it after map | |||||
| In this one, a num_samples argument is given | |||||
| Cache | |||||
| | | |||||
| map(lambda x: x) | |||||
| | | |||||
| CLUE | |||||
| """ | |||||
| logger.info("Test cache nomap clue 2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) | |||||
| ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 2 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_nomap_clue2 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_csv1(): | |||||
| """ | |||||
| A csv dataset (a non mappable dataset) with a cache over it just after the leaf | |||||
| In this one, the csv dataset will be given sharding configuration, however since a cache is | |||||
| used, the tree prepare should undo the sharding configuration and instead, a distributed | |||||
| sampler will be chosen with the same shard config. | |||||
| Cache | |||||
| | | |||||
| CSV | |||||
| """ | |||||
| logger.info("Test cache nomap csv 1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||||
| # In this case, it is a row-based sharding, not the file-based sharding that would happen if | |||||
| # there was not any cache. | |||||
| ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 1 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_nomap_csv1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_csv2(): | |||||
| """ | |||||
| A csv dataset (a non mappable dataset) with a cache over it after map | |||||
| In this one, a num_samples argument is given | |||||
| Cache | |||||
| | | |||||
| map(lambda x: x) | |||||
| | | |||||
| CSV | |||||
| """ | |||||
| logger.info("Test cache nomap csv 2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) | |||||
| ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 2 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_nomap_csv2 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_textfile1(): | |||||
| """ | |||||
| A text file dataset (a non mappable dataset) with a cache over it just after the leaf | |||||
| In this one, the text file dataset will be given sharding configuration, however since a cache is | |||||
| used, the tree prepare should undo the sharding configuration and instead, a distributed | |||||
| sampler will be chosen with the same shard config. | |||||
| Cache | |||||
| | | |||||
| TextFile | |||||
| """ | |||||
| logger.info("Test cache nomap textfile 1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||||
| # In this case, it is a row-based sharding, not the file-based sharding that would happen if | |||||
| # there was not any cache. | |||||
| ds1 = ds.CSVDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 1 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_nomap_textfile1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_textfile2(): | |||||
| """ | |||||
| A text file dataset (a non mappable dataset) with a cache over it after map | |||||
| In this one, a num_samples argument is given | |||||
| Cache | |||||
| | | |||||
| Map(tokenizer) | |||||
| | | |||||
| TextFile | |||||
| """ | |||||
| def my_tokenizer(line): | |||||
| words = line.split() | |||||
| if not words: | |||||
| return [""] | |||||
| return words | |||||
| logger.info("Test cache nomap textfile 2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2) | |||||
| tokenizer = text.PythonTokenizer(my_tokenizer) | |||||
| ds1 = ds1.map(operations=tokenizer, cache=some_cache) | |||||
| num_epoch = 4 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 2 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_nomap_textfile2 Ended.\n") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_cache_nomap_basic1() | test_cache_nomap_basic1() | ||||
| test_cache_nomap_basic2() | test_cache_nomap_basic2() | ||||
| @@ -40,8 +40,9 @@ def test_textline_dataset_all_file(): | |||||
| assert count == 5 | assert count == 5 | ||||
| def test_textline_dataset_num_samples_zero(): | |||||
| data = ds.TextFileDataset(DATA_FILE, num_samples=0) | |||||
| def test_textline_dataset_num_samples_none(): | |||||
| # Do not provide a num_samples argument, so it would be None by default | |||||
| data = ds.TextFileDataset(DATA_FILE) | |||||
| count = 0 | count = 0 | ||||
| for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | ||||
| logger.info("{}".format(i["text"])) | logger.info("{}".format(i["text"])) | ||||
| @@ -208,7 +209,7 @@ def test_textline_dataset_exceptions(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_textline_dataset_one_file() | test_textline_dataset_one_file() | ||||
| test_textline_dataset_all_file() | test_textline_dataset_all_file() | ||||
| test_textline_dataset_num_samples_zero() | |||||
| test_textline_dataset_num_samples_none() | |||||
| test_textline_dataset_shuffle_false4() | test_textline_dataset_shuffle_false4() | ||||
| test_textline_dataset_shuffle_false1() | test_textline_dataset_shuffle_false1() | ||||
| test_textline_dataset_shuffle_files4() | test_textline_dataset_shuffle_files4() | ||||