|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <unistd.h> |
|
|
|
#include <iomanip> |
|
|
|
#include "minddata/dataset/engine/cache/cache_client.h" |
|
|
|
#include "minddata/dataset/engine/cache/cache_request.h" |
|
|
|
@@ -394,7 +395,6 @@ CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_ |
|
|
|
|
|
|
|
CacheClient::AsyncBufferStream::~AsyncBufferStream() { |
|
|
|
(void)vg_.ServiceStop(); |
|
|
|
writer_wp_.Set(); |
|
|
|
(void)ReleaseBuffer(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -424,12 +424,9 @@ Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) { |
|
|
|
// We only need to set the pointer during init. Other fields will be set dynamically. |
|
|
|
buf_arr_[i].buffer_ = reinterpret_cast<void *>(start + i * kAsyncBufferSize); |
|
|
|
} |
|
|
|
buf_arr_[0].begin_addr_ = 0; |
|
|
|
buf_arr_[0].end_addr_ = 0; |
|
|
|
buf_arr_[0].bytes_avail_ = kAsyncBufferSize; |
|
|
|
buf_arr_[0].num_ele_ = 0; |
|
|
|
RETURN_IF_NOT_OK(vg_.ServiceStart()); |
|
|
|
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async flush", std::bind(&CacheClient::AsyncBufferStream::AsyncFlush, this))); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -448,127 +445,66 @@ Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) { |
|
|
|
if (sz > kAsyncBufferSize) { |
|
|
|
return Status(StatusCode::kNotImplementedYet); |
|
|
|
} |
|
|
|
// Find out where we are going to write in the (logical) buffer stream without acquiring the lock |
|
|
|
// but only use the atomic variable. |
|
|
|
auto write_addr = next_addr_.fetch_add(sz); |
|
|
|
Status rc; |
|
|
|
do { |
|
|
|
SharedLock lock(&mux_); |
|
|
|
// Check error from the server side while we have the lock; |
|
|
|
RETURN_IF_NOT_OK(flush_rc_); |
|
|
|
AsyncWriter *asyncWriter = &buf_arr_[cur_]; |
|
|
|
rc = asyncWriter->Write(write_addr, sz, v); |
|
|
|
if (rc.get_code() == StatusCode::kNoSpace) { |
|
|
|
// If no space, wake up the async flush thread |
|
|
|
writer_wp_.Clear(); |
|
|
|
flush_wp_.Set(); |
|
|
|
// Let go of the lock before we wait. |
|
|
|
lock.Unlock(); |
|
|
|
// Wait for the next window |
|
|
|
RETURN_IF_NOT_OK(writer_wp_.Wait()); |
|
|
|
} |
|
|
|
} while (rc.get_code() == StatusCode::kNoSpace); |
|
|
|
return rc; |
|
|
|
std::unique_lock<std::mutex> lock(mux_); |
|
|
|
// Check error from the server side while we have the lock; |
|
|
|
RETURN_IF_NOT_OK(flush_rc_); |
|
|
|
AsyncWriter *asyncWriter = &buf_arr_[cur_]; |
|
|
|
if (asyncWriter->bytes_avail_ < sz) { |
|
|
|
// Flush and switch to a new buffer while we have the lock. |
|
|
|
RETURN_IF_NOT_OK(SyncFlush(AsyncFlushFlag::kCallerHasXLock)); |
|
|
|
// Refresh the pointer after we switch |
|
|
|
asyncWriter = &buf_arr_[cur_]; |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK(asyncWriter->Write(sz, v)); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status CacheClient::AsyncBufferStream::SyncFlush(bool blocking) { |
|
|
|
bool retry = false; |
|
|
|
do { |
|
|
|
UniqueLock lock(&mux_); |
|
|
|
flush_wp_.Clear(); |
|
|
|
auto *asyncWriter = &buf_arr_[cur_]; |
|
|
|
retry = false; |
|
|
|
// Because the clients are copying async, we need to wait until all of them have written. |
|
|
|
if (kAsyncBufferSize - (asyncWriter->end_addr_ - asyncWriter->begin_addr_) == asyncWriter->bytes_avail_) { |
|
|
|
if (asyncWriter->num_ele_) { |
|
|
|
asyncWriter->rq.reset( |
|
|
|
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_)); |
|
|
|
flush_rc_ = cc_->PushRequest(asyncWriter->rq); |
|
|
|
if (flush_rc_.IsOk()) { |
|
|
|
// If we are asked to wait, say this is the final flush, just wait for its completion. |
|
|
|
if (blocking) { |
|
|
|
flush_rc_ = asyncWriter->rq->Wait(); |
|
|
|
asyncWriter->rq.reset(); |
|
|
|
} |
|
|
|
// Prepare for the next buffer which will start from the end addr of the previous buffer. |
|
|
|
int64_t previous_end_addr = asyncWriter->end_addr_; |
|
|
|
cur_ = (cur_ + 1) % kNumAsyncBuffer; |
|
|
|
asyncWriter = &buf_arr_[cur_]; |
|
|
|
// Update the cur_ while we have the lock. |
|
|
|
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content |
|
|
|
// Also we can also pick up any error from previous flush. |
|
|
|
if (asyncWriter->rq) { |
|
|
|
// Save the result into a common area, so worker can see it and quit. |
|
|
|
flush_rc_ = asyncWriter->rq->Wait(); |
|
|
|
asyncWriter->rq.reset(); |
|
|
|
} |
|
|
|
asyncWriter->bytes_avail_ = kAsyncBufferSize; |
|
|
|
asyncWriter->num_ele_ = 0; |
|
|
|
asyncWriter->begin_addr_ = previous_end_addr; |
|
|
|
asyncWriter->end_addr_ = previous_end_addr; |
|
|
|
} |
|
|
|
Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) { |
|
|
|
std::unique_lock lock(mux_, std::defer_lock_t()); |
|
|
|
bool callerHasXLock = (flag & AsyncFlushFlag::kCallerHasXLock) == AsyncFlushFlag::kCallerHasXLock; |
|
|
|
if (!callerHasXLock) { |
|
|
|
lock.lock(); |
|
|
|
} |
|
|
|
auto *asyncWriter = &buf_arr_[cur_]; |
|
|
|
if (asyncWriter->num_ele_) { |
|
|
|
asyncWriter->rq.reset( |
|
|
|
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_)); |
|
|
|
flush_rc_ = cc_->PushRequest(asyncWriter->rq); |
|
|
|
if (flush_rc_.IsOk()) { |
|
|
|
// If we are asked to wait, say this is the final flush, just wait for its completion. |
|
|
|
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking; |
|
|
|
if (blocking) { |
|
|
|
flush_rc_ = asyncWriter->rq->Wait(); |
|
|
|
asyncWriter->rq.reset(); |
|
|
|
} |
|
|
|
} else { |
|
|
|
// Some clients are late and aren't done yet. Let go of the lock. |
|
|
|
lock.Unlock(); |
|
|
|
if (this_thread::is_interrupted()) { |
|
|
|
retry = false; |
|
|
|
flush_rc_ = Status(StatusCode::kInterrupted); |
|
|
|
} else { |
|
|
|
retry = true; |
|
|
|
// Prepare for the next buffer. |
|
|
|
cur_ = (cur_ + 1) % kNumAsyncBuffer; |
|
|
|
asyncWriter = &buf_arr_[cur_]; |
|
|
|
// Update the cur_ while we have the lock. |
|
|
|
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content |
|
|
|
// Also we can also pick up any error from previous flush. |
|
|
|
if (asyncWriter->rq) { |
|
|
|
// Save the result into a common area, so worker can see it and quit. |
|
|
|
flush_rc_ = asyncWriter->rq->Wait(); |
|
|
|
asyncWriter->rq.reset(); |
|
|
|
} |
|
|
|
writer_wp_.Set(); |
|
|
|
std::this_thread::yield(); |
|
|
|
asyncWriter->bytes_avail_ = kAsyncBufferSize; |
|
|
|
asyncWriter->num_ele_ = 0; |
|
|
|
} |
|
|
|
} while (retry); |
|
|
|
// Wake up any writer that is waiting. |
|
|
|
writer_wp_.Set(); |
|
|
|
} |
|
|
|
return flush_rc_; |
|
|
|
} |
|
|
|
|
|
|
|
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t write_addr, int64_t sz, |
|
|
|
const std::vector<ReadableSlice> &v) { |
|
|
|
// Map our logical address to the real physical address in the buffer like where we start and |
|
|
|
// where we end. |
|
|
|
auto rel_write_addr = write_addr - begin_addr_; |
|
|
|
auto rel_end_addr = rel_write_addr + sz; |
|
|
|
// If not enough space, time to flush and swap. |
|
|
|
if (rel_end_addr > kAsyncBufferSize) { |
|
|
|
return Status(StatusCode::kNoSpace); |
|
|
|
} |
|
|
|
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std::vector<ReadableSlice> &v) { |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(sz <= bytes_avail_, "Programming error"); |
|
|
|
for (auto &p : v) { |
|
|
|
auto write_sz = p.GetSize(); |
|
|
|
WritableSlice dest(reinterpret_cast<char *>(buffer_) + rel_write_addr, write_sz); |
|
|
|
WritableSlice dest(reinterpret_cast<char *>(buffer_) + kAsyncBufferSize - bytes_avail_, write_sz); |
|
|
|
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p)); |
|
|
|
bytes_avail_ -= write_sz; |
|
|
|
rel_write_addr += write_sz; |
|
|
|
} |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(rel_write_addr == rel_end_addr, "Programming error"); |
|
|
|
++num_ele_; |
|
|
|
// Update the end_addr if ours is better |
|
|
|
int64_t new_end_addr = write_addr + sz; |
|
|
|
int64_t expected = end_addr_; |
|
|
|
while (expected < new_end_addr) { |
|
|
|
if (!end_addr_.compare_exchange_weak(expected, new_end_addr)) { |
|
|
|
expected = end_addr_; |
|
|
|
} |
|
|
|
} |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(end_addr_ >= new_end_addr, "Programming error"); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status CacheClient::AsyncBufferStream::AsyncFlush() { |
|
|
|
TaskManager::FindMe()->Post(); |
|
|
|
Status rc; |
|
|
|
do { |
|
|
|
RETURN_IF_NOT_OK(flush_wp_.Wait()); |
|
|
|
RETURN_IF_INTERRUPTED(); |
|
|
|
rc = SyncFlush(); |
|
|
|
// Other than resource error, all other error we quit. |
|
|
|
} while (rc.IsOk() || rc.IsOutofMemory() || rc.IsNoSpace()); |
|
|
|
// Make sure we wake up workers waiting for us. |
|
|
|
writer_wp_.Set(); |
|
|
|
return rc; |
|
|
|
} |
|
|
|
} // namespace dataset |
|
|
|
} // namespace mindspore |