Browse Source

!14822 Fix cache client status reset and final flush

From: @lixiachen
Reviewed-by: @robingrosman,@mikef
Signed-off-by: @robingrosman
pull/14822/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
f451625ea9
5 changed files with 207 additions and 12 deletions
  1. +27
    -3
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
  2. +3
    -1
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h
  3. +30
    -0
      tests/ut/python/cachetests/cachetest_py.sh
  4. +54
    -8
      tests/ut/python/dataset/test_cache_map.py
  5. +93
    -0
      tests/ut/python/dataset/test_cache_nomap.py

+ 27
- 3
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc View File

@@ -179,6 +179,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__, return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__,
"Not an error and we should bypass the build phase"); "Not an error and we should bypass the build phase");
} }
if (async_buffer_stream_) {
// Reset the async buffer stream to its initial state. Any stale status and data would get cleaned up.
RETURN_IF_NOT_OK(async_buffer_stream_->Reset());
}
} else { } else {
cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
// Now execute the cache create request using this identifier and other configs // Now execute the cache create request using this identifier and other configs
@@ -343,7 +347,7 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
} }
} }


CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {}
CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0) {}


CacheClient::AsyncBufferStream::~AsyncBufferStream() { CacheClient::AsyncBufferStream::~AsyncBufferStream() {
(void)vg_.ServiceStop(); (void)vg_.ServiceStop();
@@ -426,8 +430,16 @@ Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) {
// If we are asked to wait, say this is the final flush, just wait for its completion. // If we are asked to wait, say this is the final flush, just wait for its completion.
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking; bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
if (blocking) { if (blocking) {
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
// Make sure we are done with all the buffers
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
if (buf_arr_[i].rq) {
Status rc = buf_arr_[i].rq->Wait();
if (rc.IsError()) {
flush_rc_ = rc;
}
buf_arr_[i].rq.reset();
}
}
} }
// Prepare for the next buffer. // Prepare for the next buffer.
cur_ = (cur_ + 1) % kNumAsyncBuffer; cur_ = (cur_ + 1) % kNumAsyncBuffer;
@@ -458,5 +470,17 @@ Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std:
++num_ele_; ++num_ele_;
return Status::OK(); return Status::OK();
} }

Status CacheClient::AsyncBufferStream::Reset() {
// Clean up previous running state to be prepared for a new run.
cur_ = 0;
flush_rc_ = Status::OK();
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
buf_arr_[i].bytes_avail_ = kAsyncBufferSize;
buf_arr_[i].num_ele_ = 0;
buf_arr_[i].rq.reset();
}
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 3
- 1
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h View File

@@ -338,6 +338,9 @@ class CacheClient {
/// \brief Release the shared memory during shutdown /// \brief Release the shared memory during shutdown
/// /note but needs comm layer to be alive. /// /note but needs comm layer to be alive.
Status ReleaseBuffer(); Status ReleaseBuffer();
/// \brief Reset the AsyncBufferStream into its initial state
/// \return Status object
Status Reset();


private: private:
Status flush_rc_; Status flush_rc_;
@@ -347,7 +350,6 @@ class CacheClient {
int64_t offset_addr_; int64_t offset_addr_;
AsyncWriter buf_arr_[kNumAsyncBuffer]; AsyncWriter buf_arr_[kNumAsyncBuffer];
int32_t cur_; int32_t cur_;
std::atomic<int64_t> next_addr_;
}; };
std::shared_ptr<AsyncBufferStream> async_buffer_stream_; std::shared_ptr<AsyncBufferStream> async_buffer_stream_;
}; };


+ 30
- 0
tests/ut/python/cachetests/cachetest_py.sh View File

@@ -127,6 +127,16 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat"
HandleRcExit $? 0 0 HandleRcExit $? 0 0


GetSession
HandleRcExit $? 1 1
export SESSION_ID=$session_id

PytestCmd "test_cache_map.py" "test_cache_map_interrupt_and_rerun"
HandleRcExit $? 0 0

DestroySession $session_id
HandleRcExit $? 1 1

# Run two parallel pipelines (sharing cache) # Run two parallel pipelines (sharing cache)
for i in $(seq 1 2) for i in $(seq 1 2)
do do
@@ -321,6 +331,26 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1 PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1
HandleRcExit $? 0 0 HandleRcExit $? 0 0


GetSession
HandleRcExit $? 1 1
export SESSION_ID=$session_id

PytestCmd "test_cache_nomap.py" "test_cache_nomap_all_rows_cached"
HandleRcExit $? 0 0

DestroySession $session_id
HandleRcExit $? 1 1

GetSession
HandleRcExit $? 1 1
export SESSION_ID=$session_id

PytestCmd "test_cache_nomap.py" "test_cache_nomap_interrupt_and_rerun"
HandleRcExit $? 0 0

DestroySession $session_id
HandleRcExit $? 1 1

for i in $(seq 1 3) for i in $(seq 1 3)
do do
test_name="test_cache_nomap_multiple_cache${i}" test_name="test_cache_nomap_multiple_cache${i}"


+ 54
- 8
tests/ut/python/dataset/test_cache_map.py View File

@@ -760,11 +760,11 @@ def test_cache_map_parameter_check():


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id="1", size=0) ds.DatasetCache(session_id="1", size=0)
assert "Argument session_id with value 1 is not of type [<class 'int'>]" in str(info.value)
assert "Argument session_id with value 1 is not of type" in str(info.value)


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=None, size=0) ds.DatasetCache(session_id=None, size=0)
assert "Argument session_id with value None is not of type [<class 'int'>]" in str(info.value)
assert "Argument session_id with value None is not of type" in str(info.value)


