| @@ -62,7 +62,7 @@ if (ENABLE_CACHE) | |||
| endif () | |||
| add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) | |||
| target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES}) | |||
| target_link_libraries(cache_admin _c_dataengine _c_mindrecord mindspore::protobuf ${PYTHON_LIBRARIES} pthread) | |||
| if (USE_GLOG) | |||
| target_link_libraries(cache_admin mindspore::glog) | |||
| @@ -27,6 +27,8 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| @@ -325,9 +327,33 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| Help(); | |||
| break; | |||
| } | |||
| case CommandId::kCmdStart: | |||
| case CommandId::kCmdStart: { | |||
| RETURN_IF_NOT_OK(StartServer(command_id_)); | |||
| break; | |||
| } | |||
| case CommandId::kCmdStop: { | |||
| RETURN_IF_NOT_OK(StartStopServer(command_id_)); | |||
| CacheClientGreeter comm(hostname_, port_, 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| SharedMessage msg; | |||
| RETURN_IF_NOT_OK(msg.Create()); | |||
| auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId()); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| Status rc = rq->Wait(); | |||
| if (rc.IsError()) { | |||
| msg.RemoveResourcesOnExit(); | |||
| if (rc.IsNetWorkError()) { | |||
| std::string errMsg = "Server is not up or has been shutdown already."; | |||
| return Status(StatusCode::kNetWorkError, errMsg); | |||
| } | |||
| return rc; | |||
| } | |||
| // OK return code only means the server acknowledge our request but we still | |||
| // have to wait for its complete shutdown because the server will shutdown | |||
| // the comm layer as soon as the request is received, and we need to wait | |||
| // on the message queue instead. | |||
| // The server will remove the queue and we will then wake up. | |||
| Status dummy_rc; | |||
| (void)msg.ReceiveStatus(&dummy_rc); | |||
| break; | |||
| } | |||
| case CommandId::kCmdGenerateSession: { | |||
| @@ -396,7 +422,7 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { | |||
| Status CacheAdminArgHandler::StartServer(CommandId command_id) { | |||
| // There currently does not exist any "install path" or method to identify which path the installed binaries will | |||
| // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | |||
| // cache_admin binary (this process). | |||
| @@ -477,23 +503,15 @@ Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { | |||
| std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_); | |||
| char *argv[9]; | |||
| if (command_id == CommandId::kCmdStart) { | |||
| argv[0] = cache_server_binary.data(); | |||
| argv[1] = spill_dir_.data(); | |||
| argv[2] = workers_string.data(); | |||
| argv[3] = port_string.data(); | |||
| argv[4] = shared_memory_string.data(); | |||
| argv[5] = minloglevel_string.data(); | |||
| argv[6] = daemonize_string.data(); | |||
| argv[7] = memory_cap_ratio_string.data(); | |||
| argv[8] = nullptr; | |||
| } else { | |||
| // We are doing a --stop. Change the name to '-' and we also need the port number. | |||
| // The rest we don't need. | |||
| argv[0] = std::string("-").data(); | |||
| argv[1] = port_string.data(); | |||
| argv[2] = nullptr; | |||
| } | |||
| argv[0] = cache_server_binary.data(); | |||
| argv[1] = spill_dir_.data(); | |||
| argv[2] = workers_string.data(); | |||
| argv[3] = port_string.data(); | |||
| argv[4] = shared_memory_string.data(); | |||
| argv[5] = minloglevel_string.data(); | |||
| argv[6] = daemonize_string.data(); | |||
| argv[7] = memory_cap_ratio_string.data(); | |||
| argv[8] = nullptr; | |||
| // Now exec the binary | |||
| execv(cache_server_binary.data(), argv); | |||
| @@ -509,17 +527,27 @@ void CacheAdminArgHandler::Help() { | |||
| std::cerr << "Syntax:\n"; | |||
| std::cerr << " cache_admin [--start | --stop]\n"; | |||
| std::cerr << " [ [-h | --hostname] <hostname> ]\n"; | |||
| std::cerr << " Default is " << kCfgDefaultCacheHost << ".\n"; | |||
| std::cerr << " [ [-p | --port] <port number> ]\n"; | |||
| std::cerr << " Possible values are in range [1025..65535].\n"; | |||
| std::cerr << " Default is " << kCfgDefaultCachePort << ".\n"; | |||
| std::cerr << " [ [-g | --generate_session] ]\n"; | |||
| std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; | |||
| std::cerr << " [ [-w | --workers] <number of workers> ]\n"; | |||
| std::cerr << " Possible values are in range [1...max(100, Number of CPU)].\n"; | |||
| std::cerr << " Default is " << kDefaultNumWorkers << ".\n"; | |||
| std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; | |||
| std::cerr << " Default is " << kDefaultSpillDir << ".\n"; | |||
| std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; | |||
| std::cerr << " Possible values are 0, 1, 2 and 3.\n"; | |||
| std::cerr << " Default is 1 (info level).\n"; | |||
| std::cerr << " [ --list_sessions ]\n"; | |||
| // Do not expose these option to the user via help or documentation, but the options do exist to aid with | |||
| // development and tuning. | |||
| // std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | |||
| // std::cerr << " Default is " << kDefaultSharedMemorySizeInGB << " (Gb in unit).\n"; | |||
| // std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n"; | |||
| // std::cerr << " Default is " << kMemoryCapRatio << ".\n"; | |||
| std::cerr << " [--help]" << std::endl; | |||
| } | |||
| } // namespace dataset | |||
| @@ -78,7 +78,7 @@ class CacheAdminArgHandler { | |||
| kArgNumArgs = 14 // Must be the last position to provide a count | |||
| }; | |||
| Status StartStopServer(CommandId); | |||
| Status StartServer(CommandId command_id); | |||
| Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id = CommandId::kCmdUnknown); | |||
| @@ -21,11 +21,7 @@ CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) | |||
| // We create the shared memory and we will destroy it. All other client just detach only. | |||
| shm_.RemoveResourcesOnExit(); | |||
| } | |||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() { | |||
| // Also remove the path we use to generate ftok. | |||
| Path p(PortToUnixSocketPath(port_)); | |||
| (void)p.Remove(); | |||
| } | |||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() {} | |||
| Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | |||
| size_t val_in_GB) { | |||
| @@ -72,6 +72,9 @@ class CachedSharedMemoryArena : public MemoryPool { | |||
| return os; | |||
| } | |||
| /// \brief Get the shared memory key of the shared memory | |||
| SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); } | |||
| private: | |||
| mutable std::mutex mux_; | |||
| int32_t val_in_GB_; | |||
| @@ -13,10 +13,12 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <chrono> | |||
| #include <limits> | |||
| #include "minddata/dataset/engine/cache/cache_grpc_server.h" | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "utils/log_adapter.h" | |||
| #else | |||
| @@ -25,7 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) | |||
| : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) { | |||
| : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) { | |||
| // Setup a path for unix socket. | |||
| unix_socket_ = PortToUnixSocketPath(port); | |||
| // We can't generate the ftok key yet until the unix_socket_ is created | |||
| @@ -73,7 +75,8 @@ Status CacheServerGreeterImpl::Run() { | |||
| MS_LOG(INFO) << "Server listening on " << server_address; | |||
| #if CACHE_LOCAL_CLIENT | |||
| RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | |||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful"; | |||
| shm_key_ = shm_pool_->GetKey(); | |||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_; | |||
| auto cs = CacheServer::GetInstance().GetHWControl(); | |||
| // This shared memory is a hot memory and we will interleave among all the numa nodes. | |||
| cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); | |||
| @@ -181,5 +184,32 @@ void CacheServerRequest::Print(std::ostream &out) const { | |||
| out << " "; | |||
| BaseRequest::Print(out); | |||
| } | |||
| Status CacheServerGreeterImpl::MonitorUnixSocket() { | |||
| TaskManager::FindMe()->Post(); | |||
| #if CACHE_LOCAL_CLIENT | |||
| Path p(unix_socket_); | |||
| do { | |||
| RETURN_IF_INTERRUPTED(); | |||
| // If the unix socket is recreated for whatever reason, this server instance will be stale and | |||
| // no other process and communicate with us. In this case we need to shutdown ourselves. | |||
| if (p.Exists()) { | |||
| SharedMemory::shm_key_t key; | |||
| RETURN_IF_NOT_OK(PortToFtok(port_, &key)); | |||
| if (key != shm_key_) { | |||
| std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) + | |||
| ". New key " + std::to_string(key) + ". Shutting down server"; | |||
| MS_LOG(ERROR) << errMsg; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Unix socket is removed."; | |||
| TaskManager::WakeUpWatchDog(); | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::seconds(5)); | |||
| } while (true); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -87,6 +87,9 @@ class CacheServerGreeterImpl final { | |||
| /// \return Return the shared memory pool | |||
| CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } | |||
| /// \brief Montor the status of the unix socket in case it is gone. | |||
| Status MonitorUnixSocket(); | |||
| void Shutdown(); | |||
| private: | |||
| @@ -97,6 +100,7 @@ class CacheServerGreeterImpl final { | |||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_; | |||
| std::unique_ptr<grpc::Server> server_; | |||
| std::unique_ptr<CachedSharedMemoryArena> shm_pool_; | |||
| SharedMemory::shm_key_t shm_key_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -160,6 +160,9 @@ class SharedMemory : public BaseIPC { | |||
| /// \brief Set the public key | |||
| void SetPublicKey(key_t public_key) { shm_key_ = public_key; } | |||
| /// \brief Retrieve the key | |||
| shm_key_t GetKey() const { return shm_key_; } | |||
| /// \brief This returns where we attach to the shared memory. | |||
| /// \return Base address of the shared memory. | |||
| const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | |||
| @@ -27,116 +27,6 @@ | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| namespace ds = mindspore::dataset; | |||
| /// Send a synchronous command to the local server using tcp/ip. | |||
| /// We aren't using any client code because this binary is not necessarily linked with the client library. | |||
| /// So just using grpc call directly. | |||
| /// \param port tcp/ip port to use | |||
| /// \param type Type of command. | |||
| /// \param out grpc result | |||
| /// \return Status object | |||
| ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply, | |||
| grpc::Status *out) { | |||
| if (rq == nullptr) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null"); | |||
| } | |||
| if (reply == nullptr) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null"); | |||
| } | |||
| if (out == nullptr) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null"); | |||
| } | |||
| const std::string hostname = "127.0.0.1"; | |||
| auto unix_socket = ds::PortToUnixSocketPath(port); | |||
| #if CACHE_LOCAL_CLIENT | |||
| const std::string target = "unix://" + unix_socket; | |||
| #else | |||
| const std::string target = hostname + ":" + std::to_string(port); | |||
| #endif | |||
| try { | |||
| rq->set_type(static_cast<int16_t>(type)); | |||
| rq->set_client_id(-1); | |||
| rq->set_flag(0); | |||
| grpc::ChannelArguments args; | |||
| grpc::ClientContext ctx; | |||
| grpc::CompletionQueue cq; | |||
| // Standard async rpc call | |||
| std::shared_ptr<grpc::Channel> channel = | |||
| grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); | |||
| std::unique_ptr<ds::CacheServerGreeter::Stub> stub = ds::CacheServerGreeter::NewStub(channel); | |||
| std::unique_ptr<grpc::ClientAsyncResponseReader<ds::CacheReply>> rpc = | |||
| stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq); | |||
| rpc->StartCall(); | |||
| // We need to pass a tag. But since this is the only request in the completion queue and so we | |||
| // just pass a dummy | |||
| int64_t dummy; | |||
| void *tag; | |||
| bool success; | |||
| rpc->Finish(reply, out, &dummy); | |||
| // Now we wait on the completion queue synchronously. | |||
| auto r = cq.Next(&tag, &success); | |||
| if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { | |||
| if (!success || tag != &dummy) { | |||
| std::string errMsg = "Unexpected programming error "; | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| if (out->ok()) { | |||
| return ds::Status(static_cast<ds::StatusCode>(reply->rc()), reply->msg()); | |||
| } else { | |||
| auto error_code = out->error_code(); | |||
| std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code); | |||
| return ds::Status(ds::StatusCode::kNetWorkError, errMsg); | |||
| } | |||
| } else { | |||
| std::string errMsg = "Unexpected queue rc = " + std::to_string(r); | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); | |||
| } | |||
| } | |||
| /// Stop the server | |||
| /// \param argv | |||
| /// \return Status object | |||
| ds::Status StopServer(int argc, char **argv) { | |||
| ds::Status rc; | |||
| ds::CacheServer::Builder builder; | |||
| std::string errMsg; | |||
| if (argc != 2) { | |||
| return ds::Status(ds::StatusCode::kSyntaxError); | |||
| } | |||
| int32_t port = strtol(argv[1], nullptr, 10); | |||
| // We will go through the builder to do some snaity check. We only need the port number | |||
| // to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything | |||
| // to the spill directory. | |||
| builder.SetPort(port).SetRootDirectory(""); | |||
| // Part of the sanity check is check the shared memory. If the server is up and running, we expect | |||
| // the return code is kDuplicate. | |||
| rc = builder.SanityCheck(); | |||
| if (rc.IsOk()) { | |||
| errMsg = "Server is not up or has been shutdown already."; | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, errMsg); | |||
| } else if (rc.get_code() != ds::StatusCode::kDuplicateKey) { | |||
| // Not OK, and no duplicate, just return the rc whatever it is. | |||
| return rc; | |||
| } else { | |||
| // Now we get some work to do. We will send a tcp/ip request to the given port. | |||
| // This binary is not linked with client side of code, so we will just call grpc directly. | |||
| ds::CacheRequest rq; | |||
| ds::CacheReply reply; | |||
| grpc::Status grpc_rc; | |||
| rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc); | |||
| // The request is like a self destruct message, the server will not send anything back and | |||
| // shutdown all incoming request. So we should expect some unexpected network error if | |||
| // all goes well and we expect to GRPC code 14. | |||
| auto err_code = grpc_rc.error_code(); | |||
| if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| return ds::Status::OK(); | |||
| } | |||
| /// Start the server | |||
| /// \param argv | |||
| /// \return Status object | |||
| @@ -235,15 +125,8 @@ ds::Status StartServer(int argc, char **argv) { | |||
| } | |||
| int main(int argc, char **argv) { | |||
| ds::Status rc; | |||
| ds::CacheServer::Builder builder; | |||
| // This executable is not to be called directly, and should be invoked by cache_admin executable. | |||
| if (strcmp(argv[0], "-") == 0) { | |||
| rc = StopServer(argc, argv); | |||
| } else { | |||
| rc = StartServer(argc, argv); | |||
| } | |||
| ds::Status rc = StartServer(argc, argv); | |||
| // Check result | |||
| if (rc.IsError()) { | |||
| auto errCode = rc.get_code(); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <unistd.h> | |||
| #endif | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <thread> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| @@ -326,5 +327,11 @@ Status ListSessionsRequest::PostReply() { | |||
| return Status::OK(); | |||
| } | |||
| Status ServerStopRequest::PostReply() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(strcmp(reply_.result().data(), "OK") == 0, "Not the right response"); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -394,6 +394,15 @@ class ToggleWriteModeRequest : public BaseRequest { | |||
| ~ToggleWriteModeRequest() override = default; | |||
| }; | |||
| class ServerStopRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit ServerStopRequest(int32_t qID) : BaseRequest(RequestType::kStopService) { | |||
| rq_.add_buf_data(std::to_string(qID)); | |||
| } | |||
| Status PostReply() override; | |||
| }; | |||
| class ConnectResetRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| @@ -108,6 +108,9 @@ Status CacheServer::DoServiceStart() { | |||
| try { | |||
| comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | |||
| RETURN_IF_NOT_OK(comm_layer_->Run()); | |||
| // Bring up a thread to monitor the unix socket in case it is removed. | |||
| auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get()); | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f)); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| @@ -154,6 +157,15 @@ Status CacheServer::DoServiceStop() { | |||
| } | |||
| ++it; | |||
| } | |||
| // Also remove the path we use to generate ftok. | |||
| Path p(PortToUnixSocketPath(port_)); | |||
| (void)p.Remove(); | |||
| // Finally wake up cache_admin if it is waiting | |||
| for (int32_t qID : shutdown_qIDs_) { | |||
| SharedMessage msg(qID); | |||
| msg.RemoveResourcesOnExit(); | |||
| // Let msg goes out of scope which will destroy the queue. | |||
| } | |||
| return rc; | |||
| } | |||
| @@ -374,6 +386,68 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| return rc; | |||
| } | |||
| Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| int32_t numQ = GetNumGrpcWorkers(); | |||
| auto rng = GetRandomDevice(); | |||
| std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1); | |||
| int32_t qID = distribution(rng); | |||
| std::vector<CacheServerRequest *> cache_rq_list; | |||
| auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||
| const auto num_elements = p->rows()->size(); | |||
| auto connection_id = p->connection_id(); | |||
| cache_rq_list.reserve(num_elements); | |||
| int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | |||
| auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | |||
| offset_array[0] = data_offset; | |||
| for (auto i = 0; i < num_elements; ++i) { | |||
| auto data_locator = p->rows()->Get(i); | |||
| auto node_id = data_locator->node_id(); | |||
| size_t sz = data_locator->size(); | |||
| void *source_addr = reinterpret_cast<void *>(data_locator->addr()); | |||
| auto key = data_locator->key(); | |||
| // Please read the comment in CacheServer::BatchFetchRows where we allocate | |||
| // the buffer big enough so each thread (which we are going to dispatch) will | |||
| // not run into false sharing problem. We are going to round up sz to 4k. | |||
| auto sz_4k = round_up_4K(sz); | |||
| offset_array[i + 1] = offset_array[i] + sz_4k; | |||
| if (sz > 0) { | |||
| WritableSlice row_data(*out, offset_array[i], sz); | |||
| // Get a request and send to the proper worker (at some numa node) to do the fetch. | |||
| worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker(); | |||
| CacheServerRequest *cache_rq; | |||
| RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq)); | |||
| cache_rq_list.push_back(cache_rq); | |||
| // Set up all the necessarily field. | |||
| cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; | |||
| cache_rq->st_ = CacheServerRequest::STATE::PROCESS; | |||
| cache_rq->rq_.set_connection_id(connection_id); | |||
| cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_)); | |||
| auto dest_addr = row_data.GetMutablePointer(); | |||
| flatbuffers::FlatBufferBuilder fb2; | |||
| FetchRowMsgBuilder bld(fb2); | |||
| bld.add_key(key); | |||
| bld.add_size(sz); | |||
| bld.add_source_addr(reinterpret_cast<int64_t>(source_addr)); | |||
| bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr)); | |||
| auto offset = bld.Finish(); | |||
| fb2.Finish(offset); | |||
| cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); | |||
| RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq)); | |||
| } | |||
| } | |||
| // Now wait for all of them to come back. | |||
| Status rc; | |||
| for (CacheServerRequest *rq : cache_rq_list) { | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { | |||
| rc = rq->rc_; | |||
| } | |||
| RETURN_IF_NOT_OK(ReturnRequestTag(rq)); | |||
| } | |||
| return rc; | |||
| } | |||
| Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| @@ -394,6 +468,9 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| } | |||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||
| RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb)); | |||
| // Let go of the shared lock. We don't need to interact with the CacheService anymore. | |||
| // We shouldn't be holding any lock while we can wait for a long time for the rows to come back. | |||
| lck.Unlock(); | |||
| auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||
| int64_t mem_sz = sizeof(int64_t) * (sz + 1); | |||
| for (auto i = 0; i < sz; ++i) { | |||
| @@ -418,7 +495,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| void *q = nullptr; | |||
| RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | |||
| WritableSlice dest(q, mem_sz); | |||
| Status rc = cs->BatchFetch(fbb, &dest); | |||
| Status rc = BatchFetch(fbb, &dest); | |||
| if (rc.IsError()) { | |||
| shared_pool->Deallocate(q); | |||
| return rc; | |||
| @@ -439,7 +516,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| WritableSlice dest(mem.data(), mem_sz); | |||
| RETURN_IF_NOT_OK(cs->BatchFetch(fbb, &dest)); | |||
| RETURN_IF_NOT_OK(BatchFetch(fbb, &dest)); | |||
| reply->set_result(std::move(mem)); | |||
| } | |||
| } | |||
| @@ -721,7 +798,7 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) { | |||
| } | |||
| case BaseRequest::RequestType::kStopService: { | |||
| // This command shutdowns everything. | |||
| cache_req->rc_ = GlobalShutdown(); | |||
| cache_req->rc_ = GlobalShutdown(cache_req); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kHeartBeat: { | |||
| @@ -914,7 +991,25 @@ Status CacheServer::RpcRequest(worker_id_t worker_id) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::GlobalShutdown() { | |||
| Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { | |||
| auto *rq = &cache_req->rq_; | |||
| auto *reply = &cache_req->reply_; | |||
| if (!rq->buf_data().empty()) { | |||
| // cache_admin sends us a message qID and we will destroy the | |||
| // queue in our destructor and this will wake up cache_admin. | |||
| // But we don't want the cache_admin blindly just block itself. | |||
| // So we will send back an ack before shutdown the comm layer. | |||
| try { | |||
| int32_t qID = std::stoi(rq->buf_data(0)); | |||
| shutdown_qIDs_.push_back(qID); | |||
| } catch (const std::exception &e) { | |||
| // ignore it. | |||
| } | |||
| } | |||
| reply->set_result("OK"); | |||
| Status2CacheReply(cache_req->rc_, reply); | |||
| cache_req->st_ = CacheServerRequest::STATE::FINISH; | |||
| cache_req->responder_.Finish(*reply, grpc::Status::OK, cache_req); | |||
| // Let's shutdown in proper order. | |||
| bool expected = false; | |||
| if (global_shutdown_.compare_exchange_strong(expected, true)) { | |||
| @@ -939,7 +1034,7 @@ Status CacheServer::GlobalShutdown() { | |||
| return Status::OK(); | |||
| } | |||
| worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { | |||
| worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const { | |||
| auto num_numa_nodes = GetNumaNodeCount(); | |||
| MS_ASSERT(numa_id < num_numa_nodes); | |||
| auto num_workers_per_node = GetNumWorkers() / num_numa_nodes; | |||
| @@ -951,7 +1046,7 @@ worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { | |||
| return worker_id; | |||
| } | |||
| worker_id_t CacheServer::GetRandomWorker() { | |||
| worker_id_t CacheServer::GetRandomWorker() const { | |||
| std::mt19937 gen = GetRandomDevice(); | |||
| std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1); | |||
| return dist(gen); | |||
| @@ -187,11 +187,11 @@ class CacheServer : public Service { | |||
| /// \brief Assign a worker by a numa id | |||
| /// \return worker id | |||
| worker_id_t GetWorkerByNumaId(numa_id_t node_id); | |||
| worker_id_t GetWorkerByNumaId(numa_id_t node_id) const; | |||
| /// \brief Randomly pick a worker | |||
| /// \return worker id | |||
| worker_id_t GetRandomWorker(); | |||
| worker_id_t GetRandomWorker() const; | |||
| /// \brief Check if we bind threads to numa cores | |||
| bool IsNumaAffinityOn() const { return numa_affinity_; } | |||
| @@ -227,6 +227,7 @@ class CacheServer : public Service { | |||
| std::shared_ptr<CacheServerHW> hw_info_; | |||
| std::map<worker_id_t, Task *> numa_tasks_; | |||
| bool numa_affinity_; | |||
| std::vector<int32_t> shutdown_qIDs_; | |||
| /// \brief Constructor | |||
| /// \param spill_path Top directory for spilling buffers to. | |||
| @@ -314,7 +315,7 @@ class CacheServer : public Service { | |||
| /// \brief A proper shutdown of the server | |||
| /// \return Status object | |||
| Status GlobalShutdown(); | |||
| Status GlobalShutdown(CacheServerRequest *); | |||
| /// \brief Find keys that will be cache miss | |||
| /// \return Status object | |||
| @@ -330,6 +331,13 @@ class CacheServer : public Service { | |||
| /// \brief Connect request by a pipeline | |||
| Status ConnectReset(CacheRequest *rq); | |||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | |||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | |||
| /// \param[in] v A vector of row id. | |||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | |||
| /// \return Status object | |||
| Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -209,6 +209,10 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||
| Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | |||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == CacheServiceState::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v; | |||
| datalocator_v.reserve(v.size()); | |||
| for (auto row_id : v) { | |||
| @@ -225,76 +229,6 @@ Status CacheService::PreBatchFetch(connection_id_type connection_id, const std:: | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == CacheServiceState::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| int32_t numQ = cs.GetNumGrpcWorkers(); | |||
| auto rng = GetRandomDevice(); | |||
| std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1); | |||
| int32_t qID = distribution(rng); | |||
| std::vector<CacheServerRequest *> cache_rq_list; | |||
| auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||
| const auto num_elements = p->rows()->size(); | |||
| auto connection_id = p->connection_id(); | |||
| cache_rq_list.reserve(num_elements); | |||
| int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | |||
| auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | |||
| offset_array[0] = data_offset; | |||
| for (auto i = 0; i < num_elements; ++i) { | |||
| auto data_locator = p->rows()->Get(i); | |||
| auto node_id = data_locator->node_id(); | |||
| size_t sz = data_locator->size(); | |||
| void *source_addr = reinterpret_cast<void *>(data_locator->addr()); | |||
| auto key = data_locator->key(); | |||
| // Please read the comment in CacheServer::BatchFetchRows where we allocate | |||
| // the buffer big enough so each thread (which we are going to dispatch) will | |||
| // not run into false sharing problem. We are going to round up sz to 4k. | |||
| auto sz_4k = round_up_4K(sz); | |||
| offset_array[i + 1] = offset_array[i] + sz_4k; | |||
| if (sz > 0) { | |||
| WritableSlice row_data(*out, offset_array[i], sz); | |||
| // Get a request and send to the proper worker (at some numa node) to do the fetch. | |||
| worker_id_t worker_id = cs.IsNumaAffinityOn() ? cs.GetWorkerByNumaId(node_id) : cs.GetRandomWorker(); | |||
| CacheServerRequest *cache_rq; | |||
| RETURN_IF_NOT_OK(cs.GetFreeRequestTag(qID++ % numQ, &cache_rq)); | |||
| cache_rq_list.push_back(cache_rq); | |||
| // Set up all the necessarily field. | |||
| cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; | |||
| cache_rq->st_ = CacheServerRequest::STATE::PROCESS; | |||
| cache_rq->rq_.set_connection_id(connection_id); | |||
| cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_)); | |||
| auto dest_addr = row_data.GetMutablePointer(); | |||
| flatbuffers::FlatBufferBuilder fb2; | |||
| FetchRowMsgBuilder bld(fb2); | |||
| bld.add_key(key); | |||
| bld.add_size(sz); | |||
| bld.add_source_addr(reinterpret_cast<int64_t>(source_addr)); | |||
| bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr)); | |||
| auto offset = bld.Finish(); | |||
| fb2.Finish(offset); | |||
| cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); | |||
| RETURN_IF_NOT_OK(cs.PushRequest(worker_id, cache_rq)); | |||
| } | |||
| } | |||
| // Now wait for all of them to come back. Let go of the shared lock. We shouldn't be holding | |||
| // any lock while we can wait for a long time. | |||
| rw.Unlock(); | |||
| Status rc; | |||
| for (CacheServerRequest *rq : cache_rq_list) { | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { | |||
| rc = rq->rc_; | |||
| } | |||
| RETURN_IF_NOT_OK(cs.ReturnRequestTag(rq)); | |||
| } | |||
| return rc; | |||
| } | |||
| Status CacheService::InternalFetchRow(const FetchRowMsg *p) { | |||
| RETURN_UNEXPECTED_IF_NULL(p); | |||
| SharedLock rw(&rw_lock_); | |||
| @@ -75,13 +75,6 @@ class CacheService : public Service { | |||
| Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | |||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &); | |||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | |||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | |||
| /// \param[in] v A vector of row id. | |||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | |||
| /// \return Status object | |||
| Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, WritableSlice *out) const; | |||
| /// \brief Getter function | |||
| /// \return Spilling path | |||
| Path GetSpillPath() const; | |||
| @@ -87,6 +87,7 @@ class WritableSlice : public ReadableSlice { | |||
| public: | |||
| friend class StorageContainer; | |||
| friend class CacheService; | |||
| friend class CacheServer; | |||
| /// \brief Default constructor | |||
| WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} | |||
| /// \brief This form of a constructor takes a pointer and its size. | |||
| @@ -42,7 +42,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheCApiSamplerNull) { | |||
| // Create an ImageFolder Dataset, this folder_path only has 2 images in it | |||
| std::string folder_path = datasets_root_path_ + "/testImageNetData/train/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, nullptr, {}, {}, some_cache); | |||
| EXPECT_EQ(ds, nullptr); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| // Now the parameter check for ImageFolderNode would fail and we would end up with a nullptr iter. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheImageFolderCApi) { | |||
| @@ -121,7 +121,7 @@ HandleRcExit $? 0 1 | |||
| # find a port that is occupied using netstat | |||
| if [ -x "$(command -v netstat)" ]; then | |||
| port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) | |||
| port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) | |||
| if [ ${port} -gt 1025 ]; then | |||
| # start cache server with occupied port | |||
| cmd="${CACHE_ADMIN} --start -p ${port}" | |||
| @@ -171,7 +171,12 @@ HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -w illegal" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -w 101" | |||
| num_cpu=$(grep -c processor /proc/cpuinfo) | |||
| if [ $num_cpu -lt 100 ]; then | |||
| cmd="${CACHE_ADMIN} --start -w 101" | |||
| else | |||
| cmd="${CACHE_ADMIN} --start -w ${num_cpu}+1" | |||
| fi | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -w 9999999" | |||
| @@ -1,10 +0,0 @@ | |||
| ~/cache/cache_admin --start | |||
| session_id=$(~/cache/cache_admin -g | awk '{print $NF}') | |||
| export SESSION_ID=${session_id} | |||
| pytest dataset/test_cache_nomap.py::test_cache_nomap_server_stop & | |||
| pid=("$!") | |||
| sleep 2 | |||
| ~/cache/cache_admin --stop | |||
| sleep 1 | |||
| wait ${pid} | |||