Browse Source

!10144 Fix cache_server --stop

From: @lixiachen
Reviewed-by: @mikef,@nsyca
Signed-off-by: @nsyca
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
af1072cee9
8 changed files with 68 additions and 221 deletions
  1. +7
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt
  2. +23
    -26
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc
  3. +3
    -10
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h
  4. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc
  5. +0
    -110
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.h
  6. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h
  7. +28
    -69
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc
  8. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h

+ 7
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt View File

@@ -33,6 +33,13 @@ if (ENABLE_CACHE)
storage_manager.cc
storage_container.cc)

if (ENABLE_ASAN)
target_compile_options(engine-cache-server PRIVATE -fsanitize=address)
target_compile_options(engine-cache-server PRIVATE -fno-omit-frame-pointer)
target_compile_options(engine-cache-server PRIVATE -ggdb)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address")
endif()

add_executable(cache_server cache_main.cc)
if (ENABLE_GPU)
target_link_libraries(cache_server


+ 23
- 26
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc View File

@@ -37,12 +37,10 @@ void CacheServerGreeterImpl::Shutdown() {
// Always shutdown the completion queue after the server.
if (cq_) {
cq_->Shutdown();
// We need to drain the queue. All the tag is coming from
// the Services pool which will be shutdown as well. So we
// ignore the tag.
void *tag;
bool success;
while (cq_->Next(&tag, &success)) {
delete reinterpret_cast<CacheServerRequest *>(tag);
}
}
}
@@ -93,30 +91,41 @@ Status CacheServerGreeterImpl::HandleRequest(int32_t worker_id) {
// and inject them into the grpc queue.
CacheServerRequest *p;
// Get a free tag from my free list.
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(worker_id, &p));
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(&p));
RETURN_IF_NOT_OK((*p)(&svc_, cq_.get()));
do {
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto r = cq_->AsyncNext(&tag, &success, deadline);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
auto rq = static_cast<CacheServerRequest *>(tag);
if (success) {
auto rq = static_cast<CacheServerRequest *>(tag);
RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get()));
if (rq->st_ == CacheServerRequest::STATE::PROCESS) {
RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get()));
} else if (rq->st_ == CacheServerRequest::STATE::FINISH) {
MS_LOG(DEBUG) << *rq << " Finished.";
if (rq->type_ == BaseRequest::RequestType::kStopService) {
// For cache_admin --stop, ProcessRequest is just acknowledging we receive the request. Now
// we call the real function.
auto &cs = CacheServer::GetInstance();
cs.GlobalShutdown();
}
RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(rq));
}
} else {
RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(rq));
}
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();
} else {
// Queue is drained.
break;
}
} while (true);
} while (!this_thread::is_interrupted());
return Status::OK();
}

Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq) {
auto myQID = getQid();
if (st_ == STATE::CREATE) {
st_ = STATE::PROCESS;
svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this);
@@ -129,7 +138,7 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
// We can round robin, use the qid or even use the worker id. We will use the free list queue
// where the current request comes from.
CacheServerRequest *next_rq;
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq));
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(&next_rq));
RETURN_IF_NOT_OK((*next_rq)(svc, cq));
// Now we continue with the current request.
// First thing we need to extract the type from the incoming request.
@@ -144,25 +153,13 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
type_ == BaseRequest::RequestType::kStopService || type_ == BaseRequest::RequestType::kAllocateSharedBlock ||
type_ == BaseRequest::RequestType::kFreeSharedBlock) {
cs.ProcessRequest(this);
// For cache_admin --stop, ProcessRequest is just acknowledging we receive the request. Now
// we call the real function.
if (type_ == BaseRequest::RequestType::kStopService) {
cs.GlobalShutdown();
return Status(StatusCode::kInterrupted);
} else if (rc_.IsInterrupted()) {
return rc_;
}
// WARNING. After we call ProcessRequest, the memory of 'this' is being recycled by ReturnRequestTag
// asynchronously. Further access of 'this' is unpredictable.
} else {
// When the number of grpc workers is the same as the server workers, we will use this queue id
// and push to the corresponding queue.
bool random = cs.GetNumWorkers() != cs.GetNumGrpcWorkers();
worker_id_t worker_id = random ? cs.GetRandomWorker() : myQID;
RETURN_IF_NOT_OK(cs.PushRequest(worker_id, this));
RETURN_IF_NOT_OK(cs.PushRequest(cs.GetRandomWorker(), this));
}
} else if (st_ == STATE::FINISH) {
MS_LOG(DEBUG) << *this << " Finished.";
// Return back to the free list.
RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(this));
// We don't have logic here but moved to the caller.
}
return Status::OK();
}


