From: @jkl_lee Reviewed-by: @mikef,@nsyca Signed-off-by: @nsycatags/v1.1.0
| @@ -14,34 +14,96 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_arena.h" | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) { | |||
| CachedSharedMemory::CachedSharedMemory(int32_t port, size_t val_in_GB) | |||
| : shared_memory_sz_in_gb_(val_in_GB), port_(port), num_numa_nodes_(-1), sub_pool_sz_(-1) { | |||
| // We create the shared memory and we will destroy it. All other client just detach only. | |||
| shm_.RemoveResourcesOnExit(); | |||
| } | |||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() {} | |||
| CachedSharedMemory::~CachedSharedMemory() = default; | |||
| Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | |||
| size_t val_in_GB) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); | |||
| if (ba == nullptr) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| // Transfer the ownership of this pointer. Any future error in the processing we will have | |||
| // the destructor of *out to deal. | |||
| (*out).reset(ba); | |||
| Status CachedSharedMemory::Init() { | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| num_numa_nodes_ = cs.GetNumaNodeCount(); | |||
| // Generate the ftok using a combination of port. | |||
| SharedMemory::shm_key_t shm_key; | |||
| RETURN_IF_NOT_OK(PortToFtok(port, &shm_key)); | |||
| ba->shm_.SetPublicKey(shm_key); | |||
| RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key)); | |||
| shm_.SetPublicKey(shm_key); | |||
| // Value is in GB. Convert into bytes. | |||
| int64_t sz = val_in_GB * 1073741824L; | |||
| RETURN_IF_NOT_OK(ba->shm_.Create(sz)); | |||
| ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz); | |||
| int64_t shm_mem_sz = shared_memory_sz_in_gb_ * 1073741824L; | |||
| RETURN_IF_NOT_OK(shm_.Create(shm_mem_sz)); | |||
| MS_LOG(INFO) << "Creation of shared memory successful. Shared memory key " << shm_.GetKey(); | |||
| // Interleave the memory. | |||
| cs.GetHWControl()->InterleaveMemory(shm_.SharedMemoryBaseAddr(), shm_mem_sz); | |||
| // We will create a number of sub pool out of shared memory to reduce latch contention | |||
| int32_t num_of_pools = num_numa_nodes_; | |||
| if (num_numa_nodes_ == 1) { | |||
| num_of_pools = shared_memory_sz_in_gb_ * 2; | |||
| } | |||
| sub_pool_sz_ = shm_mem_sz / num_of_pools; | |||
| // If each subpool is too small, readjust the number of pools | |||
| constexpr int64 min_subpool_sz = 512 * 1048576L; | |||
| if (sub_pool_sz_ < min_subpool_sz) { | |||
| sub_pool_sz_ = min_subpool_sz; | |||
| num_of_pools = shm_mem_sz / min_subpool_sz; | |||
| } | |||
| shm_pool_.reserve(num_of_pools); | |||
| for (auto i = 0; i < num_of_pools; ++i) { | |||
| void *ptr = static_cast<char *>(shm_.SharedMemoryBaseAddr()) + i * sub_pool_sz_; | |||
| shm_pool_.push_back(std::make_unique<ArenaImpl>(ptr, sub_pool_sz_)); | |||
| } | |||
| mux_ = std::make_unique<std::mutex[]>(num_of_pools); | |||
| return Status::OK(); | |||
| } | |||
| Status CachedSharedMemory::CreateArena(std::unique_ptr<CachedSharedMemory> *out, int32_t port, size_t val_in_GB) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| auto mem_pool = std::unique_ptr<CachedSharedMemory>(new CachedSharedMemory(port, val_in_GB)); | |||
| RETURN_IF_NOT_OK(mem_pool->Init()); | |||
| *out = std::move(mem_pool); | |||
| return Status::OK(); | |||
| } | |||
| Status CachedSharedMemory::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) { | |||
| Status rc; | |||
| RETURN_UNEXPECTED_IF_NULL(p); | |||
| auto begin_slot = client_id % shm_pool_.size(); | |||
| auto slot = begin_slot; | |||
| do { | |||
| std::unique_lock<std::mutex> lock(mux_[slot]); | |||
| rc = shm_pool_[slot]->Allocate(sz, p); | |||
| if (rc.IsOutofMemory()) { | |||
| slot = (slot + 1) % shm_pool_.size(); | |||
| } | |||
| } while (rc.IsError() && slot != begin_slot); | |||
| if (rc.IsError()) { | |||
| return rc; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| void CachedSharedMemory::DeallocateSharedMemory(int32_t client_id, void *p) { | |||
| auto begin_slot = client_id % shm_pool_.size(); | |||
| auto slot = begin_slot; | |||
| auto start_addr = static_cast<char *>(SharedMemoryBaseAddr()); | |||
| bool found = false; | |||
| do { | |||
| auto ptr = start_addr + slot * sub_pool_sz_; | |||
| if (ptr <= p && p < (ptr + sub_pool_sz_)) { | |||
| std::unique_lock<std::mutex> lock(mux_[slot]); | |||
| shm_pool_[slot]->Deallocate(p); | |||
| found = true; | |||
| break; | |||
| } else { | |||
| slot = (slot + 1) % shm_pool_.size(); | |||
| } | |||
| } while (slot != begin_slot); | |||
| if (!found) { | |||
| MS_LOG(ERROR) << "Programming error. Can't find the arena the pointer " << p << " comes from"; | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -18,25 +18,29 @@ | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// This is a derived class of Arena but resides in shared memory | |||
| class CachedSharedMemoryArena : public MemoryPool { | |||
| /// This is like a CircularPool but each arena is in shared memory and | |||
| /// possibly bind to a numa socket. | |||
| class CachedSharedMemory { | |||
| public: | |||
| // Disable copy and assignment constructor | |||
| CachedSharedMemoryArena(const CachedSharedMemoryArena &) = delete; | |||
| CachedSharedMemoryArena &operator=(const CachedSharedMemoryArena &) = delete; | |||
| ~CachedSharedMemoryArena() override; | |||
| CachedSharedMemory(const CachedSharedMemory &) = delete; | |||
| CachedSharedMemory &operator=(const CachedSharedMemory &) = delete; | |||
| ~CachedSharedMemory(); | |||
| /// \brief Create an Arena in shared memory | |||
| /// \param[out] p_ba Pointer to a unique_ptr | |||
| /// \param shmkey Shared memory key | |||
| /// \param val_in_GB size of shared memory in gigabyte | |||
| /// \return Status object | |||
| static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB); | |||
| static Status CreateArena(std::unique_ptr<CachedSharedMemory> *out, int32_t port, size_t val_in_GB); | |||
| /// \brief This returns where we attach to the shared memory. | |||
| /// Some gRPC requests will ask for a shared memory block, and | |||
| @@ -44,45 +48,29 @@ class CachedSharedMemoryArena : public MemoryPool { | |||
| /// in the client. So instead we will return an address relative | |||
| /// to the base address of the shared memory where we attach to. | |||
| /// \return Base address of the shared memory. | |||
| const void *SharedMemoryBaseAddr() const { return impl_->get_base_addr(); } | |||
| /// As a derived class of MemoryPool, we have to implement the following | |||
| /// But we simply transfer the call to the implementation class | |||
| Status Allocate(size_t size, void **pVoid) override { | |||
| std::unique_lock<std::mutex> lock(mux_); | |||
| return impl_->Allocate(size, pVoid); | |||
| } | |||
| Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override { | |||
| std::unique_lock<std::mutex> lock(mux_); | |||
| return impl_->Reallocate(pVoid, old_sz, new_sz); | |||
| } | |||
| void Deallocate(void *pVoid) override { | |||
| std::unique_lock<std::mutex> lock(mux_); | |||
| impl_->Deallocate(pVoid); | |||
| } | |||
| uint64_t get_max_size() const override { return impl_->get_max_size(); } | |||
| int PercentFree() const override { | |||
| std::unique_lock<std::mutex> lock(mux_); | |||
| return impl_->PercentFree(); | |||
| } | |||
| /// \brief Dump the memory allocation block. | |||
| friend std::ostream &operator<<(std::ostream &os, const CachedSharedMemoryArena &s) { | |||
| os << *(s.impl_); | |||
| return os; | |||
| } | |||
| const void *SharedMemoryBaseAddr() const { return shm_.SharedMemoryBaseAddr(); } | |||
| void *SharedMemoryBaseAddr() { return shm_.SharedMemoryBaseAddr(); } | |||
| /// \brief Get the shared memory key of the shared memory | |||
| SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); } | |||
| /// \brief Allocate shared memory for a given pipeline | |||
| Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p); | |||
| /// \brief Deallocate shared memory for a given pipeline | |||
| void DeallocateSharedMemory(int32_t client_id, void *p); | |||
| private: | |||
| mutable std::mutex mux_; | |||
| int32_t val_in_GB_; | |||
| int32_t shared_memory_sz_in_gb_; | |||
| int32_t port_; | |||
| SharedMemory shm_; | |||
| std::unique_ptr<ArenaImpl> impl_; | |||
| std::vector<std::unique_ptr<ArenaImpl>> shm_pool_; | |||
| std::unique_ptr<std::mutex[]> mux_; | |||
| int32_t num_numa_nodes_; | |||
| int64_t sub_pool_sz_; | |||
| /// Private constructor. Not to be called directly. | |||
| CachedSharedMemoryArena(int32_t port, size_t val_in_GB); | |||
| CachedSharedMemory(int32_t port, size_t val_in_GB); | |||
| Status Init(); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | |||
| #include "minddata/dataset/util/bit.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -71,6 +72,10 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool | |||
| CacheClient::~CacheClient() { | |||
| cache_miss_keys_wp_.Set(); | |||
| // Manually release the async buffer because we need the comm layer. | |||
| if (async_buffer_stream_) { | |||
| async_buffer_stream_->ReleaseBuffer(); | |||
| } | |||
| if (client_id_ != -1) { | |||
| try { | |||
| // Send a message to the server, saying I am done. | |||
| @@ -132,6 +137,42 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::AsyncWriteRow(const TensorRow &row) { | |||
| if (async_buffer_stream_ == nullptr) { | |||
| return Status(StatusCode::kNotImplementedYet); | |||
| } | |||
| RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) { | |||
| if (async_buffer_stream_ == nullptr) { | |||
| return Status(StatusCode::kNotImplementedYet); | |||
| } else { | |||
| Status rc; | |||
| std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>(); | |||
| auto num_rows = in->NumRows(); | |||
| if (num_rows > 0) { | |||
| for (auto i = 0; i < num_rows; ++i) { | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(in->PopRow(&row)); | |||
| rc = AsyncWriteRow(row); | |||
| if (rc.get_code() == StatusCode::kNotImplementedYet) { | |||
| tensor_table->push_back(row); | |||
| } else if (rc.IsError()) { | |||
| return rc; | |||
| } | |||
| } | |||
| } | |||
| // If not all of them can be sent async, return what's left back to the caller. | |||
| if (!tensor_table->empty()) { | |||
| in->set_tensor_table(std::move(tensor_table)); | |||
| return Status(StatusCode::kNotImplementedYet); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| auto rq = std::make_shared<BatchFetchRequest>(this, row_id); | |||
| @@ -141,7 +182,7 @@ Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable | |||
| Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr); | |||
| // Free the memory by sending a request back to the server. | |||
| if (mem_addr != -1) { | |||
| auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, mem_addr); | |||
| auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, client_id_, mem_addr); | |||
| Status rc2 = PushRequest(mfree_req); | |||
| // But we won't wait for the result for the sake of performance. | |||
| if (rc.IsOk() && rc2.IsError()) { | |||
| @@ -211,6 +252,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| if (success) { | |||
| // Attach to shared memory for local client | |||
| RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); | |||
| if (local_bypass_) { | |||
| async_buffer_stream_ = std::make_shared<AsyncBufferStream>(); | |||
| RETURN_IF_NOT_OK(async_buffer_stream_->Init(this)); | |||
| } | |||
| } | |||
| // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the | |||
| // CacheOp to bypass the build phase. | |||
| @@ -240,6 +285,17 @@ Status CacheClient::GetStat(CacheServiceStat *stat) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::GetState(int8_t *out) { | |||
| SharedLock lck(&mux_); | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(server_connection_id_ != 0, "GetState called but the cache is not in use yet."); | |||
| auto rq = std::make_shared<GetCacheStateRequest>(server_connection_id_); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| *out = rq->GetState(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) { | |||
| SharedLock lck(&mux_); | |||
| auto rq = std::make_shared<CacheSchemaRequest>(server_connection_id_); | |||
| @@ -334,5 +390,181 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { | |||
| return it != gap_.end(); | |||
| } | |||
| } | |||
| CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {} | |||
| CacheClient::AsyncBufferStream::~AsyncBufferStream() { | |||
| (void)vg_.ServiceStop(); | |||
| writer_wp_.Set(); | |||
| (void)ReleaseBuffer(); | |||
| } | |||
| Status CacheClient::AsyncBufferStream::ReleaseBuffer() { | |||
| if (offset_addr_ != -1) { | |||
| auto mfree_req = | |||
| std::make_shared<FreeSharedBlockRequest>(cc_->server_connection_id_, cc_->GetClientId(), offset_addr_); | |||
| offset_addr_ = -1; | |||
| RETURN_IF_NOT_OK(cc_->PushRequest(mfree_req)); | |||
| RETURN_IF_NOT_OK(mfree_req->Wait()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) { | |||
| cc_ = cc; | |||
| // Allocate shared memory from the server | |||
| auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(cc_->server_connection_id_, cc_->GetClientId(), | |||
| kAsyncBufferSize * kNumAsyncBuffer); | |||
| RETURN_IF_NOT_OK(cc->PushRequest(mem_rq)); | |||
| RETURN_IF_NOT_OK(mem_rq->Wait()); | |||
| offset_addr_ = mem_rq->GetAddr(); | |||
| // Now we need to add that to the base address of where we attach. | |||
| auto base = cc->SharedMemoryBaseAddr(); | |||
| auto start = reinterpret_cast<int64_t>(base) + offset_addr_; | |||
| for (auto i = 0; i < kNumAsyncBuffer; ++i) { | |||
| // We only need to set the pointer during init. Other fields will be set dynamically. | |||
| buf_arr_[i].buffer_ = reinterpret_cast<void *>(start + i * kAsyncBufferSize); | |||
| } | |||
| buf_arr_[0].begin_addr_ = 0; | |||
| buf_arr_[0].end_addr_ = 0; | |||
| buf_arr_[0].bytes_avail_ = kAsyncBufferSize; | |||
| buf_arr_[0].num_ele_ = 0; | |||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async flush", std::bind(&CacheClient::AsyncBufferStream::AsyncFlush, this))); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) { | |||
| std::vector<ReadableSlice> v; | |||
| v.reserve(row.size() + 1); | |||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb; | |||
| RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb)); | |||
| int64_t sz = fbb->GetSize(); | |||
| v.emplace_back(fbb->GetBufferPointer(), sz); | |||
| for (const auto &ts : row) { | |||
| sz += ts->SizeInBytes(); | |||
| v.emplace_back(ts->GetBuffer(), ts->SizeInBytes()); | |||
| } | |||
| // If the size is too big, tell the user to send it directly. | |||
| if (sz > kAsyncBufferSize) { | |||
| return Status(StatusCode::kNotImplementedYet); | |||
| } | |||
| // Find out where we are going to write in the (logical) buffer stream without acquiring the lock | |||
| // but only use the atomic variable. | |||
| auto write_addr = next_addr_.fetch_add(sz); | |||
| Status rc; | |||
| do { | |||
| SharedLock lock(&mux_); | |||
| // Check error from the server side while we have the lock; | |||
| RETURN_IF_NOT_OK(flush_rc_); | |||
| AsyncWriter *asyncWriter = &buf_arr_[cur_]; | |||
| rc = asyncWriter->Write(write_addr, sz, v); | |||
| if (rc.get_code() == StatusCode::kNoSpace) { | |||
| // If no space, wake up the async flush thread | |||
| writer_wp_.Clear(); | |||
| flush_wp_.Set(); | |||
| // Let go of the lock before we wait. | |||
| lock.Unlock(); | |||
| // Wait for the next window | |||
| RETURN_IF_NOT_OK(writer_wp_.Wait()); | |||
| } | |||
| } while (rc.get_code() == StatusCode::kNoSpace); | |||
| return rc; | |||
| } | |||
| Status CacheClient::AsyncBufferStream::SyncFlush(bool blocking) { | |||
| bool retry = false; | |||
| do { | |||
| UniqueLock lock(&mux_); | |||
| flush_wp_.Clear(); | |||
| auto *asyncWriter = &buf_arr_[cur_]; | |||
| retry = false; | |||
| // Because the clients are copying async, we need to wait until all of them have written. | |||
| if (kAsyncBufferSize - (asyncWriter->end_addr_ - asyncWriter->begin_addr_) == asyncWriter->bytes_avail_) { | |||
| if (asyncWriter->num_ele_) { | |||
| asyncWriter->rq.reset( | |||
| new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_)); | |||
| flush_rc_ = cc_->PushRequest(asyncWriter->rq); | |||
| if (flush_rc_.IsOk()) { | |||
| // If we are asked to wait, say this is the final flush, just wait for its completion. | |||
| if (blocking) { | |||
| flush_rc_ = asyncWriter->rq->Wait(); | |||
| asyncWriter->rq.reset(); | |||
| } | |||
| // Prepare for the next buffer which will start from the end addr of the previous buffer. | |||
| int64_t previous_end_addr = asyncWriter->end_addr_; | |||
| cur_ = (cur_ + 1) % kNumAsyncBuffer; | |||
| asyncWriter = &buf_arr_[cur_]; | |||
| // Update the cur_ while we have the lock. | |||
| // Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content | |||
| // Also we can also pick up any error from previous flush. | |||
| if (asyncWriter->rq) { | |||
| // Save the result into a common area, so worker can see it and quit. | |||
| flush_rc_ = asyncWriter->rq->Wait(); | |||
| asyncWriter->rq.reset(); | |||
| } | |||
| asyncWriter->bytes_avail_ = kAsyncBufferSize; | |||
| asyncWriter->num_ele_ = 0; | |||
| asyncWriter->begin_addr_ = previous_end_addr; | |||
| asyncWriter->end_addr_ = previous_end_addr; | |||
| } | |||
| } | |||
| } else { | |||
| // Some clients are late and aren't done yet. Let go of the lock. | |||
| lock.Unlock(); | |||
| retry = true; | |||
| writer_wp_.Set(); | |||
| std::this_thread::yield(); | |||
| } | |||
| } while (retry); | |||
| // Wake up any writer that is waiting. | |||
| writer_wp_.Set(); | |||
| return flush_rc_; | |||
| } | |||
| Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t write_addr, int64_t sz, | |||
| const std::vector<ReadableSlice> &v) { | |||
| // Map our logical address to the real physical address in the buffer like where we start and | |||
| // where we end. | |||
| auto rel_write_addr = write_addr - begin_addr_; | |||
| auto rel_end_addr = rel_write_addr + sz; | |||
| // If not enough space, time to flush and swap. | |||
| if (rel_end_addr > kAsyncBufferSize) { | |||
| return Status(StatusCode::kNoSpace); | |||
| } | |||
| for (auto &p : v) { | |||
| auto write_sz = p.GetSize(); | |||
| WritableSlice dest(reinterpret_cast<char *>(buffer_) + rel_write_addr, write_sz); | |||
| RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p)); | |||
| bytes_avail_ -= write_sz; | |||
| rel_write_addr += write_sz; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rel_write_addr == rel_end_addr, "Programming error"); | |||
| ++num_ele_; | |||
| // Update the end_addr if ours is better | |||
| int64_t new_end_addr = write_addr + sz; | |||
| int64_t expected = end_addr_; | |||
| while (expected < new_end_addr) { | |||
| if (!end_addr_.compare_exchange_weak(expected, new_end_addr)) { | |||
| expected = end_addr_; | |||
| } | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(end_addr_ >= new_end_addr, "Programming error"); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::AsyncBufferStream::AsyncFlush() { | |||
| TaskManager::FindMe()->Post(); | |||
| Status rc; | |||
| do { | |||
| RETURN_IF_NOT_OK(flush_wp_.Wait()); | |||
| RETURN_IF_INTERRUPTED(); | |||
| rc = SyncFlush(); | |||
| // Other than resource error, all other error we quit. | |||
| } while (rc.IsOk() || rc.IsOutofMemory() || rc.IsNoSpace()); | |||
| // Make sure we wake up workers waiting for us. | |||
| writer_wp_.Set(); | |||
| return rc; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -38,6 +38,8 @@ | |||
| #include "minddata/dataset/util/lock.h" | |||
| #include "minddata/dataset/util/cond_var.h" | |||
| #include "minddata/dataset/util/queue_map.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -50,6 +52,7 @@ class CacheClient { | |||
| friend class CreateCacheRequest; | |||
| friend class CacheRowRequest; | |||
| friend class BatchFetchRequest; | |||
| friend class BatchCacheRowsRequest; | |||
| /// \brief A builder to help creating a CacheClient object | |||
| class Builder { | |||
| @@ -180,6 +183,11 @@ class CacheClient { | |||
| /// \return Status object | |||
| Status GetStat(CacheServiceStat *); | |||
| /// \brief Get the state of a cache server | |||
| /// \param[in/out] Pointer to a int8_t | |||
| /// \return Status object | |||
| Status GetState(int8_t *); | |||
| /// \brief Cache the schema at the cache server | |||
| /// \param map The unordered map of the schema | |||
| /// \return Status object | |||
| @@ -230,6 +238,7 @@ class CacheClient { | |||
| int32_t GetPort() const { return port_; } | |||
| int32_t GetNumConnections() const { return num_connections_; } | |||
| int32_t GetPrefetchSize() const { return prefetch_size_; } | |||
| int32_t GetClientId() const { return client_id_; } | |||
| /// MergeOp will notify us when the server can't cache any more rows. | |||
| /// We will stop any attempt to fetch any rows that are most likely | |||
| @@ -250,6 +259,20 @@ class CacheClient { | |||
| return false; | |||
| } | |||
| // Default size of the async write buffer | |||
| constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M | |||
| constexpr static int32_t kNumAsyncBuffer = 2; | |||
| /// Force a final flush to the cache server. Must be called when receving eoe. | |||
| Status FlushAsyncWriteBuffer() { | |||
| if (async_buffer_stream_) { | |||
| return async_buffer_stream_->SyncFlush(true); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in); | |||
| private: | |||
| mutable RWLock mux_; | |||
| uint64_t cache_mem_sz_; | |||
| @@ -288,6 +311,62 @@ class CacheClient { | |||
| std::set<row_id_type> gap_; | |||
| }; | |||
| std::unique_ptr<CacheMissKeys> cache_miss_keys_; | |||
| /// A data stream of back-to-back serialized tensor rows. | |||
| class AsyncBufferStream { | |||
| public: | |||
| AsyncBufferStream(); | |||
| ~AsyncBufferStream(); | |||
| /// \brief Initialize an Ascyn write buffer | |||
| Status Init(CacheClient *cc); | |||
| /// A worker will call the API AsyncWrite to put a TensorRow into the data stream. | |||
| /// A background thread will stream the data to the cache server. | |||
| /// The result of calling AsyncWrite is not immediate known or it can be the last | |||
| /// result of some previous flush. | |||
| /// \note Need to call SyncFlush to do the final flush. | |||
| Status AsyncWrite(const TensorRow &row); | |||
| Status SyncFlush(bool blocking = false); | |||
| /// This maps a physical shared memory to the data stream. | |||
| class AsyncWriter { | |||
| public: | |||
| friend class AsyncBufferStream; | |||
| Status Write(int64_t start_addr, int64_t sz, const std::vector<ReadableSlice> &v); | |||
| private: | |||
| std::shared_ptr<BatchCacheRowsRequest> rq; | |||
| void *buffer_; | |||
| int32_t num_ele_; // How many tensor rows in this buffer | |||
| int64_t begin_addr_; // Start of logical address of the data stream | |||
| std::atomic<int64_t> end_addr_; // End of the logical address of the data stream | |||
| std::atomic<int64_t> bytes_avail_; // Number of bytes remain | |||
| }; | |||
| /// \brief Release the shared memory during shutdown | |||
| /// /note but needs comm layer to be alive. | |||
| Status ReleaseBuffer(); | |||
| private: | |||
| Status flush_rc_; | |||
| WaitPost writer_wp_; | |||
| WaitPost flush_wp_; | |||
| RWLock mux_; | |||
| TaskGroup vg_; | |||
| CacheClient *cc_; | |||
| int64_t offset_addr_; | |||
| AsyncWriter buf_arr_[kNumAsyncBuffer]; | |||
| int32_t cur_; | |||
| std::atomic<int64_t> next_addr_; | |||
| /// \brief Entry point of the async flush thread. | |||
| Status AsyncFlush(); | |||
| }; | |||
| std::shared_ptr<AsyncBufferStream> async_buffer_stream_; | |||
| /// \brief Serialize a Tensor into the async buffer. | |||
| Status AsyncWriteRow(const TensorRow &row); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -37,6 +37,10 @@ namespace dataset { | |||
| /// For too small amount, we won't get any benefit using shared memory method because we need | |||
| /// two rpc requests to use shared memory method. | |||
| constexpr static int32_t kLocalByPassThreshold = 64 * 1024; | |||
| /// \brief Default size (in GB) of shared memory we are going to create | |||
| constexpr static int32_t kDefaultSharedMemorySize = 4; | |||
| /// \brief Memory Cap ratio used by the server | |||
| constexpr static float kDefaultMemoryCapRatio = 0.8; | |||
| /// \brief A flag used by the BatchFetch request (client side) if it can support local bypass | |||
| constexpr static uint32_t kLocalClientSupport = 1; | |||
| /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is | |||
| @@ -46,7 +50,15 @@ constexpr static uint32_t kDataIsInSharedMemory = 2; | |||
| constexpr static int32_t kSharedMessageSize = 2048; | |||
| /// \brief State of CacheService at the server. | |||
| enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; | |||
| enum class CacheServiceState : int8_t { | |||
| kNone = 0, | |||
| kBuildPhase = 1, | |||
| kFetchPhase = 2, | |||
| kNoLocking = 3, | |||
| kOutOfMemory = 4, | |||
| kNoSpace = 5, | |||
| kError = 127 | |||
| }; | |||
| /// \brief Convert a Status object into a protobuf | |||
| /// \param rc[in] Status object | |||
| @@ -23,8 +23,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) | |||
| : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) { | |||
| CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port) : port_(port) { | |||
| // Setup a path for unix socket. | |||
| unix_socket_ = PortToUnixSocketPath(port); | |||
| // We can't generate the ftok key yet until the unix_socket_ is created | |||
| @@ -70,14 +69,6 @@ Status CacheServerGreeterImpl::Run() { | |||
| server_ = builder.BuildAndStart(); | |||
| if (server_) { | |||
| MS_LOG(INFO) << "Server listening on " << server_address; | |||
| #if CACHE_LOCAL_CLIENT | |||
| RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | |||
| shm_key_ = shm_pool_->GetKey(); | |||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_; | |||
| auto cs = CacheServer::GetInstance().GetHWControl(); | |||
| // This shared memory is a hot memory and we will interleave among all the numa nodes. | |||
| cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); | |||
| #endif | |||
| } else { | |||
| std::string errMsg = "Fail to start server. "; | |||
| if (port_tcpip != port_) { | |||
| @@ -147,14 +138,18 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp | |||
| // Now we pass the address of this instance to CacheServer's main loop. | |||
| MS_LOG(DEBUG) << "Handle request " << *this; | |||
| // We will distribute the request evenly (or randomly) over all the numa nodes. | |||
| // The exception is BatchFetch which we need to pre-process here. | |||
| if (type_ == BaseRequest::RequestType::kBatchFetchRows) { | |||
| rc_ = cs.BatchFetchRows(&rq_, &reply_); | |||
| if (!rc_.IsInterrupted()) { | |||
| Status2CacheReply(rc_, &reply_); | |||
| st_ = CacheServerRequest::STATE::FINISH; | |||
| responder_.Finish(reply_, grpc::Status::OK, this); | |||
| } else { | |||
| // The exception is BatchFetch and BatchCache which we need to pre-process here. | |||
| // Also some requests are urgent that we want to process them here too. | |||
| if (type_ == BaseRequest::RequestType::kBatchFetchRows || type_ == BaseRequest::RequestType::kBatchCacheRows || | |||
| type_ == BaseRequest::RequestType::kStopService || type_ == BaseRequest::RequestType::kAllocateSharedBlock || | |||
| type_ == BaseRequest::RequestType::kFreeSharedBlock) { | |||
| cs.ProcessRequest(this); | |||
| // For cache_admin --stop, ProcessRequest is just acknowledging we receive the request. Now | |||
| // we call the real function. | |||
| if (type_ == BaseRequest::RequestType::kStopService) { | |||
| cs.GlobalShutdown(); | |||
| return Status(StatusCode::kInterrupted); | |||
| } else if (rc_.IsInterrupted()) { | |||
| return rc_; | |||
| } | |||
| } else { | |||
| @@ -191,10 +186,12 @@ Status CacheServerGreeterImpl::MonitorUnixSocket() { | |||
| // If the unix socket is recreated for whatever reason, this server instance will be stale and | |||
| // no other process and communicate with us. In this case we need to shutdown ourselves. | |||
| if (p.Exists()) { | |||
| auto &cs = CacheServer::GetInstance(); | |||
| SharedMemory::shm_key_t key; | |||
| RETURN_IF_NOT_OK(PortToFtok(port_, &key)); | |||
| if (key != shm_key_) { | |||
| std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) + | |||
| auto shm_key = cs.GetKey(); | |||
| if (key != shm_key) { | |||
| std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key) + | |||
| ". New key " + std::to_string(key) + ". Shutting down server"; | |||
| MS_LOG(ERROR) << errMsg; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| @@ -18,12 +18,14 @@ | |||
| #include <atomic> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/engine/cache/cache_arena.h" | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| @@ -75,7 +77,7 @@ class CacheServerGreeterImpl final { | |||
| friend class CacheServer; | |||
| public: | |||
| explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb); | |||
| explicit CacheServerGreeterImpl(int32_t port); | |||
| virtual ~CacheServerGreeterImpl(); | |||
| /// \brief Brings up gRPC server | |||
| /// \return none | |||
| @@ -83,24 +85,18 @@ class CacheServerGreeterImpl final { | |||
| /// \brief Entry function to handle cache server request | |||
| Status HandleRequest(int32_t worker_id); | |||
| /// Return the shared memory pool. | |||
| /// \return Return the shared memory pool | |||
| CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } | |||
| /// \brief Montor the status of the unix socket in case it is gone. | |||
| Status MonitorUnixSocket(); | |||
| /// \brief This shutdown down the comm layer | |||
| void Shutdown(); | |||
| private: | |||
| int32_t port_; | |||
| size_t shm_pool_sz_in_gb_; | |||
| std::string unix_socket_; | |||
| CacheServerGreeter::AsyncService svc_; | |||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_; | |||
| std::unique_ptr<grpc::Server> server_; | |||
| std::unique_ptr<CachedSharedMemoryArena> shm_pool_; | |||
| SharedMemory::shm_key_t shm_key_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -209,6 +209,14 @@ void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) { | |||
| #endif | |||
| } | |||
| void CacheServerHW::AssignToNode(numa_id_t numa_id, void *ptr, size_t sz) { | |||
| #ifdef NUMA_ENABLED | |||
| if (numa_enabled()) { | |||
| numa_tonode_memory(ptr, sz, numa_id); | |||
| } | |||
| #endif | |||
| } | |||
| bool CacheServerHW::numa_enabled() { | |||
| #ifdef NUMA_ENABLED | |||
| return (numa_available() != -1); | |||
| @@ -63,6 +63,9 @@ class CacheServerHW { | |||
| /// \brief Interleave a given memory block. Used by shared memory only. | |||
| static void InterleaveMemory(void *ptr, size_t sz); | |||
| /// \brief Assign a given memory block to a numa node. Used by shared memory only. | |||
| void AssignToNode(numa_id_t numa_id, void *ptr, size_t sz); | |||
| /// \brief Set default memory policy. | |||
| static Status SetDefaultMemoryPolicy(CachePoolPolicy); | |||
| @@ -53,7 +53,7 @@ Status SharedMessage::Create() { | |||
| Status SharedMessage::SendStatus(const Status &rc) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); | |||
| StatusMsgBuf msg{ | |||
| CacheMsgBuf msg{ | |||
| 1, | |||
| }; | |||
| msg.body.status.err_code = static_cast<int32_t>(rc.get_code()); | |||
| @@ -71,7 +71,7 @@ Status SharedMessage::SendStatus(const Status &rc) { | |||
| Status SharedMessage::ReceiveStatus(Status *rc) { | |||
| RETURN_UNEXPECTED_IF_NULL(rc); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); | |||
| struct StatusMsgBuf msg {}; | |||
| struct CacheMsgBuf msg {}; | |||
| auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR); | |||
| if (err == -1) { | |||
| std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); | |||
| @@ -28,7 +28,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// A message queue structure between the parent and the child process | |||
| struct StatusMsgBuf { | |||
| struct CacheMsgBuf { | |||
| int64_t mtype; | |||
| union { | |||
| char mtext[1]; | |||
| @@ -168,12 +168,14 @@ Path CachePool::GetSpillPath() const { | |||
| } | |||
| CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||
| tree_->LockShared(); // Prevent any node split while we search. | |||
| CacheStat cs{-1, -1, 0, 0, 0, 0}; | |||
| int64_t total_sz = 0; | |||
| if (tree_->begin() != tree_->end()) { | |||
| cs.min_key = tree_->begin().key(); | |||
| cs.max_key = cs.min_key; // will adjust later. | |||
| for (auto it = tree_->begin(); it != tree_->end(); ++it) { | |||
| it.LockShared(); | |||
| total_sz += it.value().sz; | |||
| if (it.value().ptr != nullptr) { | |||
| ++cs.num_mem_cached; | |||
| @@ -190,6 +192,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||
| } | |||
| } | |||
| cs.max_key = cur_key; | |||
| it.Unlock(); | |||
| } | |||
| } | |||
| if (total_sz > 0) { | |||
| @@ -199,6 +202,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||
| cs.average_cache_sz = 1; | |||
| } | |||
| } | |||
| tree_->Unlock(); | |||
| return cs; | |||
| } | |||
| @@ -58,7 +58,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te | |||
| if (sent_using_local_bypass) { | |||
| MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data"; | |||
| // Allocate shared memory from the server | |||
| auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), sz_); | |||
| auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), cc->GetClientId(), sz_); | |||
| RETURN_IF_NOT_OK(cc->PushRequest(mem_rq)); | |||
| RETURN_IF_NOT_OK(mem_rq->Wait()); | |||
| addr_ = mem_rq->GetAddr(); | |||
| @@ -305,6 +305,15 @@ Status GetStatRequest::PostReply() { | |||
| return Status::OK(); | |||
| } | |||
| Status GetCacheStateRequest::PostReply() { | |||
| try { | |||
| cache_service_state_ = std::stoi(reply_.result()); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ListSessionsRequest::PostReply() { | |||
| auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data()); | |||
| auto session_vector = msg->sessions(); | |||
| @@ -333,5 +342,13 @@ Status ServerStopRequest::PostReply() { | |||
| return Status::OK(); | |||
| } | |||
| BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele) | |||
| : BaseRequest(RequestType::kBatchCacheRows) { | |||
| rq_.set_connection_id(cc->server_connection_id_); | |||
| rq_.set_client_id(cc->client_id_); | |||
| rq_.add_buf_data(cc->cookie()); | |||
| rq_.add_buf_data(std::to_string(addr)); | |||
| rq_.add_buf_data(std::to_string(num_ele)); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -78,6 +78,9 @@ class BaseRequest { | |||
| kListSessions = 16, | |||
| kConnectReset = 17, | |||
| kInternalFetchRow = 18, | |||
| kBatchCacheRows = 19, | |||
| kInternalCacheRow = 20, | |||
| kGetCacheState = 21, | |||
| // Add new request before it. | |||
| kRequestUnknown = 32767 | |||
| }; | |||
| @@ -133,10 +136,11 @@ class BaseRequest { | |||
| class FreeSharedBlockRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr) | |||
| explicit FreeSharedBlockRequest(connection_id_type connection_id, int32_t client_id, int64_t addr) | |||
| : BaseRequest(RequestType::kFreeSharedBlock) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(std::to_string(addr)); | |||
| rq_.set_client_id(client_id); | |||
| } | |||
| ~FreeSharedBlockRequest() override = default; | |||
| }; | |||
| @@ -178,7 +182,7 @@ class CacheRowRequest : public BaseRequest { | |||
| /// the shared memory by sending another request. The following function will generate a suitable | |||
| /// request for the CacheClient to send. | |||
| std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() { | |||
| return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), addr_); | |||
| return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), rq_.client_id(), addr_); | |||
| } | |||
| private: | |||
| @@ -271,6 +275,24 @@ class GetStatRequest : public BaseRequest { | |||
| CacheServiceStat stat_{}; | |||
| }; | |||
| /// \brief Get the state of a cache service | |||
| class GetCacheStateRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit GetCacheStateRequest(connection_id_type connection_id) | |||
| : BaseRequest(RequestType::kGetCacheState), cache_service_state_(0) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~GetCacheStateRequest() override = default; | |||
| Status PostReply() override; | |||
| auto GetState() const { return cache_service_state_; } | |||
| private: | |||
| int8_t cache_service_state_; | |||
| }; | |||
| /// \brief Request to cache a schema | |||
| class CacheSchemaRequest : public BaseRequest { | |||
| public: | |||
| @@ -367,10 +389,11 @@ class ListSessionsRequest : public BaseRequest { | |||
| class AllocateSharedBlockRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz) | |||
| explicit AllocateSharedBlockRequest(connection_id_type connection_id, int32_t client_id, size_t requestedSz) | |||
| : BaseRequest(RequestType::kAllocateSharedBlock) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(std::to_string(requestedSz)); | |||
| rq_.set_client_id(client_id); | |||
| } | |||
| ~AllocateSharedBlockRequest() override = default; | |||
| @@ -420,6 +443,13 @@ class ConnectResetRequest : public BaseRequest { | |||
| return Status::OK(); | |||
| } | |||
| }; | |||
| class BatchCacheRowsRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele); | |||
| ~BatchCacheRowsRequest() override = default; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ | |||
| @@ -106,14 +106,18 @@ Status CacheServer::DoServiceStart() { | |||
| RETURN_IF_NOT_OK(free_list_->Register(&vg_)); | |||
| // Start the comm layer | |||
| try { | |||
| comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | |||
| comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_); | |||
| RETURN_IF_NOT_OK(comm_layer_->Run()); | |||
| // Bring up a thread to monitor the unix socket in case it is removed. | |||
| auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get()); | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f)); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| #if CACHE_LOCAL_CLIENT | |||
| RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_)); | |||
| // Bring up a thread to monitor the unix socket in case it is removed. But it must be done | |||
| // after we have created the unix socket. | |||
| auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get()); | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f)); | |||
| #endif | |||
| // Spawn a few threads to serve the real request. | |||
| auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); | |||
| for (auto i = 0; i < num_workers_; ++i) { | |||
| @@ -350,11 +354,12 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| auto client_id = rq->client_id(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set"); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| auto *base = SharedMemoryBaseAddr(); | |||
| // Ensure we got 3 pieces of data coming in | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data"); | |||
| // First piece of data is the cookie and is required | |||
| @@ -381,8 +386,10 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); | |||
| } | |||
| } | |||
| // Return the block to the shared memory. | |||
| shared_pool->Deallocate(p); | |||
| // Return the block to the shared memory only if it is not internal request. | |||
| if (static_cast<BaseRequest::RequestType>(rq->type()) == BaseRequest::RequestType::kCacheRow) { | |||
| DeallocateSharedMemory(client_id, p); | |||
| } | |||
| return rc; | |||
| } | |||
| @@ -450,6 +457,7 @@ Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuil | |||
| Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| auto client_id = rq->client_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| @@ -490,14 +498,13 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0); | |||
| if (local_bypass) { | |||
| // We will use shared memory | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| auto *base = SharedMemoryBaseAddr(); | |||
| void *q = nullptr; | |||
| RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | |||
| RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, mem_sz, &q)); | |||
| WritableSlice dest(q, mem_sz); | |||
| Status rc = BatchFetch(fbb, &dest); | |||
| if (rc.IsError()) { | |||
| shared_pool->Deallocate(q); | |||
| DeallocateSharedMemory(client_id, q); | |||
| return rc; | |||
| } | |||
| // We can't return the absolute address which makes no sense to the client. | |||
| @@ -597,7 +604,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) { | |||
| // First piece of data is the cookie | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); | |||
| auto &cookie = rq->buf_data(0); | |||
| // We can only allow to switch phase is the cookie match. | |||
| // We can only allow to switch phase if the cookie match. | |||
| if (cookie == cs->cookie()) { | |||
| RETURN_IF_NOT_OK(cs->BuildPhaseDone()); | |||
| } else { | |||
| @@ -713,6 +720,203 @@ Status CacheServer::ConnectReset(CacheRequest *rq) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::BatchCacheRows(CacheRequest *rq) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == 3, "Expect three pieces of data"); | |||
| int32_t numQ = GetNumGrpcWorkers(); | |||
| auto rng = GetRandomDevice(); | |||
| std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1); | |||
| int32_t qID = distribution(rng); | |||
| std::vector<CacheServerRequest *> cache_rq_list; | |||
| try { | |||
| auto &cookie = rq->buf_data(0); | |||
| auto connection_id = rq->connection_id(); | |||
| auto client_id = rq->client_id(); | |||
| int64_t offset_addr; | |||
| int32_t num_elem; | |||
| auto *base = SharedMemoryBaseAddr(); | |||
| offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10); | |||
| auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr); | |||
| num_elem = strtol(rq->buf_data(2).data(), nullptr, 10); | |||
| cache_rq_list.reserve(num_elem); | |||
| // Get a set of free request and push into the queues. | |||
| for (auto i = 0; i < num_elem; ++i) { | |||
| auto start = reinterpret_cast<int64_t>(p); | |||
| auto msg = GetTensorRowHeaderMsg(p); | |||
| p += msg->size_of_this(); | |||
| for (auto k = 0; k < msg->column()->size(); ++k) { | |||
| p += msg->data_sz()->Get(k); | |||
| } | |||
| CacheServerRequest *cache_rq; | |||
| RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq)); | |||
| cache_rq_list.push_back(cache_rq); | |||
| // Fill in details. | |||
| cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow; | |||
| cache_rq->st_ = CacheServerRequest::STATE::PROCESS; | |||
| cache_rq->rq_.set_connection_id(connection_id); | |||
| cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_)); | |||
| cache_rq->rq_.set_client_id(client_id); | |||
| cache_rq->rq_.set_flag(kDataIsInSharedMemory); | |||
| cache_rq->rq_.add_buf_data(cookie); | |||
| cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast<int64_t>(base))); | |||
| cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(p - start))); | |||
| RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq)); | |||
| } | |||
| // Now wait for all of them to come back. | |||
| Status rc; | |||
| for (CacheServerRequest *cache_rq : cache_rq_list) { | |||
| RETURN_IF_NOT_OK(cache_rq->Wait()); | |||
| if (cache_rq->rc_.IsError() && !cache_rq->rc_.IsInterrupted() && rc.IsOk()) { | |||
| rc = cache_rq->rc_; | |||
| } | |||
| RETURN_IF_NOT_OK(ReturnRequestTag(cache_rq)); | |||
| } | |||
| return rc; | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| void CacheServer::ProcessRequest(CacheServerRequest *cache_req) { | |||
| bool internal_request = false; | |||
| auto &rq = cache_req->rq_; | |||
| auto &reply = cache_req->reply_; | |||
| // Except for creating a new session, we expect cs is not null. | |||
| switch (cache_req->type_) { | |||
| case BaseRequest::RequestType::kCacheRow: | |||
| case BaseRequest::RequestType::kInternalCacheRow: { | |||
| // Look into the flag to see where we can find the data and | |||
| // call the appropriate method. | |||
| auto flag = rq.flag(); | |||
| if (BitTest(flag, kDataIsInSharedMemory)) { | |||
| cache_req->rc_ = FastCacheRow(&rq, &reply); | |||
| internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow); | |||
| } else { | |||
| cache_req->rc_ = CacheRow(&rq, &reply); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBatchCacheRows: { | |||
| cache_req->rc_ = BatchCacheRows(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBatchFetchRows: { | |||
| cache_req->rc_ = BatchFetchRows(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kInternalFetchRow: { | |||
| internal_request = true; | |||
| auto connection_id = rq.connection_id(); | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data())); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCreateCache: { | |||
| cache_req->rc_ = CreateService(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetCacheMissKeys: { | |||
| cache_req->rc_ = GetCacheMissKeys(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDestroyCache: { | |||
| cache_req->rc_ = DestroyCache(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetStat: { | |||
| cache_req->rc_ = GetStat(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCacheSchema: { | |||
| cache_req->rc_ = CacheSchema(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kFetchSchema: { | |||
| cache_req->rc_ = FetchSchema(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBuildPhaseDone: { | |||
| cache_req->rc_ = BuildPhaseDone(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDropSession: { | |||
| cache_req->rc_ = DestroySession(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGenerateSessionId: { | |||
| cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kAllocateSharedBlock: { | |||
| cache_req->rc_ = AllocateSharedMemory(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kFreeSharedBlock: { | |||
| cache_req->rc_ = FreeSharedMemory(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kStopService: { | |||
| // This command shutdowns everything. | |||
| // But we first reply back to the client that we receive the request. | |||
| // The real shutdown work will be done by the caller. | |||
| cache_req->rc_ = AcknowledgeShutdown(cache_req); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kHeartBeat: { | |||
| cache_req->rc_ = Status::OK(); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kToggleWriteMode: { | |||
| cache_req->rc_ = ToggleWriteMode(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kListSessions: { | |||
| cache_req->rc_ = ListSessions(&reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kConnectReset: { | |||
| cache_req->rc_ = ConnectReset(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetCacheState: { | |||
| auto connection_id = rq.connection_id(); | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto state = cs->GetState(); | |||
| reply.set_result(std::to_string(static_cast<int8_t>(state))); | |||
| cache_req->rc_ = Status::OK(); | |||
| } | |||
| break; | |||
| } | |||
| default: | |||
| std::string errMsg("Unknown request type : "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| // Notify it is done, and move on to the next request. | |||
| Status2CacheReply(cache_req->rc_, &reply); | |||
| cache_req->st_ = CacheServerRequest::STATE::FINISH; | |||
| // We will re-tag the request back to the grpc queue. Once it comes back from the client, | |||
| // the CacheServerRequest, i.e. the pointer cache_req, will be free | |||
| if (!internal_request) { | |||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||
| } else { | |||
| // This is an internal request and is not tied to rpc. But need to post because there | |||
| // is a thread waiting on the completion of this request. | |||
| cache_req->wp_.Set(); | |||
| } | |||
| } | |||
| /// \brief This is the main loop the cache server thread(s) are running. | |||
| /// Each thread will pop a request and send the result back to the client using grpc | |||
| /// \return | |||
| @@ -722,121 +926,9 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) { | |||
| auto &my_que = cache_q_->operator[](worker_id); | |||
| // Loop forever until we are interrupted or shutdown. | |||
| while (!global_shutdown_) { | |||
| bool internal_request = false; | |||
| CacheServerRequest *cache_req = nullptr; | |||
| RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | |||
| auto &rq = cache_req->rq_; | |||
| auto &reply = cache_req->reply_; | |||
| // Except for creating a new session, we expect cs is not null. | |||
| switch (cache_req->type_) { | |||
| case BaseRequest::RequestType::kCacheRow: { | |||
| // Look into the flag to see where we can find the data and | |||
| // call the appropriate method. | |||
| auto flag = rq.flag(); | |||
| if (BitTest(flag, kDataIsInSharedMemory)) { | |||
| cache_req->rc_ = FastCacheRow(&rq, &reply); | |||
| } else { | |||
| cache_req->rc_ = CacheRow(&rq, &reply); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kInternalFetchRow: { | |||
| internal_request = true; | |||
| auto connection_id = rq.connection_id(); | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data())); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCreateCache: { | |||
| cache_req->rc_ = CreateService(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetCacheMissKeys: { | |||
| cache_req->rc_ = GetCacheMissKeys(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDestroyCache: { | |||
| cache_req->rc_ = DestroyCache(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetStat: { | |||
| cache_req->rc_ = GetStat(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCacheSchema: { | |||
| cache_req->rc_ = CacheSchema(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kFetchSchema: { | |||
| cache_req->rc_ = FetchSchema(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBuildPhaseDone: { | |||
| cache_req->rc_ = BuildPhaseDone(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDropSession: { | |||
| cache_req->rc_ = DestroySession(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGenerateSessionId: { | |||
| cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kAllocateSharedBlock: { | |||
| cache_req->rc_ = AllocateSharedMemory(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kFreeSharedBlock: { | |||
| cache_req->rc_ = FreeSharedMemory(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kStopService: { | |||
| // This command shutdowns everything. | |||
| cache_req->rc_ = GlobalShutdown(cache_req); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kHeartBeat: { | |||
| cache_req->rc_ = Status::OK(); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kToggleWriteMode: { | |||
| cache_req->rc_ = ToggleWriteMode(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kListSessions: { | |||
| cache_req->rc_ = ListSessions(&reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kConnectReset: { | |||
| cache_req->rc_ = ConnectReset(&rq); | |||
| break; | |||
| } | |||
| default: | |||
| std::string errMsg("Unknown request type : "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| // Notify it is done, and move on to the next request. | |||
| Status2CacheReply(cache_req->rc_, &reply); | |||
| cache_req->st_ = CacheServerRequest::STATE::FINISH; | |||
| // We will re-tag the request back to the grpc queue. Once it comes back from the client, | |||
| // the CacheServerRequest, i.e. the pointer cache_req, will be free | |||
| if (!global_shutdown_) { | |||
| if (!internal_request) { | |||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||
| } else { | |||
| // This is an internal request and is not tied to rpc. But need to post because there | |||
| // is a thread waiting on the completion of this request. | |||
| cache_req->wp_.Set(); | |||
| } | |||
| } | |||
| ProcessRequest(cache_req); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -869,6 +961,11 @@ CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int | |||
| MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build " | |||
| "that is compiled with numa support for more optimal performance"; | |||
| } | |||
| // We create the shared memory and we will destroy it. All other client just detach only. | |||
| if (shared_memory_sz_in_gb_ > kDefaultSharedMemorySize) { | |||
| MS_LOG(INFO) << "Shared memory size is readjust to " << kDefaultSharedMemorySize << " GB."; | |||
| shared_memory_sz_in_gb_ = kDefaultSharedMemorySize; | |||
| } | |||
| } | |||
| Status CacheServer::Run(int msg_qid) { | |||
| @@ -965,24 +1062,34 @@ session_id_type CacheServer::GenerateSessionID() { | |||
| } | |||
| Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) { | |||
| auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10); | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| void *p = nullptr; | |||
| RETURN_IF_NOT_OK(shared_pool->Allocate(requestedSz, &p)); | |||
| // We can't return the absolute address which makes no sense to the client. | |||
| // Instead we return the difference. | |||
| auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base); | |||
| reply->set_result(std::to_string(difference)); | |||
| auto client_id = rq->client_id(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set"); | |||
| try { | |||
| auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10); | |||
| void *p = nullptr; | |||
| RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p)); | |||
| auto *base = SharedMemoryBaseAddr(); | |||
| // We can't return the absolute address which makes no sense to the client. | |||
| // Instead we return the difference. | |||
| auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base); | |||
| reply->set_result(std::to_string(difference)); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::FreeSharedMemory(CacheRequest *rq) { | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10); | |||
| auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr); | |||
| shared_pool->Deallocate(p); | |||
| auto client_id = rq->client_id(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set"); | |||
| auto *base = SharedMemoryBaseAddr(); | |||
| try { | |||
| auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10); | |||
| auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr); | |||
| DeallocateSharedMemory(client_id, p); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -992,7 +1099,7 @@ Status CacheServer::RpcRequest(worker_id_t worker_id) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { | |||
| Status CacheServer::AcknowledgeShutdown(CacheServerRequest *cache_req) { | |||
| auto *rq = &cache_req->rq_; | |||
| auto *reply = &cache_req->reply_; | |||
| if (!rq->buf_data().empty()) { | |||
| @@ -1008,9 +1115,10 @@ Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { | |||
| } | |||
| } | |||
| reply->set_result("OK"); | |||
| Status2CacheReply(cache_req->rc_, reply); | |||
| cache_req->st_ = CacheServerRequest::STATE::FINISH; | |||
| cache_req->responder_.Finish(*reply, grpc::Status::OK, cache_req); | |||
| return Status::OK(); | |||
| } | |||
| void CacheServer::GlobalShutdown() { | |||
| // Let's shutdown in proper order. | |||
| bool expected = false; | |||
| if (global_shutdown_.compare_exchange_strong(expected, true)) { | |||
| @@ -1032,7 +1140,6 @@ Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { | |||
| it = all_caches_.erase(it); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const { | |||
| @@ -1053,6 +1160,12 @@ worker_id_t CacheServer::GetRandomWorker() const { | |||
| return dist(gen); | |||
| } | |||
| Status CacheServer::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) { | |||
| return shm_->AllocateSharedMemory(client_id, sz, p); | |||
| } | |||
| void CacheServer::DeallocateSharedMemory(int32_t client_id, void *p) { shm_->DeallocateSharedMemory(client_id, p); } | |||
| Status CacheServer::Builder::IpcResourceCleanup() { | |||
| Status rc; | |||
| SharedMemory::shm_key_t shm_key; | |||
| @@ -1124,8 +1237,8 @@ CacheServer::Builder::Builder() | |||
| : top_("/tmp"), | |||
| num_workers_(std::thread::hardware_concurrency() / 2), | |||
| port_(50052), | |||
| shared_memory_sz_in_gb_(4), | |||
| memory_cap_ratio_(0.8) { | |||
| shared_memory_sz_in_gb_(kDefaultSharedMemorySize), | |||
| memory_cap_ratio_(kDefaultMemoryCapRatio) { | |||
| if (num_workers_ == 0) { | |||
| num_workers_ = 1; | |||
| } | |||
| @@ -25,12 +25,14 @@ | |||
| #include <chrono> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <set> | |||
| #include <thread> | |||
| #include "minddata/dataset/engine/cache/cache_arena.h" | |||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| @@ -196,15 +198,31 @@ class CacheServer : public Service { | |||
| /// \brief Check if we bind threads to numa cores | |||
| bool IsNumaAffinityOn() const { return numa_affinity_; } | |||
| /// \brief Internal function to do row batch fetch | |||
| /// \param rq Request | |||
| /// \param reply Reply | |||
| /// \return Status object | |||
| Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Return the memory cap ratio | |||
| float GetMemoryCapRatio() const { return memory_cap_ratio_; } | |||
| /// \brief How a request is handled. | |||
| /// \note that it can be process immediately by a grpc thread or routed to a server thread | |||
| /// which is pinned to some numa node core. | |||
| void ProcessRequest(CacheServerRequest *cache_req); | |||
| void GlobalShutdown(); | |||
| /// \brief This returns where we attach to the shared memory. | |||
| /// Some gRPC requests will ask for a shared memory block, and | |||
| /// we can't return the absolute address as this makes no sense | |||
| /// in the client. So instead we will return an address relative | |||
| /// to the base address of the shared memory where we attach to. | |||
| /// \return Base address of the shared memory. | |||
| const void *SharedMemoryBaseAddr() const { return shm_->SharedMemoryBaseAddr(); } | |||
| /// \brief Return the public key of the shared memory. | |||
| int32_t GetKey() const { return shm_->GetKey(); } | |||
| Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p); | |||
| void DeallocateSharedMemory(int32_t client_id, void *p); | |||
| private: | |||
| static std::once_flag init_instance_flag_; | |||
| static CacheServer *instance_; | |||
| @@ -228,6 +246,7 @@ class CacheServer : public Service { | |||
| std::map<worker_id_t, Task *> numa_tasks_; | |||
| bool numa_affinity_; | |||
| std::vector<int32_t> shutdown_qIDs_; | |||
| std::unique_ptr<CachedSharedMemory> shm_; | |||
| /// \brief Constructor | |||
| /// \param spill_path Top directory for spilling buffers to. | |||
| @@ -315,7 +334,7 @@ class CacheServer : public Service { | |||
| /// \brief A proper shutdown of the server | |||
| /// \return Status object | |||
| Status GlobalShutdown(CacheServerRequest *); | |||
| Status AcknowledgeShutdown(CacheServerRequest *cache_req); | |||
| /// \brief Find keys that will be cache miss | |||
| /// \return Status object | |||
| @@ -332,12 +351,19 @@ class CacheServer : public Service { | |||
| /// \brief Connect request by a pipeline | |||
| Status ConnectReset(CacheRequest *rq); | |||
| /// \brief Internal function to do row batch fetch | |||
| /// \param rq Request | |||
| /// \param reply Reply | |||
| /// \return Status object | |||
| Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | |||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | |||
| /// \param[in] v A vector of row id. | |||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | |||
| /// \return Status object | |||
| Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out); | |||
| Status BatchCacheRows(CacheRequest *rq); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -68,10 +68,11 @@ Status CacheService::DoServiceStop() { | |||
| Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | |||
| if (st_ == CacheServiceState::kFetchPhase) { | |||
| if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) { | |||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " + | |||
| std::to_string(static_cast<int>(st_.load()))); | |||
| } | |||
| if (st_ == CacheServiceState::kNoLocking) { | |||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | |||
| @@ -119,6 +120,16 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | |||
| } else { | |||
| if (HasBuildPhase()) { | |||
| // For cache service that has a build phase, record the error in the state | |||
| // so other clients can be aware of the new state. There is nothing one can | |||
| // do to resume other than to drop the cache. | |||
| if (rc.IsNoSpace()) { | |||
| st_ = CacheServiceState::kNoSpace; | |||
| } else if (rc.IsOutofMemory()) { | |||
| st_ = CacheServiceState::kOutOfMemory; | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| return Status::OK(); | |||
| @@ -130,10 +141,11 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | |||
| if (st_ == CacheServiceState::kFetchPhase) { | |||
| if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) { | |||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " + | |||
| std::to_string(static_cast<int>(st_.load()))); | |||
| } | |||
| if (st_ == CacheServiceState::kNoLocking) { | |||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | |||
| @@ -161,6 +173,16 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ | |||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | |||
| } else { | |||
| if (HasBuildPhase()) { | |||
| // For cache service that has a build phase, record the error in the state | |||
| // so other clients can be aware of the new state. There is nothing one can | |||
| // do to resume other than to drop the cache. | |||
| if (rc.IsNoSpace()) { | |||
| st_ = CacheServiceState::kNoSpace; | |||
| } else if (rc.IsOutofMemory()) { | |||
| st_ = CacheServiceState::kOutOfMemory; | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| return Status::OK(); | |||
| @@ -202,16 +224,17 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| out->stat_ = cp_->GetStat(); | |||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||
| out->state_ = static_cast<ServiceStat::state_type>(st_.load()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | |||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == CacheServiceState::kBuildPhase) { | |||
| if (HasBuildPhase() && st_ != CacheServiceState::kFetchPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " + | |||
| std::to_string(static_cast<int>(st_.load()))); | |||
| } | |||
| std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v; | |||
| datalocator_v.reserve(v.size()); | |||
| @@ -271,7 +294,8 @@ Status CacheService::FetchSchema(std::string *out) const { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == CacheServiceState::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " + | |||
| std::to_string(static_cast<int>(st_.load()))); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| // We are going to use std::string to allocate and hold the result which will be eventually | |||
| @@ -292,6 +316,7 @@ Status CacheService::BuildPhaseDone() { | |||
| UniqueLock rw(&rw_lock_); | |||
| st_ = CacheServiceState::kFetchPhase; | |||
| cp_->SetLocking(false); | |||
| MS_LOG(WARNING) << "Locking mode is switched off."; | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); | |||
| @@ -91,6 +91,8 @@ class CacheService : public Service { | |||
| /// \param[in/out] A pointer to a pre-allocated ServiceStat structure | |||
| /// \return Status Object | |||
| Status GetStat(ServiceStat *); | |||
| /// \brief Return the current state | |||
| CacheServiceState GetState() const { return st_.load(); } | |||
| /// \brief Cache schema | |||
| /// \param buf A Google Flatbuffer that contains the schema | |||
| /// \param len size of the buffer | |||
| @@ -131,7 +133,7 @@ class CacheService : public Service { | |||
| bool generate_id_; | |||
| std::string cookie_; | |||
| std::atomic<int32_t> num_clients_; | |||
| CacheServiceState st_; | |||
| std::atomic<CacheServiceState> st_; | |||
| std::string schema_; | |||
| std::shared_ptr<NumaMemoryPool> numa_pool_; | |||
| // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make | |||
| @@ -427,6 +427,7 @@ Status CachePerfRun::Run() { | |||
| } | |||
| // Now we create the children knowing all two sets of message queues are constructed. | |||
| auto start_tick = std::chrono::steady_clock::now(); | |||
| for (auto i = 0; i < num_pipelines_; ++i) { | |||
| auto pid = fork(); | |||
| if (pid == 0) { | |||
| @@ -502,6 +503,10 @@ Status CachePerfRun::Run() { | |||
| // Wait until all pipelines finish the first epoch. | |||
| RETURN_IF_NOT_OK(pipeline_wp_.Wait()); | |||
| auto end_tick = std::chrono::steady_clock::now(); | |||
| int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count(); | |||
| std::cout << "Epoch one (build phase) elapsed time " << elapse_time << " seconds" << std::endl; | |||
| std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer() | |||
| << std::endl; | |||
| @@ -543,6 +548,7 @@ Status CachePerfRun::Run() { | |||
| epoch_sync_cnt_ = 0; | |||
| pipeline_wp_.Clear(); | |||
| epoch_results_.clear(); | |||
| start_tick = std::chrono::steady_clock::now(); | |||
| // Signal each pipeline to start | |||
| for (auto msg_qid : msg_send_lists_) { | |||
| CachePerfMsg msg; | |||
| @@ -551,6 +557,9 @@ Status CachePerfRun::Run() { | |||
| } | |||
| // Wait for the child to finish | |||
| RETURN_IF_NOT_OK(pipeline_wp_.Wait()); | |||
| end_tick = std::chrono::steady_clock::now(); | |||
| elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count(); | |||
| std::cout << "Epoch " << epoch_num << " elapsed time " << elapse_time << " seconds" << std::endl; | |||
| std::cout << "Epoch " << epoch_num | |||
| << " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl; | |||
| PrintEpochSummary(); | |||
| @@ -238,6 +238,9 @@ Status CachePipelineRun::RunFirstEpoch() { | |||
| RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking)); | |||
| } | |||
| // Final flush | |||
| cc_->FlushAsyncWriteBuffer(); | |||
| // Send a message saying epoch one done for this pipeline. | |||
| EpochDone proto; | |||
| proto.set_pipeline(my_pipeline_); | |||
| @@ -291,7 +294,7 @@ Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) { | |||
| buffer->set_tensor_table(std::move(tensor_table)); | |||
| // Measure the time to call WriteBuffer | |||
| auto start_tick = std::chrono::steady_clock::now(); | |||
| rc = cc_->WriteBuffer(std::move(buffer)); | |||
| rc = cc_->AsyncWriteBuffer(std::move(buffer)); | |||
| auto end_tick = std::chrono::steady_clock::now(); | |||
| if (rc.IsError()) { | |||
| if (rc.IsOutofMemory() || rc.IsNoSpace()) { | |||
| @@ -122,6 +122,17 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||
| if (db_ptr->eoe()) { | |||
| // Ignore it. | |||
| MS_LOG(DEBUG) << "Ignore eoe"; | |||
| // However we need to flush any left over from the async write buffer. But any error | |||
| // we are getting will just to stop caching but the pipeline will continue | |||
| Status rc; | |||
| if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) { | |||
| cache_missing_rows_ = false; | |||
| if (rc.IsOutofMemory() || rc.IsNoSpace()) { | |||
| cache_client_->ServerRunningOutOfResources(); | |||
| } else { | |||
| MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString(); | |||
| } | |||
| } | |||
| } else { | |||
| while (db_ptr->NumRows() > 0) { | |||
| TensorRow row; | |||
| @@ -143,6 +154,9 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||
| rc = rq->AsyncSendCacheRequest(cache_client_, row); | |||
| if (rc.IsOk()) { | |||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||
| } else if (rc.IsOutofMemory() || rc.IsNoSpace()) { | |||
| cache_missing_rows_ = false; | |||
| cache_client_->ServerRunningOutOfResources(); | |||
| } | |||
| } | |||
| } | |||
| @@ -309,17 +323,25 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha | |||
| if (st_.compare_exchange_strong(expected, State::kDirty)) { | |||
| // We will do a deep copy but write directly into CacheRequest protobuf or shared memory | |||
| Status rc; | |||
| cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get()); | |||
| rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); | |||
| if (rc.IsOk()) { | |||
| // Send the request async. The cleaner will check the return code. | |||
| rc = cc->PushRequest(cleaner_copy_); | |||
| rc = cc->AsyncWriteRow(row); | |||
| if (rc.get_code() == StatusCode::kNotImplementedYet) { | |||
| cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get()); | |||
| rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); | |||
| if (rc.IsOk()) { | |||
| // Send the request async. The cleaner will check the return code. | |||
| rc = cc->PushRequest(cleaner_copy_); | |||
| } | |||
| } else if (rc.IsOk()) { | |||
| // Set the state to clean even though it still sits in the cache client async buffer. | |||
| // The cleaner will then ignore it once the state is clean. | |||
| st_ = State::kClean; | |||
| } | |||
| if (rc.IsError()) { | |||
| // Clean up the shared pointer and reset the state back to empty | |||
| cleaner_copy_.reset(); | |||
| st_ = State::kEmpty; | |||
| } | |||
| return rc; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -109,7 +109,14 @@ Status CacheOp::CacheAllRows(int32_t worker_id) { | |||
| RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); | |||
| while (!db_ptr->eof()) { | |||
| if (!db_ptr->eoe()) { | |||
| RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); | |||
| Status rc; | |||
| // Do the Async write if we attach to the shared memory. | |||
| rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr)); | |||
| if (rc.get_code() == StatusCode::kNotImplementedYet) { | |||
| RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); | |||
| } else if (rc.IsError()) { | |||
| return rc; | |||
| } | |||
| } else { | |||
| // In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up | |||
| // as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the | |||
| @@ -139,21 +146,41 @@ Status CacheOp::WaitForCachingAllRows() { | |||
| RETURN_IF_NOT_OK(rows_cache_done_.Wait()); | |||
| // Move from build phase to fetch phase if we are the one to fill the cache | |||
| if (phase_ == Phase::kBuildPhase) { | |||
| RETURN_IF_NOT_OK(cache_client_->FlushAsyncWriteBuffer()); // One more flush | |||
| RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone()); | |||
| // Move to the next phase | |||
| phase_ = Phase::kFetchPhase; | |||
| } | |||
| // Get statistics from the server, and if we are not the one to create the cache, | |||
| // If we are not the one to create the cache, | |||
| // wait until the state changed from build phase to fetch base. | |||
| CacheServiceStat stat{}; | |||
| bool BuildPhaseDone = true; | |||
| do { | |||
| RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | |||
| BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase); | |||
| if (!BuildPhaseDone) { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |||
| int8_t out; | |||
| RETURN_IF_NOT_OK(cache_client_->GetState(&out)); | |||
| auto state = static_cast<CacheServiceState>(out); | |||
| switch (state) { | |||
| case CacheServiceState::kBuildPhase: | |||
| // Do nothing. Continue to wait. | |||
| BuildPhaseDone = false; | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |||
| break; | |||
| case CacheServiceState::kFetchPhase: | |||
| BuildPhaseDone = true; | |||
| break; | |||
| case CacheServiceState::kOutOfMemory: | |||
| return Status(StatusCode::kOutOfMemory, "Cache server is running out of memory"); | |||
| case CacheServiceState::kNoSpace: | |||
| return Status(StatusCode::kNoSpace, "Cache server is running of out spill storage"); | |||
| case CacheServiceState::kNone: | |||
| case CacheServiceState::kError: | |||
| default: | |||
| RETURN_STATUS_UNEXPECTED("Unexpected state: " + std::to_string(out)); | |||
| } | |||
| } while (!BuildPhaseDone); | |||
| // Get statistics from the server, and if we are not the one to create the cache, | |||
| // wait until the state changed from build phase to fetch base. | |||
| CacheServiceStat stat{}; | |||
| RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | |||
| const row_id_type min_key = stat.min_row_id; | |||
| const row_id_type max_key = stat.max_row_id; | |||
| num_rows_ = max_key - min_key + 1; | |||
| @@ -148,6 +148,12 @@ class BPlusTree { | |||
| acquire_lock_ = on_off; | |||
| } | |||
| void LockShared() { rw_lock_.LockShared(); } | |||
| void LockExclusive() { rw_lock_.LockExclusive(); } | |||
| void Unlock() { rw_lock_.Unlock(); } | |||
| private: | |||
| // Abstract class of a node (leaf or inner) | |||
| class BaseNode { | |||
| @@ -409,6 +415,21 @@ class BPlusTree { | |||
| bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } | |||
| bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } | |||
| void LockShared() { | |||
| cur_->rw_lock_.LockShared(); | |||
| locked_ = true; | |||
| } | |||
| void LockExclusive() { | |||
| cur_->rw_lock_.LockExclusive(); | |||
| locked_ = true; | |||
| } | |||
| void Unlock() { | |||
| cur_->rw_lock_.Unlock(); | |||
| locked_ = false; | |||
| } | |||
| private: | |||
| typename BPlusTree::LeafNode *cur_; | |||
| slot_type slot_; | |||
| @@ -458,6 +479,21 @@ class BPlusTree { | |||
| bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } | |||
| bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } | |||
| void LockShared() { | |||
| cur_->rw_lock_.LockShared(); | |||
| locked_ = true; | |||
| } | |||
| void LockExclusive() { | |||
| cur_->rw_lock_.LockExclusive(); | |||
| locked_ = true; | |||
| } | |||
| void Unlock() { | |||
| cur_->rw_lock_.Unlock(); | |||
| locked_ = false; | |||
| } | |||
| private: | |||
| const typename BPlusTree::LeafNode *cur_; | |||
| slot_type slot_; | |||