Browse Source

!10898 Fix occasional run task error during 8p training with cache

From: @lixiachen
Reviewed-by: @liucunwei,@pandoublefeng
Signed-off-by: @liucunwei
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
c85590c9f5
4 changed files with 57 additions and 125 deletions
  1. +46
    -110
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
  2. +9
    -15
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h
  3. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc
  4. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h

+ 46
- 110
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <unistd.h>
#include <iomanip>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_request.h"
@@ -394,7 +395,6 @@ CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_

CacheClient::AsyncBufferStream::~AsyncBufferStream() {
(void)vg_.ServiceStop();
writer_wp_.Set();
(void)ReleaseBuffer();
}

@@ -424,12 +424,9 @@ Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) {
// We only need to set the pointer during init. Other fields will be set dynamically.
buf_arr_[i].buffer_ = reinterpret_cast<void *>(start + i * kAsyncBufferSize);
}
buf_arr_[0].begin_addr_ = 0;
buf_arr_[0].end_addr_ = 0;
buf_arr_[0].bytes_avail_ = kAsyncBufferSize;
buf_arr_[0].num_ele_ = 0;
RETURN_IF_NOT_OK(vg_.ServiceStart());
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async flush", std::bind(&CacheClient::AsyncBufferStream::AsyncFlush, this)));
return Status::OK();
}

@@ -448,127 +445,66 @@ Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
if (sz > kAsyncBufferSize) {
return Status(StatusCode::kNotImplementedYet);
}
// Find out where we are going to write in the (logical) buffer stream without acquiring the lock
// but only use the atomic variable.
auto write_addr = next_addr_.fetch_add(sz);
Status rc;
do {
SharedLock lock(&mux_);
// Check error from the server side while we have the lock;
RETURN_IF_NOT_OK(flush_rc_);
AsyncWriter *asyncWriter = &buf_arr_[cur_];
rc = asyncWriter->Write(write_addr, sz, v);
if (rc.get_code() == StatusCode::kNoSpace) {
// If no space, wake up the async flush thread
writer_wp_.Clear();
flush_wp_.Set();
// Let go of the lock before we wait.
lock.Unlock();
// Wait for the next window
RETURN_IF_NOT_OK(writer_wp_.Wait());
}
} while (rc.get_code() == StatusCode::kNoSpace);
return rc;
std::unique_lock<std::mutex> lock(mux_);
// Check error from the server side while we have the lock;
RETURN_IF_NOT_OK(flush_rc_);
AsyncWriter *asyncWriter = &buf_arr_[cur_];
if (asyncWriter->bytes_avail_ < sz) {
// Flush and switch to a new buffer while we have the lock.
RETURN_IF_NOT_OK(SyncFlush(AsyncFlushFlag::kCallerHasXLock));
// Refresh the pointer after we switch
asyncWriter = &buf_arr_[cur_];
}
RETURN_IF_NOT_OK(asyncWriter->Write(sz, v));
return Status::OK();
}

