From: @lixiachen Reviewed-by: @nsyca,@mikef,@nsyca,@mikef Signed-off-by: @nsycatags/v1.1.0
| @@ -377,7 +377,7 @@ CacheClient::CacheMissKeys::CacheMissKeys(const std::vector<row_id_type> &v) { | |||||
| gap_.insert(*it); | gap_.insert(*it); | ||||
| ++it; | ++it; | ||||
| } | } | ||||
| MS_LOG(WARNING) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size(); | |||||
| MS_LOG(INFO) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size(); | |||||
| } | } | ||||
| bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { | bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { | ||||
| @@ -116,7 +116,7 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic | |||||
| } else { | } else { | ||||
| // If asked to spill to disk instead but there is no storage set up, simply return no memory | // If asked to spill to disk instead but there is no storage set up, simply return no memory | ||||
| // instead. | // instead. | ||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "No enough storage for cache server to cache data"); | |||||
| } | } | ||||
| } else { | } else { | ||||
| return rc; | return rc; | ||||
| @@ -271,7 +271,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \brief Getter function | /// \brief Getter function | ||||
| /// \return The number of repeats per epoch for the operator | /// \return The number of repeats per epoch for the operator | ||||
| int32_t op_num_repeats_per_epoch() { return op_num_repeats_per_epoch_; } | |||||
| int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; } | |||||
| /// \brief Register the internal worker connectors. No op unless it is a parallel op | /// \brief Register the internal worker connectors. No op unless it is a parallel op | ||||
| /// \return Status | /// \return Status | ||||
| @@ -354,7 +354,7 @@ Status MapOp::ComputeColMap() { | |||||
| RETURN_IF_NOT_OK(InitPrivateVariable(¤t_name_id_map)); | RETURN_IF_NOT_OK(InitPrivateVariable(¤t_name_id_map)); | ||||
| // Create the final column name to index mapping in the base class field | // Create the final column name to index mapping in the base class field | ||||
| CreateFinalColMap(¤t_name_id_map); | CreateFinalColMap(¤t_name_id_map); | ||||
| MS_LOG(DEBUG) << "Column name map for map op set: " << this->ColumnNameMapAsString(); | |||||
| MS_LOG(DEBUG) << "Column name map for map op is: " << this->ColumnNameMapAsString(); | |||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "Column name map is already set!"; | MS_LOG(WARNING) << "Column name map is already set!"; | ||||
| } | } | ||||
| @@ -130,6 +130,12 @@ Status SkipOp::Accept(NodePass *p, bool *modified) { | |||||
| return p->RunOnNode(shared_from_base<SkipOp>(), modified); | return p->RunOnNode(shared_from_base<SkipOp>(), modified); | ||||
| } | } | ||||
| // Visitor pre-accept method for NodePass | |||||
| Status SkipOp::PreAccept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->PreRunOnNode(shared_from_base<SkipOp>(), modified); | |||||
| } | |||||
| // Get Dataset size | // Get Dataset size | ||||
| Status SkipOp::GetDatasetSize(int64_t *dataset_size) { | Status SkipOp::GetDatasetSize(int64_t *dataset_size) { | ||||
| if (dataset_size_ > 0) { | if (dataset_size_ > 0) { | ||||
| @@ -80,6 +80,12 @@ class SkipOp : public PipelineOp { | |||||
| // @return - Status of the node visit. | // @return - Status of the node visit. | ||||
| Status Accept(NodePass *p, bool *modified) override; | Status Accept(NodePass *p, bool *modified) override; | ||||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status PreAccept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for GetDatasetSize | /// \brief Base-class override for GetDatasetSize | ||||
| /// \param[out] dataset_size the size of the dataset | /// \param[out] dataset_size the size of the dataset | ||||
| /// \return Status of the function | /// \return Status of the function | ||||
| @@ -133,6 +133,12 @@ Status TakeOp::Accept(NodePass *p, bool *modified) { | |||||
| return p->RunOnNode(shared_from_base<TakeOp>(), modified); | return p->RunOnNode(shared_from_base<TakeOp>(), modified); | ||||
| } | } | ||||
| // Visitor pre-accept method for NodePass | |||||
| Status TakeOp::PreAccept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->PreRunOnNode(shared_from_base<TakeOp>(), modified); | |||||
| } | |||||
| // Get Dataset size | // Get Dataset size | ||||
| Status TakeOp::GetDatasetSize(int64_t *dataset_size) { | Status TakeOp::GetDatasetSize(int64_t *dataset_size) { | ||||
| if (dataset_size_ > 0) { | if (dataset_size_ > 0) { | ||||
| @@ -84,6 +84,12 @@ class TakeOp : public PipelineOp { | |||||
| // @return - Status of the node visit. | // @return - Status of the node visit. | ||||
| Status Accept(NodePass *p, bool *modified) override; | Status Accept(NodePass *p, bool *modified) override; | ||||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status PreAccept(NodePass *p, bool *modified) override; | |||||
| // Op name getter | // Op name getter | ||||
| // @return Name of the current Op | // @return Name of the current Op | ||||
| std::string Name() const override { return kTakeOp; } | std::string Name() const override { return kTakeOp; } | ||||
| @@ -31,13 +31,13 @@ class DatasetCacheImpl : public DatasetCache { | |||||
| public: | public: | ||||
| /// | /// | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param id A user assigned session id for the current pipeline | |||||
| /// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited | |||||
| /// \param spill Spill to disk if out of memory | |||||
| /// \param hostname optional host name | |||||
| /// \param port optional port | |||||
| /// \param num_connections optional number of connections | |||||
| /// \param prefetch_sz optional prefetch size | |||||
| /// \param id A user assigned session id for the current pipeline. | |||||
| /// \param mem_sz Size of the memory set aside for the row caching (default=0 which means unlimited). | |||||
| /// \param spill Spill to disk if out of memory (default=False). | |||||
| /// \param hostname optional host name (default="127.0.0.1"). | |||||
| /// \param port optional port (default=50052). | |||||
| /// \param num_connections optional number of connections (default=12). | |||||
| /// \param prefetch_sz optional prefetch size (default=20). | |||||
| DatasetCacheImpl(session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname, | DatasetCacheImpl(session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname, | ||||
| std::optional<int32_t> port, std::optional<int32_t> num_connections, | std::optional<int32_t> port, std::optional<int32_t> num_connections, | ||||
| std::optional<int32_t> prefetch_sz) | std::optional<int32_t> prefetch_sz) | ||||
| @@ -444,6 +444,16 @@ Status NodePass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) { | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| @@ -303,6 +303,8 @@ class NodePass : public Pass { | |||||
| virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | ||||
| @@ -65,6 +65,24 @@ Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modifi | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Returns an error if TakeOp exists under a cache | |||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||||
| if (is_cached_) { | |||||
| RETURN_STATUS_UNEXPECTED("TakeOp/SplitOp is currently not supported as a descendant operator under a cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Returns an error if SkipOp exists under a cache | |||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||||
| if (is_cached_) { | |||||
| RETURN_STATUS_UNEXPECTED("SkipOp is currently not supported as a descendant operator under a cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| // Returns an error if FilterOp exists under a cache | // Returns an error if FilterOp exists under a cache | ||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | ||||
| @@ -59,6 +59,18 @@ class CacheErrorPass : public NodePass { | |||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | ||||
| /// \brief Returns an error if TakeOp exists under a cache | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | |||||
| /// \brief Returns an error if SkipOp exists under a cache | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| /// \brief Returns an error if FilterOp exists under a cache | /// \brief Returns an error if FilterOp exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| @@ -287,7 +287,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// name as the input columns, i.e., the columns will be replaced | /// name as the input columns, i.e., the columns will be replaced | ||||
| /// \param[in] project_columns A list of column names to project | /// \param[in] project_columns A list of column names to project | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current MapDataset | /// \return Shared pointer to the current MapDataset | ||||
| std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorOperation>> operations, | std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns = {}, | std::vector<std::string> input_columns = {}, | ||||
| @@ -553,7 +552,6 @@ class AlbumDataset : public Dataset { | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema, | std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema, | ||||
| const std::vector<std::string> &column_names = {}, bool decode = false, | const std::vector<std::string> &column_names = {}, bool decode = false, | ||||
| @@ -580,7 +578,6 @@ class CelebADataset : public Dataset { | |||||
| /// \param[in] decode Decode the images after reading (default=false). | /// \param[in] decode Decode the images after reading (default=false). | ||||
| /// \param[in] extensions Set of file extensions to be included in the dataset (default={}). | /// \param[in] extensions Set of file extensions to be included in the dataset (default={}). | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage = "all", | std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage = "all", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false, | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false, | ||||
| @@ -602,7 +599,6 @@ class Cifar10Dataset : public Dataset { | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage = "all", | std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage = "all", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | ||||
| @@ -623,7 +619,6 @@ class Cifar100Dataset : public Dataset { | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage = "all", | std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage = "all", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | ||||
| @@ -655,7 +650,6 @@ class CLUEDataset : public Dataset { | |||||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | /// \param[in] shard_id The shard ID within num_shards. This argument should be | ||||
| /// specified only when num_shards is also specified. (Default = 0) | /// specified only when num_shards is also specified. (Default = 0) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current CLUEDataset | /// \return Shared pointer to the current CLUEDataset | ||||
| std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC", | std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC", | ||||
| const std::string &usage = "train", int64_t num_samples = 0, | const std::string &usage = "train", int64_t num_samples = 0, | ||||
| @@ -686,7 +680,6 @@ class CocoDataset : public Dataset { | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file, | std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file, | ||||
| const std::string &task = "Detection", const bool &decode = false, | const std::string &task = "Detection", const bool &decode = false, | ||||
| @@ -723,7 +716,6 @@ class CSVDataset : public Dataset { | |||||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | /// \param[in] shard_id The shard ID within num_shards. This argument should be | ||||
| /// specified only when num_shards is also specified. (Default = 0) | /// specified only when num_shards is also specified. (Default = 0) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',', | std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',', | ||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {}, | const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {}, | ||||
| @@ -752,7 +744,6 @@ class ImageFolderDataset : public Dataset { | |||||
| /// \param[in] extensions File extensions to be read | /// \param[in] extensions File extensions to be read | ||||
| /// \param[in] class_indexing a class name to label map | /// \param[in] class_indexing a class name to label map | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current ImageFolderDataset | /// \return Shared pointer to the current ImageFolderDataset | ||||
| std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode = false, | std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode = false, | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | ||||
| @@ -779,7 +770,6 @@ class ManifestDataset : public Dataset { | |||||
| /// names will be sorted alphabetically and each class will be given a unique index starting from 0). | /// names will be sorted alphabetically and each class will be given a unique index starting from 0). | ||||
| /// \param[in] decode Decode the images after reading (default=false). | /// \param[in] decode Decode the images after reading (default=false). | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current ManifestDataset | /// \return Shared pointer to the current ManifestDataset | ||||
| std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const std::string &usage = "train", | std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const std::string &usage = "train", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | ||||
| @@ -842,7 +832,6 @@ class MnistDataset : public Dataset { | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current MnistDataset | /// \return Shared pointer to the current MnistDataset | ||||
| std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = "all", | std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = "all", | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), | ||||
| @@ -874,7 +863,6 @@ class RandomDataDataset : public Dataset { | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| template <typename T = std::shared_ptr<SchemaObj>> | template <typename T = std::shared_ptr<SchemaObj>> | ||||
| std::shared_ptr<RandomDataDataset> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, | std::shared_ptr<RandomDataDataset> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, | ||||
| @@ -913,7 +901,6 @@ class TextFileDataset : public Dataset { | |||||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | /// \param[in] shard_id The shard ID within num_shards. This argument should be | ||||
| /// specified only when num_shards is also specified. (Default = 0) | /// specified only when num_shards is also specified. (Default = 0) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current TextFileDataset | /// \return Shared pointer to the current TextFileDataset | ||||
| std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0, | std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples = 0, | ||||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | ||||
| @@ -956,7 +943,6 @@ class TFRecordDataset : public Dataset { | |||||
| /// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of | /// \param[in] shard_equal_rows Get equal rows for all shards. (Default = False, number of rows of | ||||
| /// each shard may be not equal) | /// each shard may be not equal) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current TFRecordDataset | /// \return Shared pointer to the current TFRecordDataset | ||||
| template <typename T = std::shared_ptr<SchemaObj>> | template <typename T = std::shared_ptr<SchemaObj>> | ||||
| std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr, | std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &dataset_files, const T &schema = nullptr, | ||||
| @@ -1006,7 +992,6 @@ class VOCDataset : public Dataset { | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | /// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used). | ||||
| /// The cache feature is under development and is not recommended. | |||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", | std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", | ||||
| const std::string &usage = "train", | const std::string &usage = "train", | ||||
| @@ -1015,13 +1000,13 @@ std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::strin | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr); | const std::shared_ptr<DatasetCache> &cache = nullptr); | ||||
| /// \brief Function the create a cache to be attached to a dataset | /// \brief Function the create a cache to be attached to a dataset | ||||
| /// \param id A user assigned session id for the current pipeline | |||||
| /// \param mem_sz Size of the memory set aside for the row caching. 0 for unlimited | |||||
| /// \param spill Spill to disk if out of memory | |||||
| /// \param hostname optional host name | |||||
| /// \param port optional port | |||||
| /// \param num_connections optional number of connections | |||||
| /// \param prefetch_sz optional prefetch size | |||||
| /// \param id A user assigned session id for the current pipeline. | |||||
| /// \param mem_sz Size of the memory set aside for the row caching (default=0 which means unlimited). | |||||
| /// \param spill Spill to disk if out of memory (default=False). | |||||
| /// \param hostname optional host name (default="127.0.0.1"). | |||||
| /// \param port optional port (default=50052). | |||||
| /// \param num_connections optional number of connections (default=12). | |||||
| /// \param prefetch_sz optional prefetch size (default=20). | |||||
| /// \return Shared pointer to DatasetCache. If error, nullptr is returned. | /// \return Shared pointer to DatasetCache. If error, nullptr is returned. | ||||
| std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, | std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill, | ||||
| std::optional<std::string> hostname = std::nullopt, | std::optional<std::string> hostname = std::nullopt, | ||||
| @@ -24,9 +24,18 @@ from ..core.validator_helpers import type_check, check_uint32, check_uint64, che | |||||
| class DatasetCache: | class DatasetCache: | ||||
| """ | """ | ||||
| A client to interface with tensor caching service | A client to interface with tensor caching service | ||||
| Args: | |||||
| session_id (int): A user assigned session id for the current pipeline. | |||||
| size (int, optional): Size of the memory set aside for the row caching (default=0 which means unlimited). | |||||
| spilling (bool, optional): Whether or not spilling to disk if out of memory (default=False). | |||||
| hostname (str, optional): Host name (default="127.0.0.1"). | |||||
| port (int, optional): Port to connect to server (default=50052). | |||||
| num_connections (int, optional): Number of tcp/ip connections (default=12). | |||||
| prefetch_size (int, optional): Prefetch size (default=20). | |||||
| """ | """ | ||||
| def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None, | |||||
| def __init__(self, session_id, size=0, spilling=False, hostname=None, port=None, num_connections=None, | |||||
| prefetch_size=None): | prefetch_size=None): | ||||
| check_uint32(session_id, "session_id") | check_uint32(session_id, "session_id") | ||||
| type_check(size, (int,), "size") | type_check(size, (int,), "size") | ||||
| @@ -489,7 +489,6 @@ class Dataset: | |||||
| python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This | python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This | ||||
| option could be beneficial if the Python operation is computational heavy (default=False). | option could be beneficial if the Python operation is computational heavy (default=False). | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None). | callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None). | ||||
| @@ -2203,7 +2202,6 @@ class MapDataset(Dataset): | |||||
| python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This | python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This | ||||
| option could be beneficial if the Python operation is computational heavy (default=False). | option could be beneficial if the Python operation is computational heavy (default=False). | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None) | callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None) | ||||
| Raises: | Raises: | ||||
| @@ -2944,7 +2942,6 @@ class ImageFolderDataset(MappableDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -3092,7 +3089,6 @@ class MnistDataset(MappableDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -3782,7 +3778,7 @@ class TFRecordDataset(SourceDataset): | |||||
| shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows | shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows | ||||
| is false, number of rows of each shard may be not equal. | is false, number of rows of each shard may be not equal. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| >>> import mindspore.common.dtype as mstype | >>> import mindspore.common.dtype as mstype | ||||
| @@ -3972,7 +3968,6 @@ class ManifestDataset(MappableDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -4135,7 +4130,6 @@ class Cifar10Dataset(MappableDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -4276,7 +4270,6 @@ class Cifar100Dataset(MappableDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -4358,7 +4351,6 @@ class RandomDataset(SourceDataset): | |||||
| num_parallel_workers (int, optional): Number of workers to read the data | num_parallel_workers (int, optional): Number of workers to read the data | ||||
| (default=None, number set in the config). | (default=None, number set in the config). | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset | shuffle (bool, optional): Whether or not to perform shuffle on the dataset | ||||
| (default=None, expected order behavior shown in the table). | (default=None, expected order behavior shown in the table). | ||||
| num_shards (int, optional): Number of shards that the dataset will be divided | num_shards (int, optional): Number of shards that the dataset will be divided | ||||
| @@ -4596,7 +4588,6 @@ class VOCDataset(MappableDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If xml of Annotations is an invalid format. | RuntimeError: If xml of Annotations is an invalid format. | ||||
| @@ -4791,7 +4782,6 @@ class CocoDataset(MappableDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Raises: | Raises: | ||||
| RuntimeError: If sampler and shuffle are specified at the same time. | RuntimeError: If sampler and shuffle are specified at the same time. | ||||
| @@ -4944,7 +4934,6 @@ class CelebADataset(MappableDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -5057,7 +5046,6 @@ class CLUEDataset(SourceDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -5291,7 +5279,6 @@ class CSVDataset(SourceDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| @@ -5403,7 +5390,6 @@ class TextFileDataset(SourceDataset): | |||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | shard_id (int, optional): The shard ID within num_shards (default=None). This | ||||
| argument can only be specified when num_shards is also specified. | argument can only be specified when num_shards is also specified. | ||||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | ||||
| The cache feature is under development and is not recommended. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -55,6 +55,9 @@ export SESSION_ID=$session_id | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_failure" 1 | PytestCmd "test_cache_map.py" "test_cache_map_failure" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| PytestCmd "test_cache_map.py" "test_cache_map_split" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # DatasetCache parameter check | # DatasetCache parameter check | ||||
| PytestCmd "test_cache_map.py" "test_cache_map_parameter_check" | PytestCmd "test_cache_map.py" "test_cache_map_parameter_check" | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| @@ -528,6 +528,190 @@ def test_cache_map_failure8(): | |||||
| logger.info('test_cache_failure8 Ended.\n') | logger.info('test_cache_failure8 Ended.\n') | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_failure9(): | |||||
| """ | |||||
| Test take under cache (failure) | |||||
| repeat | |||||
| | | |||||
| Cache | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| Take | |||||
| | | |||||
| ImageFolder | |||||
| """ | |||||
| logger.info("Test cache failure 9") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This DATA_DIR only has 2 images in it | |||||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||||
| ds1 = ds1.take(2) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||||
| ds1 = ds1.repeat(4) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||||
| assert num_iter == 0 | |||||
| logger.info('test_cache_failure9 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_failure10(): | |||||
| """ | |||||
| Test skip under cache (failure) | |||||
| repeat | |||||
| | | |||||
| Cache | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| Skip | |||||
| | | |||||
| ImageFolder | |||||
| """ | |||||
| logger.info("Test cache failure 10") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This DATA_DIR only has 2 images in it | |||||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||||
| ds1 = ds1.skip(1) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||||
| ds1 = ds1.repeat(4) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert "SkipOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||||
| assert num_iter == 0 | |||||
| logger.info('test_cache_failure10 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_split1(): | |||||
| """ | |||||
| Test split (after a non-source node) under cache (failure). | |||||
| Split after a non-source node is implemented with TakeOp/SkipOp, hence the failure. | |||||
| repeat | |||||
| | | |||||
| Cache | |||||
| | | |||||
| Map(resize) | |||||
| | | |||||
| Split | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| ImageFolder | |||||
| """ | |||||
| logger.info("Test cache split 1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This DATA_DIR only has 2 images in it | |||||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op) | |||||
| ds1, ds2 = ds1.split([0.5, 0.5]) | |||||
| resize_op = c_vision.Resize((224, 224)) | |||||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| ds2 = ds2.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| ds1 = ds1.repeat(4) | |||||
| ds2 = ds2.repeat(4) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in ds2.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert "TakeOp/SplitOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||||
| logger.info('test_cache_split1 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_split2(): | |||||
| """ | |||||
| Test split (after a source node) under cache (ok). | |||||
| Split after a source node is implemented with subset sampler, hence ok. | |||||
| repeat | |||||
| | | |||||
| Cache | |||||
| | | |||||
| Map(resize) | |||||
| | | |||||
| Split | |||||
| | | |||||
| VOCDataset | |||||
| """ | |||||
| logger.info("Test cache split 2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 9 records | |||||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| ds1, ds2 = ds1.split([0.3, 0.7]) | |||||
| resize_op = c_vision.Resize((224, 224)) | |||||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| ds2 = ds2.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||||
| ds1 = ds1.repeat(4) | |||||
| ds2 = ds2.repeat(4) | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert num_iter == 12 | |||||
| num_iter = 0 | |||||
| for _ in ds2.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert num_iter == 24 | |||||
| logger.info('test_cache_split2 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_map_parameter_check(): | def test_cache_map_parameter_check(): | ||||
| """ | """ | ||||
| @@ -1748,7 +1748,7 @@ def test_cache_nomap_textfile1(): | |||||
| # However, the sharding will be done by the sampler, not by the clue leaf node | # However, the sharding will be done by the sampler, not by the clue leaf node | ||||
| # In this case, it is a row-based sharding, not the file-based sharding that would happen if | # In this case, it is a row-based sharding, not the file-based sharding that would happen if | ||||
| # there was not any cache. | # there was not any cache. | ||||
| ds1 = ds.CSVDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache) | |||||
| ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache) | |||||
| num_epoch = 4 | num_epoch = 4 | ||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | ||||