with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
ds.DatasetCache(session_id=1, size=-1) ds.DatasetCache(session_id=1, size=-1)
@@ -772,19 +772,19 @@ def test_cache_map_parameter_check():


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size="1") ds.DatasetCache(session_id=1, size="1")
assert "Argument size with value 1 is not of type [<class 'int'>]" in str(info.value)
assert "Argument size with value 1 is not of type" in str(info.value)


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=None) ds.DatasetCache(session_id=1, size=None)
assert "Argument size with value None is not of type [<class 'int'>]" in str(info.value)
assert "Argument size with value None is not of type" in str(info.value)


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling="illegal") ds.DatasetCache(session_id=1, size=0, spilling="illegal")
assert "Argument spilling with value illegal is not of type [<class 'bool'>]" in str(info.value)
assert "Argument spilling with value illegal is not of type" in str(info.value)


with pytest.raises(TypeError) as err: with pytest.raises(TypeError) as err:
ds.DatasetCache(session_id=1, size=0, hostname=50052) ds.DatasetCache(session_id=1, size=0, hostname=50052)
assert "Argument hostname with value 50052 is not of type [<class 'str'>]" in str(err.value)
assert "Argument hostname with value 50052 is not of type" in str(err.value)


with pytest.raises(RuntimeError) as err: with pytest.raises(RuntimeError) as err:
ds.DatasetCache(session_id=1, size=0, hostname="illegal") ds.DatasetCache(session_id=1, size=0, hostname="illegal")
@@ -796,11 +796,11 @@ def test_cache_map_parameter_check():


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, port="illegal") ds.DatasetCache(session_id=1, size=0, port="illegal")
assert "Argument port with value illegal is not of type [<class 'int'>]" in str(info.value)
assert "Argument port with value illegal is not of type" in str(info.value)


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, port="50052") ds.DatasetCache(session_id=1, size=0, port="50052")
assert "Argument port with value 50052 is not of type [<class 'int'>]" in str(info.value)
assert "Argument port with value 50052 is not of type" in str(info.value)


with pytest.raises(ValueError) as err: with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, port=0) ds.DatasetCache(session_id=1, size=0, port=0)
@@ -2110,6 +2110,52 @@ def test_cache_map_nested_repeat():
logger.info('test_cache_map_nested_repeat Ended.\n') logger.info('test_cache_map_nested_repeat Ended.\n')




@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_interrupt_and_rerun():
"""
Test interrupt a running pipeline and then re-use the same cache to run another pipeline

cache
|
Cifar10
"""

logger.info("Test cache map interrupt and rerun")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

some_cache = ds.DatasetCache(session_id=session_id, size=0)

ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
iter1 = ds1.create_dict_iterator()

num_iter = 0
with pytest.raises(AttributeError) as e:
for _ in iter1:
num_iter += 1
if num_iter == 10:
iter1.stop()
assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)

num_epoch = 2
iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
num_iter = 0
for _ in iter2:
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 10000
epoch_count += 1

cache_stat = some_cache.GetStat()
assert cache_stat.num_mem_cached == 10000

logger.info("test_cache_map_interrupt_and_rerun Ended.\n")


if __name__ == '__main__': if __name__ == '__main__':
# This is just a list of tests, don't try to run these tests with 'python test_cache_map.py' # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py'
# since cache server is required to be brought up first # since cache server is required to be brought up first


+ 93
- 0
tests/ut/python/dataset/test_cache_nomap.py View File

@@ -1193,6 +1193,58 @@ def test_cache_nomap_server_stop():
logger.info("test_cache_nomap_server_stop Ended.\n") logger.info("test_cache_nomap_server_stop Ended.\n")




@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_interrupt_and_rerun():
"""
Test interrupt a running pipeline and then re-use the same cache to run another pipeline

Cache
|
RandomDataset
"""

logger.info("Test cache nomap interrupt and rerun")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8,
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])

some_cache = ds.DatasetCache(session_id=session_id, size=0)

# User-created sampler here
ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache)
iter1 = ds1.create_dict_iterator()

num_iter = 0
with pytest.raises(AttributeError) as e:
for _ in iter1:
num_iter += 1
if num_iter == 10:
iter1.stop()
assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)

num_epoch = 2
iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
num_iter = 0
for _ in iter2:
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 10000
epoch_count += 1

cache_stat = some_cache.GetStat()
assert cache_stat.num_mem_cached == 10000

logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n")


@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_epoch_ctrl1(): def test_cache_nomap_epoch_ctrl1():
""" """
@@ -2262,6 +2314,47 @@ def test_cache_nomap_pyfunc_function():
logger.info("test_cache_nomap_pyfunc_function Ended.\n") logger.info("test_cache_nomap_pyfunc_function Ended.\n")




@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_all_rows_cached():
"""
Make sure all rows are cached before we switch to the fetching phase

Cache
|
RandomDataset
"""

logger.info("Test cache nomap all rows cached")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")

schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8,
shape=[450, 450, 3])
schema.add_column('label', de_type=mstype.uint8, shape=[1])

some_cache = ds.DatasetCache(session_id=session_id, size=0)

# easier to reproduce the problem with 271 total rows
num_total_rows = 271
# User-created sampler here
ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache)
iter1 = ds1.create_dict_iterator()

num_iter = 0
for _ in iter1:
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == num_total_rows

cache_stat = some_cache.GetStat()
assert cache_stat.num_mem_cached == num_total_rows

logger.info("test_cache_nomap_all_rows_cached Ended.\n")


if __name__ == '__main__': if __name__ == '__main__':
# This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py'
# since cache server is required to be brought up first # since cache server is required to be brought up first


Loading…
Cancel
Save