Browse Source

!10074 Fix a testcase error when shuffle is above cache

From: @lixiachen
Reviewed-by: @liucunwei,@pandoublefeng
Signed-off-by: @liucunwei,@liucunwei
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
fc11b7dd68
3 changed files with 42 additions and 18 deletions
  1. +4
    -14
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
  2. +4
    -4
      mindspore/dataset/engine/cache_client.py
  3. +34
    -0
      tests/ut/python/dataset/test_cache_map.py

+ 4
- 14
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc View File

@@ -275,24 +275,14 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
}

Status CacheMergeOp::EoeReceived(int32_t worker_id) {
// If we are in a repeat path, send the eoe up.
// Otherwise ignore it.
if (op_total_repeats_ != 1) {
return DatasetOp::EoeReceived(worker_id);
}
return Status::OK();
// Send the eoe up.
MS_LOG(DEBUG) << "Cache merge sending eoe";
return DatasetOp::EoeReceived(worker_id);
}

// Base-class override for handling cases when an eof is received.
Status CacheMergeOp::EofReceived(int32_t worker_id) {
// If we are not in a repeated path, then the merge op gets a eof by itself, without first
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
// provides that for us.
if (op_total_repeats_ == 1) {
MS_LOG(DEBUG) << "Cache merge sending eoe";
RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id));
}
// Send the eof up.
MS_LOG(DEBUG) << "Cache merge sending eof";
return DatasetOp::EofReceived(worker_id);
}


+ 4
- 4
mindspore/dataset/engine/cache_client.py View File

@@ -23,7 +23,10 @@ from ..core.validator_helpers import type_check, check_uint32, check_uint64, che

class DatasetCache:
"""
A client to interface with tensor caching service
A client to interface with tensor caching service.

For details, please check `Chinese tutorial <https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html>`_,
`Chinese programming guide <https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache>`_.

Args:
session_id (int): A user assigned session id for the current pipeline.
@@ -34,9 +37,6 @@ class DatasetCache:
num_connections (int, optional): Number of tcp/ip connections (default=12).
prefetch_size (int, optional): Prefetch size (default=20).

Tutorials:
https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache
https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html
"""

def __init__(self, session_id, size=0, spilling=False, hostname=None, port=None, num_connections=None,


+ 34
- 0
tests/ut/python/dataset/test_cache_map.py View File

@@ -1857,6 +1857,40 @@ def test_cache_map_cifar3():
logger.info("test_cache_map_cifar3 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_cifar4():
"""
Test mappable cifar10 leaf with cache op right over the leaf, and shuffle op over the cache op

shuffle
|
cache
|
Cifar10
"""

logger.info("Test cache map cifar4")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
ds1 = ds1.shuffle(10)

num_epoch = 1
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)

epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 10
epoch_count += 1
assert epoch_count == num_epoch

logger.info("test_cache_map_cifar4 Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_voc1():
"""


Loading…
Cancel
Save