Status CacheClient::AsyncBufferStream::SyncFlush(bool blocking) {
bool retry = false;
do {
UniqueLock lock(&mux_);
flush_wp_.Clear();
auto *asyncWriter = &buf_arr_[cur_];
retry = false;
// Because the clients are copying async, we need to wait until all of them have written.
if (kAsyncBufferSize - (asyncWriter->end_addr_ - asyncWriter->begin_addr_) == asyncWriter->bytes_avail_) {
if (asyncWriter->num_ele_) {
asyncWriter->rq.reset(
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
if (flush_rc_.IsOk()) {
// If we are asked to wait, say this is the final flush, just wait for its completion.
if (blocking) {
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
// Prepare for the next buffer which will start from the end addr of the previous buffer.
int64_t previous_end_addr = asyncWriter->end_addr_;
cur_ = (cur_ + 1) % kNumAsyncBuffer;
asyncWriter = &buf_arr_[cur_];
// Update the cur_ while we have the lock.
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
// Also we can also pick up any error from previous flush.
if (asyncWriter->rq) {
// Save the result into a common area, so worker can see it and quit.
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
asyncWriter->bytes_avail_ = kAsyncBufferSize;
asyncWriter->num_ele_ = 0;
asyncWriter->begin_addr_ = previous_end_addr;
asyncWriter->end_addr_ = previous_end_addr;
}
Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) {
std::unique_lock lock(mux_, std::defer_lock_t());
bool callerHasXLock = (flag & AsyncFlushFlag::kCallerHasXLock) == AsyncFlushFlag::kCallerHasXLock;
if (!callerHasXLock) {
lock.lock();
}
auto *asyncWriter = &buf_arr_[cur_];
if (asyncWriter->num_ele_) {
asyncWriter->rq.reset(
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
if (flush_rc_.IsOk()) {
// If we are asked to wait, say this is the final flush, just wait for its completion.
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
if (blocking) {
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
} else {
// Some clients are late and aren't done yet. Let go of the lock.
lock.Unlock();
if (this_thread::is_interrupted()) {
retry = false;
flush_rc_ = Status(StatusCode::kInterrupted);
} else {
retry = true;
// Prepare for the next buffer.
cur_ = (cur_ + 1) % kNumAsyncBuffer;
asyncWriter = &buf_arr_[cur_];
// Update the cur_ while we have the lock.
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
// Also we can also pick up any error from previous flush.
if (asyncWriter->rq) {
// Save the result into a common area, so worker can see it and quit.
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
writer_wp_.Set();
std::this_thread::yield();
asyncWriter->bytes_avail_ = kAsyncBufferSize;
asyncWriter->num_ele_ = 0;
}
} while (retry);
// Wake up any writer that is waiting.
writer_wp_.Set();
}
return flush_rc_;
}

Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t write_addr, int64_t sz,
const std::vector<ReadableSlice> &v) {
// Map our logical address to the real physical address in the buffer like where we start and
// where we end.
auto rel_write_addr = write_addr - begin_addr_;
auto rel_end_addr = rel_write_addr + sz;
// If not enough space, time to flush and swap.
if (rel_end_addr > kAsyncBufferSize) {
return Status(StatusCode::kNoSpace);
}
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std::vector<ReadableSlice> &v) {
CHECK_FAIL_RETURN_UNEXPECTED(sz <= bytes_avail_, "Programming error");
for (auto &p : v) {
auto write_sz = p.GetSize();
WritableSlice dest(reinterpret_cast<char *>(buffer_) + rel_write_addr, write_sz);
WritableSlice dest(reinterpret_cast<char *>(buffer_) + kAsyncBufferSize - bytes_avail_, write_sz);
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p));
bytes_avail_ -= write_sz;
rel_write_addr += write_sz;
}
CHECK_FAIL_RETURN_UNEXPECTED(rel_write_addr == rel_end_addr, "Programming error");
++num_ele_;
// Update the end_addr if ours is better
int64_t new_end_addr = write_addr + sz;
int64_t expected = end_addr_;
while (expected < new_end_addr) {
if (!end_addr_.compare_exchange_weak(expected, new_end_addr)) {
expected = end_addr_;
}
}
CHECK_FAIL_RETURN_UNEXPECTED(end_addr_ >= new_end_addr, "Programming error");
return Status::OK();
}

Status CacheClient::AsyncBufferStream::AsyncFlush() {
TaskManager::FindMe()->Post();
Status rc;
do {
RETURN_IF_NOT_OK(flush_wp_.Wait());
RETURN_IF_INTERRUPTED();
rc = SyncFlush();
// Other than resource error, all other error we quit.
} while (rc.IsOk() || rc.IsOutofMemory() || rc.IsNoSpace());
// Make sure we wake up workers waiting for us.
writer_wp_.Set();
return rc;
}
} // namespace dataset
} // namespace mindspore

+ 9
- 15
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h View File

@@ -259,12 +259,12 @@ class CacheClient {

// Default size of the async write buffer
constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M
constexpr static int32_t kNumAsyncBuffer = 2;
constexpr static int32_t kNumAsyncBuffer = 3;

/// Force a final flush to the cache server. Must be called when receving eoe.
/// Force a final flush to the cache server. Must be called when receiving eoe.
Status FlushAsyncWriteBuffer() {
if (async_buffer_stream_) {
return async_buffer_stream_->SyncFlush(true);
return async_buffer_stream_->SyncFlush(AsyncBufferStream::AsyncFlushFlag::kFlushBlocking);
}
return Status::OK();
}
@@ -323,21 +323,20 @@ class CacheClient {
/// result of some previous flush.
/// \note Need to call SyncFlush to do the final flush.
Status AsyncWrite(const TensorRow &row);
Status SyncFlush(bool blocking = false);
enum class AsyncFlushFlag : int8_t { kFlushNone = 0, kFlushBlocking = 1, kCallerHasXLock = 1u << 2 };
Status SyncFlush(AsyncFlushFlag flag);

/// This maps a physical shared memory to the data stream.
class AsyncWriter {
public:
friend class AsyncBufferStream;
Status Write(int64_t start_addr, int64_t sz, const std::vector<ReadableSlice> &v);
Status Write(int64_t sz, const std::vector<ReadableSlice> &v);

private:
std::shared_ptr<BatchCacheRowsRequest> rq;
void *buffer_;
int32_t num_ele_; // How many tensor rows in this buffer
int64_t begin_addr_; // Start of logical address of the data stream
std::atomic<int64_t> end_addr_; // End of the logical address of the data stream
std::atomic<int64_t> bytes_avail_; // Number of bytes remain
int32_t num_ele_; // How many tensor rows in this buffer
int64_t bytes_avail_; // Number of bytes remain
};

/// \brief Release the shared memory during shutdown
@@ -346,18 +345,13 @@ class CacheClient {

private:
Status flush_rc_;
WaitPost writer_wp_;
WaitPost flush_wp_;
RWLock mux_;
std::mutex mux_;
TaskGroup vg_;
CacheClient *cc_;
int64_t offset_addr_;
AsyncWriter buf_arr_[kNumAsyncBuffer];
int32_t cur_;
std::atomic<int64_t> next_addr_;

/// \brief Entry point of the async flush thread.
Status AsyncFlush();
};
std::shared_ptr<AsyncBufferStream> async_buffer_stream_;



+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc View File

@@ -131,6 +131,7 @@ Status CacheServerHW::GetNumaNodeInfo() {
while (iter != end) {
auto match = iter->str();
auto pos = match.find_first_of('-');
CHECK_FAIL_RETURN_UNEXPECTED(pos != std::string::npos, "Failed to parse numa node file");
std::string min = match.substr(0, pos);
std::string max = match.substr(pos + 1);
cpu_id_t cpu_min = strtol(min.data(), nullptr, 10);


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h View File

@@ -363,6 +363,7 @@ class CacheServer : public Service {
expected_ = n;
rc_lists_.reserve(expected_);
}
~BatchWait() = default;

std::weak_ptr<BatchWait> GetBatchWait() { return weak_from_this(); }



Loading…
Cancel
Save