diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index b052f8f1f5..c2cac7cf68 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -179,6 +179,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); } + if (async_buffer_stream_) { + // Reset the async buffer stream to its initial state. Any stale status and data would get cleaned up. + RETURN_IF_NOT_OK(async_buffer_stream_->Reset()); + } } else { cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client // Now execute the cache create request using this identifier and other configs @@ -343,7 +347,7 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { } } -CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {} +CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0) {} CacheClient::AsyncBufferStream::~AsyncBufferStream() { (void)vg_.ServiceStop(); @@ -426,8 +430,16 @@ Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) { // If we are asked to wait, say this is the final flush, just wait for its completion. bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking; if (blocking) { - flush_rc_ = asyncWriter->rq->Wait(); - asyncWriter->rq.reset(); + // Make sure we are done with all the buffers + for (auto i = 0; i < kNumAsyncBuffer; ++i) { + if (buf_arr_[i].rq) { + Status rc = buf_arr_[i].rq->Wait(); + if (rc.IsError()) { + flush_rc_ = rc; + } + buf_arr_[i].rq.reset(); + } + } } // Prepare for the next buffer. cur_ = (cur_ + 1) % kNumAsyncBuffer; @@ -458,5 +470,17 @@ Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std: ++num_ele_; return Status::OK(); } + +Status CacheClient::AsyncBufferStream::Reset() { + // Clean up previous running state to be prepared for a new run. + cur_ = 0; + flush_rc_ = Status::OK(); + for (auto i = 0; i < kNumAsyncBuffer; ++i) { + buf_arr_[i].bytes_avail_ = kAsyncBufferSize; + buf_arr_[i].num_ele_ = 0; + buf_arr_[i].rq.reset(); + } + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h index ae51df1714..7e906ad234 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -338,6 +338,9 @@ class CacheClient { /// \brief Release the shared memory during shutdown /// /note but needs comm layer to be alive. Status ReleaseBuffer(); + /// \brief Reset the AsyncBufferStream into its initial state + /// \return Status object + Status Reset(); private: Status flush_rc_; @@ -347,7 +350,6 @@ class CacheClient { int64_t offset_addr_; AsyncWriter buf_arr_[kNumAsyncBuffer]; int32_t cur_; - std::atomic next_addr_; }; std::shared_ptr async_buffer_stream_; }; diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh index 5f7f0d03c2..de08bce4ee 100755 --- a/tests/ut/python/cachetests/cachetest_py.sh +++ b/tests/ut/python/cachetests/cachetest_py.sh @@ -127,6 +127,16 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" HandleRcExit $? 0 0 +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_map.py" "test_cache_map_interrupt_and_rerun" +HandleRcExit $? 0 0 + +DestroySession $session_id +HandleRcExit $? 1 1 + # Run two parallel pipelines (sharing cache) for i in $(seq 1 2) do @@ -321,6 +331,26 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1 HandleRcExit $? 0 0 +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_all_rows_cached" +HandleRcExit $? 0 0 + +DestroySession $session_id +HandleRcExit $? 1 1 + +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +PytestCmd "test_cache_nomap.py" "test_cache_nomap_interrupt_and_rerun" +HandleRcExit $? 0 0 + +DestroySession $session_id +HandleRcExit $? 1 1 + for i in $(seq 1 3) do test_name="test_cache_nomap_multiple_cache${i}" diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 0c94c6e069..9d4a3e1260 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -760,11 +760,11 @@ def test_cache_map_parameter_check(): with pytest.raises(TypeError) as info: ds.DatasetCache(session_id="1", size=0) - assert "Argument session_id with value 1 is not of type []" in str(info.value) + assert "Argument session_id with value 1 is not of type" in str(info.value) with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=None, size=0) - assert "Argument session_id with value None is not of type []" in str(info.value) + assert "Argument session_id with value None is not of type" in str(info.value) with pytest.raises(ValueError) as info: ds.DatasetCache(session_id=1, size=-1) @@ -772,19 +772,19 @@ def test_cache_map_parameter_check(): with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size="1") - assert "Argument size with value 1 is not of type []" in str(info.value) + assert "Argument size with value 1 is not of type" in str(info.value) with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size=None) - assert "Argument size with value None is not of type []" in str(info.value) + assert "Argument size with value None is not of type" in str(info.value) with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size=0, spilling="illegal") - assert "Argument spilling with value illegal is not of type []" in str(info.value) + assert "Argument spilling with value illegal is not of type" in str(info.value) with pytest.raises(TypeError) as err: ds.DatasetCache(session_id=1, size=0, hostname=50052) - assert "Argument hostname with value 50052 is not of type []" in str(err.value) + assert "Argument hostname with value 50052 is not of type" in str(err.value) with pytest.raises(RuntimeError) as err: ds.DatasetCache(session_id=1, size=0, hostname="illegal") @@ -796,11 +796,11 @@ def test_cache_map_parameter_check(): with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size=0, port="illegal") - assert "Argument port with value illegal is not of type []" in str(info.value) + assert "Argument port with value illegal is not of type" in str(info.value) with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size=0, port="50052") - assert "Argument port with value 50052 is not of type []" in str(info.value) + assert "Argument port with value 50052 is not of type" in str(info.value) with pytest.raises(ValueError) as err: ds.DatasetCache(session_id=1, size=0, port=0) @@ -2110,6 +2110,52 @@ def test_cache_map_nested_repeat(): logger.info('test_cache_map_nested_repeat Ended.\n') +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_interrupt_and_rerun(): + """ + Test interrupt a running pipeline and then re-use the same cache to run another pipeline + + cache + | + Cifar10 + """ + + logger.info("Test cache map interrupt and rerun") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) + iter1 = ds1.create_dict_iterator() + + num_iter = 0 + with pytest.raises(AttributeError) as e: + for _ in iter1: + num_iter += 1 + if num_iter == 10: + iter1.stop() + assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value) + + num_epoch = 2 + iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) + epoch_count = 0 + for _ in range(num_epoch): + num_iter = 0 + for _ in iter2: + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 10000 + epoch_count += 1 + + cache_stat = some_cache.GetStat() + assert cache_stat.num_mem_cached == 10000 + + logger.info("test_cache_map_interrupt_and_rerun Ended.\n") + + if __name__ == '__main__': # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py' # since cache server is required to be brought up first diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 90085555b5..0e2375ccad 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -1193,6 +1193,58 @@ def test_cache_nomap_server_stop(): logger.info("test_cache_nomap_server_stop Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_interrupt_and_rerun(): + """ + Test interrupt a running pipeline and then re-use the same cache to run another pipeline + + Cache + | + RandomDataset + """ + + logger.info("Test cache nomap interrupt and rerun") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + # User-created sampler here + ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache) + iter1 = ds1.create_dict_iterator() + + num_iter = 0 + with pytest.raises(AttributeError) as e: + for _ in iter1: + num_iter += 1 + if num_iter == 10: + iter1.stop() + assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value) + + num_epoch = 2 + iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) + epoch_count = 0 + for _ in range(num_epoch): + num_iter = 0 + for _ in iter2: + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 10000 + epoch_count += 1 + + cache_stat = some_cache.GetStat() + assert cache_stat.num_mem_cached == 10000 + + logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n") + + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_epoch_ctrl1(): """ @@ -2262,6 +2314,47 @@ def test_cache_nomap_pyfunc_function(): logger.info("test_cache_nomap_pyfunc_function Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_all_rows_cached(): + """ + Make sure all rows are cached before we switch to the fetching phase + + Cache + | + RandomDataset + """ + + logger.info("Test cache nomap all rows cached") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[450, 450, 3]) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + # easier to reproduce the problem with 271 total rows + num_total_rows = 271 + # User-created sampler here + ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache) + iter1 = ds1.create_dict_iterator() + + num_iter = 0 + for _ in iter1: + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == num_total_rows + + cache_stat = some_cache.GetStat() + assert cache_stat.num_mem_cached == num_total_rows + + logger.info("test_cache_nomap_all_rows_cached Ended.\n") + + if __name__ == '__main__': # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' # since cache server is required to be brought up first