+ 3
- 10
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h View File

@@ -38,12 +38,10 @@ class CacheServerRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
friend class CacheServerGreeterImpl;
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
explicit CacheServerRequest(int32_t queue_id)
: BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown),
qid_(queue_id),
st_(STATE::CREATE),
responder_(&ctx_) {}
CacheServerRequest()
: BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown), st_(STATE::CREATE), responder_(&ctx_) {}

~CacheServerRequest() override = default;

@@ -58,12 +56,7 @@ class CacheServerRequest : public BaseRequest {
/// \param out
void Print(std::ostream &out) const override;

/// \brief Getter of the queue id
/// \return The queue where the request should go to
int32_t getQid() const { return qid_; }

private:
int32_t qid_;
Status rc_;
STATE st_;
grpc::ServerContext ctx_;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc View File

@@ -115,8 +115,8 @@ Status CacheServerHW::GetNumaNodeInfo() {
const char kCpuList[] = "cpulist";
auto r = std::regex("[0-9]*-[0-9]*");
for (Path p : numa_nodes_) {
auto node_dir = p.Basename().data();
numa_id_t numa_node = strtol(node_dir + strlen(kNodeName), nullptr, 10);
auto node_dir = p.Basename();
numa_id_t numa_node = strtol(node_dir.data() + strlen(kNodeName), nullptr, 10);
Path f = p / kCpuList;
std::ifstream fs(f.toString());
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + f.toString());


+ 0
- 110
mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.h View File

@@ -28,116 +28,6 @@

namespace mindspore {
namespace dataset {
/// \brief An allocator but for a particular numa node.
template <typename T>
class NumaAllocator {
public:
explicit NumaAllocator(numa_id_t node_id, CachePoolPolicy policy)
: policy_(policy), numa_enabled_(false), node_id_(node_id) {
#ifdef NUMA_ENABLED
numa_enabled_ = numa_available() != -1;
#endif
}
~NumaAllocator() = default;

template <typename U>
explicit NumaAllocator(NumaAllocator<U> const &rhs)
: policy_(rhs.policy_), numa_enabled_(rhs.numa_enabled_), node_id_(rhs.node_id_) {}

template <typename U>
bool operator==(Allocator<U> const &rhs) const {
return node_id_ == rhs.node_id_;
}

template <typename U>
bool operator!=(Allocator<U> const &rhs) const {
return node_id_ != rhs.node_id_;
}

template <typename U>
friend class NumaAllocator;

using value_type = T;
using pointer = T *;
using const_pointer = const T *;
using reference = T &;
using const_reference = const T &;
using size_type = uint64_t;
using difference_type = std::ptrdiff_t;

template <typename U>
struct rebind {
using other = Allocator<U>;
};

using propagate_on_container_copy_assignment = std::true_type;
using propagate_on_container_move_assignment = std::true_type;
using propagate_on_container_swap = std::true_type;

/// Allocate memory on this node only. Return nullptr if no memory on this numa node.
/// \note. This version will not throw if we can't allocate memory from this node.
/// User must check if the pointer returned is null or not.
pointer allocate(std::size_t n) noexcept {
auto sz = n * sizeof(T);
void *p = nullptr;
#ifdef NUMA_ENABLED
if (numa_enabled_) {
switch (policy_) {
case kPreferred:
numa_set_preferred(node_id_);
p = numa_alloc(sz);
break;
case kLocal:
p = numa_alloc_local(sz);
break;
case kInterleave:
p = numa_alloc_interleaved(sz);
break;
case kOnNode:
p = numa_alloc_onnode(sz, node_id_);
break;
case kNone:
default:
p = numa_alloc(sz);
break;
}
} else {
p = malloc(sz);
}
#else
p = malloc(sz);
#endif
return reinterpret_cast<pointer>(p);
}

/// Free a memory allocated on this node.
void deallocate(pointer p, std::size_t n) noexcept {
#ifdef NUMA_ENABLED
if (numa_enabled_) {
numa_free(p, n * sizeof(T));
} else {
free(p);
}
#else
free(p);
#endif
}

/// \brief Allow one to change to another numa node
void SetNodeId(numa_id_t node_id) { node_id_ = node_id; }

/// \brif Getter for node_id;
numa_id_t GetNodeId() const { return node_id_; }

/// \brief Getter for policy
CachePoolPolicy GetPolicy() const { return policy_; }

private:
CachePoolPolicy policy_;
bool numa_enabled_;
numa_id_t node_id_;
};

/// \brief A NumaMemoryPool is like a CircularPool but all the arenas have already been allocated
/// and each one comes from a numa socket. Memory is allocated using OnNode policy. That is,
/// it is solely comes from one particular numa node, and is not interleaved.


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h View File

@@ -91,6 +91,7 @@ class BaseRequest {
friend class CacheClientRequestTag;
friend class CacheClient;
friend class CacheService;
friend class CacheServerGreeterImpl;

/// \brief Base class of a cache server request
/// \param type Type of the request


+ 28
- 69
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc View File

@@ -76,34 +76,7 @@ Status CacheServer::DoServiceStart() {
// But technically they don't have to be the same.
num_grpc_workers_ = num_workers_;
MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_;
// For the grpc completion queue to work, we need to allocate some
// tags which in our case are instances of CacheServerQuest.
// They got recycled and we will allocate them in advance and push
// them into some free list. We need more (two or three times) the
// size of the cache_q. While each worker is working on a CacheSerRequest,
// we need some extra running injecting in the the qrpc completion queue.
const int32_t kMultiplier = 2;
int ratio = num_workers_ / num_grpc_workers_;
if (num_workers_ % num_grpc_workers_) ++ratio;
const int32_t free_list_capacity = kMultiplier * (kQueCapacity + 1) * ratio;
free_list_ = std::make_shared<QueueList<CacheServerRequest *>>();
free_list_->Init(num_grpc_workers_, free_list_capacity);
tag_.reserve(num_grpc_workers_);
// Now we populate all free list. Round robin the free list among the numa nodes.
for (auto m = 0; m < num_grpc_workers_; ++m) {
NumaAllocator<CacheServerRequest> alloc(m % num_numa_nodes, CachePoolPolicy::kPreferred);
// Ideally we allocate all the free list in one malloc. But we will allocate one segment
// at a time so that we can change the numa policy easily per grpc worker.
auto my_tag = std::make_unique<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>(alloc);
// Allocate the tag and assign it the current queue
RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m));
for (int i = 0; i < free_list_capacity; ++i) {
RETURN_IF_NOT_OK(free_list_->operator[](m)->Add((*my_tag)[i]));
}
tag_.push_back(std::move(my_tag));
}
RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
RETURN_IF_NOT_OK(free_list_->Register(&vg_));
// Start the comm layer
try {
comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_);
@@ -396,14 +369,10 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {

Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) {
RETURN_UNEXPECTED_IF_NULL(out);
int32_t numQ = GetNumGrpcWorkers();
auto rng = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
int32_t qID = distribution(rng);
auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
const auto num_elements = p->rows()->size();
auto connection_id = p->connection_id();
auto batch_wait = std::make_unique<BatchWait>(num_elements);
auto batch_wait = std::make_shared<BatchWait>(num_elements);
int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
offset_array[0] = data_offset;
@@ -423,7 +392,7 @@ Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuil
// Get a request and send to the proper worker (at some numa node) to do the fetch.
worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker();
CacheServerRequest *cache_rq;
RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq));
RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
// Set up all the necessarily field.
cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
@@ -719,10 +688,6 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {

Status CacheServer::BatchCacheRows(CacheRequest *rq) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == 3, "Expect three pieces of data");
int32_t numQ = GetNumGrpcWorkers();
auto rng = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
int32_t qID = distribution(rng);
try {
auto &cookie = rq->buf_data(0);
auto connection_id = rq->connection_id();
@@ -733,7 +698,7 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) {
offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10);
auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
num_elem = strtol(rq->buf_data(2).data(), nullptr, 10);
auto batch_wait = std::make_unique<BatchWait>(num_elem);
auto batch_wait = std::make_shared<BatchWait>(num_elem);
// Get a set of free request and push into the queues.
for (auto i = 0; i < num_elem; ++i) {
auto start = reinterpret_cast<int64_t>(p);
@@ -743,7 +708,7 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) {
p += msg->data_sz()->Get(k);
}
CacheServerRequest *cache_rq;
RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq));
RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
// Fill in details.
cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow;
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
@@ -787,7 +752,11 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
try {
int64_t addr = strtol(rq.buf_data(3).data(), nullptr, 10);
auto *bw = reinterpret_cast<BatchWait *>(addr);
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
if (bwObj.lock()) {
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
@@ -820,7 +789,11 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
try {
int64_t addr = strtol(rq.buf_data(1).data(), nullptr, 10);
auto *bw = reinterpret_cast<BatchWait *>(addr);
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
// Check if the object is still around.
auto bwObj = bw->GetBatchWait();
if (bwObj.lock()) {
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
@@ -918,7 +891,7 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
cache_req->st_ = CacheServerRequest::STATE::FINISH;
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
// the CacheServerRequest, i.e. the pointer cache_req, will be free
if (!internal_request) {
if (!internal_request && !global_shutdown_) {
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
} else {
// We can free up the request now.
@@ -994,28 +967,26 @@ Status CacheServer::Run(int msg_qid) {
// note that after we have sent the initial status using the msg_qid, parent process will exit and
// remove it. So we can't use it again.
RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking));
// Shutdown the grpc queue. No longer accept any new comer.
comm_layer_->Shutdown();
// The next thing to do drop all the caches.
RETURN_IF_NOT_OK(ServiceStop());
return Status::OK();
}

Status CacheServer::GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q) {
Status CacheServer::GetFreeRequestTag(CacheServerRequest **q) {
RETURN_UNEXPECTED_IF_NULL(q);
CacheServer &cs = CacheServer::GetInstance();
CacheServerRequest *p;
RETURN_IF_NOT_OK(cs.free_list_->operator[](queue_id)->PopFront(&p));
auto *p = new (std::nothrow) CacheServerRequest();
if (p == nullptr) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
*q = p;
return Status::OK();
}

Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
RETURN_UNEXPECTED_IF_NULL(p);
int32_t myQID = p->getQid();
// Free any memory from the protobufs
p->~CacheServerRequest();
// Re-initialize the memory
new (p) CacheServerRequest(myQID);
// Now we return it back to free list.
CacheServer &cs = CacheServer::GetInstance();
RETURN_IF_NOT_OK(cs.free_list_->operator[](myQID)->Add(p));
delete p;
return Status::OK();
}

@@ -1134,22 +1105,10 @@ void CacheServer::GlobalShutdown() {
bool expected = false;
if (global_shutdown_.compare_exchange_strong(expected, true)) {
MS_LOG(WARNING) << "Shutting down server.";
// Shutdown the grpc queue. No longer accept any new comer.
// The threads we spawn to work on the grpc queue will exit themselves once
// they notice the queue has been shutdown.
comm_layer_->Shutdown();
// Now we interrupt any threads that are waiting on cache_q_
// Interrupt all the threads and queues. We will leave the shutdown
// of the comm layer after we have joined all the threads and will
// be done by the master thread.
vg_.interrupt_all();
// The next thing to do drop all the caches.
UniqueLock lck(&rwLock_);
for (auto it = all_caches_.begin(); it != all_caches_.end();) {
auto id = it->first;
MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
// Wait for all outstanding work to be finished.
auto &cs = it->second;
UniqueLock cs_lock(&cs->rw_lock_);
it = all_caches_.erase(it);
}
}
}



+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h View File

@@ -165,7 +165,7 @@ class CacheServer : public Service {
/// \brief Get a free tag
/// \param q[in] pointer to a pointer to a CacheServerRequest
/// \return Status object
static Status GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q);
static Status GetFreeRequestTag(CacheServerRequest **q);

/// \brief Return a tag to the free list
/// \param p[in] pointer to already finished CacheServerRequest tag
@@ -232,8 +232,6 @@ class CacheServer : public Service {
cache_index all_caches_;
std::set<session_id_type> active_sessions_;
std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_;
std::shared_ptr<QueueList<CacheServerRequest *>> free_list_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>> tag_;
std::shared_ptr<CacheServerGreeterImpl> comm_layer_;
TaskGroup vg_;
int32_t num_workers_;
@@ -359,13 +357,15 @@ class CacheServer : public Service {
/// So we will let the server thread return the free tag immediately but the put
/// the return code in this following structure. GRPC thread must wait until all
/// the rc come back.
class BatchWait {
class BatchWait : public std::enable_shared_from_this<BatchWait> {
public:
explicit BatchWait(int n) : expected_(n), num_back_(0) {
expected_ = n;
rc_lists_.reserve(expected_);
}

std::weak_ptr<BatchWait> GetBatchWait() { return weak_from_this(); }

Status Set(Status rc) {
CHECK_FAIL_RETURN_UNEXPECTED(expected_ > num_back_, "Programming error");
std::unique_lock<std::mutex> lck(mux_);


Loading…
Cancel
Save