From: @lixiachen Reviewed-by: @robingrosman,@mikef Signed-off-by: @robingrosmanpull/14822/MERGE
| @@ -179,6 +179,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||||
| return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__, | return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__, | ||||
| "Not an error and we should bypass the build phase"); | "Not an error and we should bypass the build phase"); | ||||
| } | } | ||||
| if (async_buffer_stream_) { | |||||
| // Reset the async buffer stream to its initial state. Any stale status and data would get cleaned up. | |||||
| RETURN_IF_NOT_OK(async_buffer_stream_->Reset()); | |||||
| } | |||||
| } else { | } else { | ||||
| cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client | cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client | ||||
| // Now execute the cache create request using this identifier and other configs | // Now execute the cache create request using this identifier and other configs | ||||
| @@ -343,7 +347,7 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { | |||||
| } | } | ||||
| } | } | ||||
| CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {} | |||||
| CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0) {} | |||||
| CacheClient::AsyncBufferStream::~AsyncBufferStream() { | CacheClient::AsyncBufferStream::~AsyncBufferStream() { | ||||
| (void)vg_.ServiceStop(); | (void)vg_.ServiceStop(); | ||||
| @@ -426,8 +430,16 @@ Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) { | |||||
| // If we are asked to wait, say this is the final flush, just wait for its completion. | // If we are asked to wait, say this is the final flush, just wait for its completion. | ||||
| bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking; | bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking; | ||||
| if (blocking) { | if (blocking) { | ||||
| flush_rc_ = asyncWriter->rq->Wait(); | |||||
| asyncWriter->rq.reset(); | |||||
| // Make sure we are done with all the buffers | |||||
| for (auto i = 0; i < kNumAsyncBuffer; ++i) { | |||||
| if (buf_arr_[i].rq) { | |||||
| Status rc = buf_arr_[i].rq->Wait(); | |||||
| if (rc.IsError()) { | |||||
| flush_rc_ = rc; | |||||
| } | |||||
| buf_arr_[i].rq.reset(); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| // Prepare for the next buffer. | // Prepare for the next buffer. | ||||
| cur_ = (cur_ + 1) % kNumAsyncBuffer; | cur_ = (cur_ + 1) % kNumAsyncBuffer; | ||||
| @@ -458,5 +470,17 @@ Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std: | |||||
| ++num_ele_; | ++num_ele_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheClient::AsyncBufferStream::Reset() { | |||||
| // Clean up previous running state to be prepared for a new run. | |||||
| cur_ = 0; | |||||
| flush_rc_ = Status::OK(); | |||||
| for (auto i = 0; i < kNumAsyncBuffer; ++i) { | |||||
| buf_arr_[i].bytes_avail_ = kAsyncBufferSize; | |||||
| buf_arr_[i].num_ele_ = 0; | |||||
| buf_arr_[i].rq.reset(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -338,6 +338,9 @@ class CacheClient { | |||||
| /// \brief Release the shared memory during shutdown | /// \brief Release the shared memory during shutdown | ||||
| /// /note but needs comm layer to be alive. | /// /note but needs comm layer to be alive. | ||||
| Status ReleaseBuffer(); | Status ReleaseBuffer(); | ||||
| /// \brief Reset the AsyncBufferStream into its initial state | |||||
| /// \return Status object | |||||
| Status Reset(); | |||||
| private: | private: | ||||
| Status flush_rc_; | Status flush_rc_; | ||||
| @@ -347,7 +350,6 @@ class CacheClient { | |||||
| int64_t offset_addr_; | int64_t offset_addr_; | ||||
| AsyncWriter buf_arr_[kNumAsyncBuffer]; | AsyncWriter buf_arr_[kNumAsyncBuffer]; | ||||
| int32_t cur_; | int32_t cur_; | ||||
| std::atomic<int64_t> next_addr_; | |||||
| }; | }; | ||||
| std::shared_ptr<AsyncBufferStream> async_buffer_stream_; | std::shared_ptr<AsyncBufferStream> async_buffer_stream_; | ||||
| }; | }; | ||||
| @@ -127,6 +127,16 @@ HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" | PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| GetSession | |||||
| HandleRcExit $? 1 1 | |||||
| export SESSION_ID=$session_id | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_interrupt_and_rerun" | |||||
| HandleRcExit $? 0 0 | |||||
| DestroySession $session_id | |||||
| HandleRcExit $? 1 1 | |||||
| # Run two parallel pipelines (sharing cache) | # Run two parallel pipelines (sharing cache) | ||||
| for i in $(seq 1 2) | for i in $(seq 1 2) | ||||
| do | do | ||||
| @@ -321,6 +331,26 @@ HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1 | PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| GetSession | |||||
| HandleRcExit $? 1 1 | |||||
| export SESSION_ID=$session_id | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_all_rows_cached" | |||||
| HandleRcExit $? 0 0 | |||||
| DestroySession $session_id | |||||
| HandleRcExit $? 1 1 | |||||
| GetSession | |||||
| HandleRcExit $? 1 1 | |||||
| export SESSION_ID=$session_id | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_interrupt_and_rerun" | |||||
| HandleRcExit $? 0 0 | |||||
| DestroySession $session_id | |||||
| HandleRcExit $? 1 1 | |||||
| for i in $(seq 1 3) | for i in $(seq 1 3) | ||||
| do | do | ||||
| test_name="test_cache_nomap_multiple_cache${i}" | test_name="test_cache_nomap_multiple_cache${i}" | ||||
| @@ -760,11 +760,11 @@ def test_cache_map_parameter_check(): | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id="1", size=0) | ds.DatasetCache(session_id="1", size=0) | ||||
| assert "Argument session_id with value 1 is not of type [<class 'int'>]" in str(info.value) | |||||
| assert "Argument session_id with value 1 is not of type" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=None, size=0) | ds.DatasetCache(session_id=None, size=0) | ||||
| assert "Argument session_id with value None is not of type [<class 'int'>]" in str(info.value) | |||||
| assert "Argument session_id with value None is not of type" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| ds.DatasetCache(session_id=1, size=-1) | ds.DatasetCache(session_id=1, size=-1) | ||||
| @@ -772,19 +772,19 @@ def test_cache_map_parameter_check(): | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size="1") | ds.DatasetCache(session_id=1, size="1") | ||||
| assert "Argument size with value 1 is not of type [<class 'int'>]" in str(info.value) | |||||
| assert "Argument size with value 1 is not of type" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=None) | ds.DatasetCache(session_id=1, size=None) | ||||
| assert "Argument size with value None is not of type [<class 'int'>]" in str(info.value) | |||||
| assert "Argument size with value None is not of type" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=0, spilling="illegal") | ds.DatasetCache(session_id=1, size=0, spilling="illegal") | ||||
| assert "Argument spilling with value illegal is not of type [<class 'bool'>]" in str(info.value) | |||||
| assert "Argument spilling with value illegal is not of type" in str(info.value) | |||||
| with pytest.raises(TypeError) as err: | with pytest.raises(TypeError) as err: | ||||
| ds.DatasetCache(session_id=1, size=0, hostname=50052) | ds.DatasetCache(session_id=1, size=0, hostname=50052) | ||||
| assert "Argument hostname with value 50052 is not of type [<class 'str'>]" in str(err.value) | |||||
| assert "Argument hostname with value 50052 is not of type" in str(err.value) | |||||
| with pytest.raises(RuntimeError) as err: | with pytest.raises(RuntimeError) as err: | ||||
| ds.DatasetCache(session_id=1, size=0, hostname="illegal") | ds.DatasetCache(session_id=1, size=0, hostname="illegal") | ||||
| @@ -796,11 +796,11 @@ def test_cache_map_parameter_check(): | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=0, port="illegal") | ds.DatasetCache(session_id=1, size=0, port="illegal") | ||||
| assert "Argument port with value illegal is not of type [<class 'int'>]" in str(info.value) | |||||
| assert "Argument port with value illegal is not of type" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=0, port="50052") | ds.DatasetCache(session_id=1, size=0, port="50052") | ||||
| assert "Argument port with value 50052 is not of type [<class 'int'>]" in str(info.value) | |||||
| assert "Argument port with value 50052 is not of type" in str(info.value) | |||||
| with pytest.raises(ValueError) as err: | with pytest.raises(ValueError) as err: | ||||
| ds.DatasetCache(session_id=1, size=0, port=0) | ds.DatasetCache(session_id=1, size=0, port=0) | ||||
| @@ -2110,6 +2110,52 @@ def test_cache_map_nested_repeat(): | |||||
| logger.info('test_cache_map_nested_repeat Ended.\n') | logger.info('test_cache_map_nested_repeat Ended.\n') | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_interrupt_and_rerun(): | |||||
| """ | |||||
| Test interrupt a running pipeline and then re-use the same cache to run another pipeline | |||||
| cache | |||||
| | | |||||
| Cifar10 | |||||
| """ | |||||
| logger.info("Test cache map interrupt and rerun") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) | |||||
| iter1 = ds1.create_dict_iterator() | |||||
| num_iter = 0 | |||||
| with pytest.raises(AttributeError) as e: | |||||
| for _ in iter1: | |||||
| num_iter += 1 | |||||
| if num_iter == 10: | |||||
| iter1.stop() | |||||
| assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value) | |||||
| num_epoch = 2 | |||||
| iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| num_iter = 0 | |||||
| for _ in iter2: | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||||
| assert num_iter == 10000 | |||||
| epoch_count += 1 | |||||
| cache_stat = some_cache.GetStat() | |||||
| assert cache_stat.num_mem_cached == 10000 | |||||
| logger.info("test_cache_map_interrupt_and_rerun Ended.\n") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py' | # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py' | ||||
| # since cache server is required to be brought up first | # since cache server is required to be brought up first | ||||
| @@ -1193,6 +1193,58 @@ def test_cache_nomap_server_stop(): | |||||
| logger.info("test_cache_nomap_server_stop Ended.\n") | logger.info("test_cache_nomap_server_stop Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_interrupt_and_rerun(): | |||||
| """ | |||||
| Test interrupt a running pipeline and then re-use the same cache to run another pipeline | |||||
| Cache | |||||
| | | |||||
| RandomDataset | |||||
| """ | |||||
| logger.info("Test cache nomap interrupt and rerun") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| schema = ds.Schema() | |||||
| schema.add_column('image', de_type=mstype.uint8, | |||||
| shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) | |||||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| # User-created sampler here | |||||
| ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache) | |||||
| iter1 = ds1.create_dict_iterator() | |||||
| num_iter = 0 | |||||
| with pytest.raises(AttributeError) as e: | |||||
| for _ in iter1: | |||||
| num_iter += 1 | |||||
| if num_iter == 10: | |||||
| iter1.stop() | |||||
| assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value) | |||||
| num_epoch = 2 | |||||
| iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| num_iter = 0 | |||||
| for _ in iter2: | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||||
| assert num_iter == 10000 | |||||
| epoch_count += 1 | |||||
| cache_stat = some_cache.GetStat() | |||||
| assert cache_stat.num_mem_cached == 10000 | |||||
| logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_nomap_epoch_ctrl1(): | def test_cache_nomap_epoch_ctrl1(): | ||||
| """ | """ | ||||
| @@ -2262,6 +2314,47 @@ def test_cache_nomap_pyfunc_function(): | |||||
| logger.info("test_cache_nomap_pyfunc_function Ended.\n") | logger.info("test_cache_nomap_pyfunc_function Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_all_rows_cached(): | |||||
| """ | |||||
| Make sure all rows are cached before we switch to the fetching phase | |||||
| Cache | |||||
| | | |||||
| RandomDataset | |||||
| """ | |||||
| logger.info("Test cache nomap all rows cached") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| schema = ds.Schema() | |||||
| schema.add_column('image', de_type=mstype.uint8, | |||||
| shape=[450, 450, 3]) | |||||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| # easier to reproduce the problem with 271 total rows | |||||
| num_total_rows = 271 | |||||
| # User-created sampler here | |||||
| ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache) | |||||
| iter1 = ds1.create_dict_iterator() | |||||
| num_iter = 0 | |||||
| for _ in iter1: | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||||
| assert num_iter == num_total_rows | |||||
| cache_stat = some_cache.GetStat() | |||||
| assert cache_stat.num_mem_cached == num_total_rows | |||||
| logger.info("test_cache_nomap_all_rows_cached Ended.\n") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' | # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' | ||||
| # since cache server is required to be brought up first | # since cache server is required to be brought up first | ||||