From: @lixiachen Reviewed-by: @liucunwei,@pandoublefeng Signed-off-by: @liucunwei,@liucunweitags/v1.1.0
| @@ -275,24 +275,14 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) { | |||||
| } | } | ||||
| Status CacheMergeOp::EoeReceived(int32_t worker_id) { | Status CacheMergeOp::EoeReceived(int32_t worker_id) { | ||||
| // If we are in a repeat path, send the eoe up. | |||||
| // Otherwise ignore it. | |||||
| if (op_total_repeats_ != 1) { | |||||
| return DatasetOp::EoeReceived(worker_id); | |||||
| } | |||||
| return Status::OK(); | |||||
| // Send the eoe up. | |||||
| MS_LOG(DEBUG) << "Cache merge sending eoe"; | |||||
| return DatasetOp::EoeReceived(worker_id); | |||||
| } | } | ||||
| // Base-class override for handling cases when an eof is received. | // Base-class override for handling cases when an eof is received. | ||||
| Status CacheMergeOp::EofReceived(int32_t worker_id) { | Status CacheMergeOp::EofReceived(int32_t worker_id) { | ||||
| // If we are not in a repeated path, then the merge op gets a eof by itself, without first | |||||
| // getting an eoe. However, the logic demands that all epochs close with an eoe first before eof. | |||||
| // Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class | |||||
| // provides that for us. | |||||
| if (op_total_repeats_ == 1) { | |||||
| MS_LOG(DEBUG) << "Cache merge sending eoe"; | |||||
| RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id)); | |||||
| } | |||||
| // Send the eof up. | |||||
| MS_LOG(DEBUG) << "Cache merge sending eof"; | MS_LOG(DEBUG) << "Cache merge sending eof"; | ||||
| return DatasetOp::EofReceived(worker_id); | return DatasetOp::EofReceived(worker_id); | ||||
| } | } | ||||
| @@ -23,7 +23,10 @@ from ..core.validator_helpers import type_check, check_uint32, check_uint64, che | |||||
| class DatasetCache: | class DatasetCache: | ||||
| """ | """ | ||||
| A client to interface with tensor caching service | |||||
| A client to interface with tensor caching service. | |||||
| For details, please check `Chinese tutorial <https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html>`_, | |||||
| `Chinese programming guide <https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache>`_. | |||||
| Args: | Args: | ||||
| session_id (int): A user assigned session id for the current pipeline. | session_id (int): A user assigned session id for the current pipeline. | ||||
| @@ -34,9 +37,6 @@ class DatasetCache: | |||||
| num_connections (int, optional): Number of tcp/ip connections (default=12). | num_connections (int, optional): Number of tcp/ip connections (default=12). | ||||
| prefetch_size (int, optional): Prefetch size (default=20). | prefetch_size (int, optional): Prefetch size (default=20). | ||||
| Tutorials: | |||||
| https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache | |||||
| https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html | |||||
| """ | """ | ||||
| def __init__(self, session_id, size=0, spilling=False, hostname=None, port=None, num_connections=None, | def __init__(self, session_id, size=0, spilling=False, hostname=None, port=None, num_connections=None, | ||||
| @@ -1857,6 +1857,40 @@ def test_cache_map_cifar3(): | |||||
| logger.info("test_cache_map_cifar3 Ended.\n") | logger.info("test_cache_map_cifar3 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_cifar4(): | |||||
| """ | |||||
| Test mappable cifar10 leaf with cache op right over the leaf, and shuffle op over the cache op | |||||
| shuffle | |||||
| | | |||||
| cache | |||||
| | | |||||
| Cifar10 | |||||
| """ | |||||
| logger.info("Test cache map cifar4") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) | |||||
| ds1 = ds1.shuffle(10) | |||||
| num_epoch = 1 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| assert sum([1 for _ in iter1]) == 10 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_map_cifar4 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_map_voc1(): | def test_cache_map_voc1(): | ||||
| """ | """ | ||||