From: @lixiachen Reviewed-by: @robingrosman,@mikef Signed-off-by: @robingrosmanpull/14822/MERGE
| @@ -179,6 +179,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__, | |||
| "Not an error and we should bypass the build phase"); | |||
| } | |||
| if (async_buffer_stream_) { | |||
| // Reset the async buffer stream to its initial state. Any stale status and data would get cleaned up. | |||
| RETURN_IF_NOT_OK(async_buffer_stream_->Reset()); | |||
| } | |||
| } else { | |||
| cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client | |||
| // Now execute the cache create request using this identifier and other configs | |||
| @@ -343,7 +347,7 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { | |||
| } | |||
| } | |||
| CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {} | |||
| CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0) {} | |||
| CacheClient::AsyncBufferStream::~AsyncBufferStream() { | |||
| (void)vg_.ServiceStop(); | |||
| @@ -426,8 +430,16 @@ Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) { | |||
| // If we are asked to wait, say this is the final flush, just wait for its completion. | |||
| bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking; | |||
| if (blocking) { | |||
| flush_rc_ = asyncWriter->rq->Wait(); | |||
| asyncWriter->rq.reset(); | |||
| // Make sure we are done with all the buffers | |||
| for (auto i = 0; i < kNumAsyncBuffer; ++i) { | |||
| if (buf_arr_[i].rq) { | |||
| Status rc = buf_arr_[i].rq->Wait(); | |||
| if (rc.IsError()) { | |||
| flush_rc_ = rc; | |||
| } | |||
| buf_arr_[i].rq.reset(); | |||
| } | |||
| } | |||
| } | |||
| // Prepare for the next buffer. | |||
| cur_ = (cur_ + 1) % kNumAsyncBuffer; | |||
| @@ -458,5 +470,17 @@ Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std: | |||
| ++num_ele_; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::AsyncBufferStream::Reset() { | |||
| // Clean up previous running state to be prepared for a new run. | |||
| cur_ = 0; | |||
| flush_rc_ = Status::OK(); | |||
| for (auto i = 0; i < kNumAsyncBuffer; ++i) { | |||
| buf_arr_[i].bytes_avail_ = kAsyncBufferSize; | |||
| buf_arr_[i].num_ele_ = 0; | |||
| buf_arr_[i].rq.reset(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -338,6 +338,9 @@ class CacheClient { | |||
| /// \brief Release the shared memory during shutdown | |||
| /// /note but needs comm layer to be alive. | |||
| Status ReleaseBuffer(); | |||
| /// \brief Reset the AsyncBufferStream into its initial state | |||
| /// \return Status object | |||
| Status Reset(); | |||
| private: | |||
| Status flush_rc_; | |||
| @@ -347,7 +350,6 @@ class CacheClient { | |||
| int64_t offset_addr_; | |||
| AsyncWriter buf_arr_[kNumAsyncBuffer]; | |||
| int32_t cur_; | |||
| std::atomic<int64_t> next_addr_; | |||
| }; | |||
| std::shared_ptr<AsyncBufferStream> async_buffer_stream_; | |||
| }; | |||
| @@ -127,6 +127,16 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" | |||
| HandleRcExit $? 0 0 | |||
| GetSession | |||
| HandleRcExit $? 1 1 | |||
| export SESSION_ID=$session_id | |||
| PytestCmd "test_cache_map.py" "test_cache_map_interrupt_and_rerun" | |||
| HandleRcExit $? 0 0 | |||
| DestroySession $session_id | |||
| HandleRcExit $? 1 1 | |||
| # Run two parallel pipelines (sharing cache) | |||
| for i in $(seq 1 2) | |||
| do | |||
| @@ -321,6 +331,26 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1 | |||
| HandleRcExit $? 0 0 | |||
| GetSession | |||
| HandleRcExit $? 1 1 | |||
| export SESSION_ID=$session_id | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_all_rows_cached" | |||
| HandleRcExit $? 0 0 | |||
| DestroySession $session_id | |||
| HandleRcExit $? 1 1 | |||
| GetSession | |||
| HandleRcExit $? 1 1 | |||
| export SESSION_ID=$session_id | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_interrupt_and_rerun" | |||
| HandleRcExit $? 0 0 | |||
| DestroySession $session_id | |||
| HandleRcExit $? 1 1 | |||
| for i in $(seq 1 3) | |||
| do | |||
| test_name="test_cache_nomap_multiple_cache${i}" | |||
| @@ -760,11 +760,11 @@ def test_cache_map_parameter_check(): | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id="1", size=0) | |||
| assert "Argument session_id with value 1 is not of type [<class 'int'>]" in str(info.value) | |||
| assert "Argument session_id with value 1 is not of type" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=None, size=0) | |||
| assert "Argument session_id with value None is not of type [<class 'int'>]" in str(info.value) | |||
| assert "Argument session_id with value None is not of type" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| ds.DatasetCache(session_id=1, size=-1) | |||
| @@ -772,19 +772,19 @@ def test_cache_map_parameter_check(): | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size="1") | |||
| assert "Argument size with value 1 is not of type [<class 'int'>]" in str(info.value) | |||
| assert "Argument size with value 1 is not of type" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size=None) | |||
| assert "Argument size with value None is not of type [<class 'int'>]" in str(info.value) | |||
| assert "Argument size with value None is not of type" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size=0, spilling="illegal") | |||
| assert "Argument spilling with value illegal is not of type [<class 'bool'>]" in str(info.value) | |||
| assert "Argument spilling with value illegal is not of type" in str(info.value) | |||
| with pytest.raises(TypeError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, hostname=50052) | |||
| assert "Argument hostname with value 50052 is not of type [<class 'str'>]" in str(err.value) | |||
| assert "Argument hostname with value 50052 is not of type" in str(err.value) | |||
| with pytest.raises(RuntimeError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, hostname="illegal") | |||
| @@ -796,11 +796,11 @@ def test_cache_map_parameter_check(): | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size=0, port="illegal") | |||
| assert "Argument port with value illegal is not of type [<class 'int'>]" in str(info.value) | |||
| assert "Argument port with value illegal is not of type" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size=0, port="50052") | |||
| assert "Argument port with value 50052 is not of type [<class 'int'>]" in str(info.value) | |||
| assert "Argument port with value 50052 is not of type" in str(info.value) | |||
| with pytest.raises(ValueError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, port=0) | |||
| @@ -2110,6 +2110,52 @@ def test_cache_map_nested_repeat(): | |||
| logger.info('test_cache_map_nested_repeat Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_interrupt_and_rerun(): | |||
| """ | |||
| Test interrupt a running pipeline and then re-use the same cache to run another pipeline | |||
| cache | |||
| | | |||
| Cifar10 | |||
| """ | |||
| logger.info("Test cache map interrupt and rerun") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) | |||
| iter1 = ds1.create_dict_iterator() | |||
| num_iter = 0 | |||
| with pytest.raises(AttributeError) as e: | |||
| for _ in iter1: | |||
| num_iter += 1 | |||
| if num_iter == 10: | |||
| iter1.stop() | |||
| assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value) | |||
| num_epoch = 2 | |||
| iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| num_iter = 0 | |||
| for _ in iter2: | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 10000 | |||
| epoch_count += 1 | |||
| cache_stat = some_cache.GetStat() | |||
| assert cache_stat.num_mem_cached == 10000 | |||
| logger.info("test_cache_map_interrupt_and_rerun Ended.\n") | |||
| if __name__ == '__main__': | |||
| # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py' | |||
| # since cache server is required to be brought up first | |||
| @@ -1193,6 +1193,58 @@ def test_cache_nomap_server_stop(): | |||
| logger.info("test_cache_nomap_server_stop Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_interrupt_and_rerun(): | |||
| """ | |||
| Test interrupt a running pipeline and then re-use the same cache to run another pipeline | |||
| Cache | |||
| | | |||
| RandomDataset | |||
| """ | |||
| logger.info("Test cache nomap interrupt and rerun") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| schema = ds.Schema() | |||
| schema.add_column('image', de_type=mstype.uint8, | |||
| shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # User-created sampler here | |||
| ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache) | |||
| iter1 = ds1.create_dict_iterator() | |||
| num_iter = 0 | |||
| with pytest.raises(AttributeError) as e: | |||
| for _ in iter1: | |||
| num_iter += 1 | |||
| if num_iter == 10: | |||
| iter1.stop() | |||
| assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value) | |||
| num_epoch = 2 | |||
| iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| num_iter = 0 | |||
| for _ in iter2: | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 10000 | |||
| epoch_count += 1 | |||
| cache_stat = some_cache.GetStat() | |||
| assert cache_stat.num_mem_cached == 10000 | |||
| logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_epoch_ctrl1(): | |||
| """ | |||
| @@ -2262,6 +2314,47 @@ def test_cache_nomap_pyfunc_function(): | |||
| logger.info("test_cache_nomap_pyfunc_function Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_all_rows_cached(): | |||
| """ | |||
| Make sure all rows are cached before we switch to the fetching phase | |||
| Cache | |||
| | | |||
| RandomDataset | |||
| """ | |||
| logger.info("Test cache nomap all rows cached") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| schema = ds.Schema() | |||
| schema.add_column('image', de_type=mstype.uint8, | |||
| shape=[450, 450, 3]) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # easier to reproduce the problem with 271 total rows | |||
| num_total_rows = 271 | |||
| # User-created sampler here | |||
| ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache) | |||
| iter1 = ds1.create_dict_iterator() | |||
| num_iter = 0 | |||
| for _ in iter1: | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == num_total_rows | |||
| cache_stat = some_cache.GetStat() | |||
| assert cache_stat.num_mem_cached == num_total_rows | |||
| logger.info("test_cache_nomap_all_rows_cached Ended.\n") | |||
| if __name__ == '__main__': | |||
| # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' | |||
| # since cache server is required to be brought up first | |||