| @@ -62,7 +62,7 @@ if (ENABLE_CACHE) | |||||
| endif () | endif () | ||||
| add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) | add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) | ||||
| target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES}) | |||||
| target_link_libraries(cache_admin _c_dataengine _c_mindrecord mindspore::protobuf ${PYTHON_LIBRARIES} pthread) | |||||
| if (USE_GLOG) | if (USE_GLOG) | ||||
| target_link_libraries(cache_admin mindspore::glog) | target_link_libraries(cache_admin mindspore::glog) | ||||
| @@ -27,6 +27,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| @@ -325,9 +327,33 @@ Status CacheAdminArgHandler::RunCommand() { | |||||
| Help(); | Help(); | ||||
| break; | break; | ||||
| } | } | ||||
| case CommandId::kCmdStart: | |||||
| case CommandId::kCmdStart: { | |||||
| RETURN_IF_NOT_OK(StartServer(command_id_)); | |||||
| break; | |||||
| } | |||||
| case CommandId::kCmdStop: { | case CommandId::kCmdStop: { | ||||
| RETURN_IF_NOT_OK(StartStopServer(command_id_)); | |||||
| CacheClientGreeter comm(hostname_, port_, 1); | |||||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||||
| SharedMessage msg; | |||||
| RETURN_IF_NOT_OK(msg.Create()); | |||||
| auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId()); | |||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||||
| Status rc = rq->Wait(); | |||||
| if (rc.IsError()) { | |||||
| msg.RemoveResourcesOnExit(); | |||||
| if (rc.IsNetWorkError()) { | |||||
| std::string errMsg = "Server is not up or has been shutdown already."; | |||||
| return Status(StatusCode::kNetWorkError, errMsg); | |||||
| } | |||||
| return rc; | |||||
| } | |||||
| // OK return code only means the server acknowledge our request but we still | |||||
| // have to wait for its complete shutdown because the server will shutdown | |||||
| // the comm layer as soon as the request is received, and we need to wait | |||||
| // on the message queue instead. | |||||
| // The server will remove the queue and we will then wake up. | |||||
| Status dummy_rc; | |||||
| (void)msg.ReceiveStatus(&dummy_rc); | |||||
| break; | break; | ||||
| } | } | ||||
| case CommandId::kCmdGenerateSession: { | case CommandId::kCmdGenerateSession: { | ||||
| @@ -396,7 +422,7 @@ Status CacheAdminArgHandler::RunCommand() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { | |||||
| Status CacheAdminArgHandler::StartServer(CommandId command_id) { | |||||
| // There currently does not exist any "install path" or method to identify which path the installed binaries will | // There currently does not exist any "install path" or method to identify which path the installed binaries will | ||||
| // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | ||||
| // cache_admin binary (this process). | // cache_admin binary (this process). | ||||
| @@ -477,23 +503,15 @@ Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { | |||||
| std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_); | std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_); | ||||
| char *argv[9]; | char *argv[9]; | ||||
| if (command_id == CommandId::kCmdStart) { | |||||
| argv[0] = cache_server_binary.data(); | |||||
| argv[1] = spill_dir_.data(); | |||||
| argv[2] = workers_string.data(); | |||||
| argv[3] = port_string.data(); | |||||
| argv[4] = shared_memory_string.data(); | |||||
| argv[5] = minloglevel_string.data(); | |||||
| argv[6] = daemonize_string.data(); | |||||
| argv[7] = memory_cap_ratio_string.data(); | |||||
| argv[8] = nullptr; | |||||
| } else { | |||||
| // We are doing a --stop. Change the name to '-' and we also need the port number. | |||||
| // The rest we don't need. | |||||
| argv[0] = std::string("-").data(); | |||||
| argv[1] = port_string.data(); | |||||
| argv[2] = nullptr; | |||||
| } | |||||
| argv[0] = cache_server_binary.data(); | |||||
| argv[1] = spill_dir_.data(); | |||||
| argv[2] = workers_string.data(); | |||||
| argv[3] = port_string.data(); | |||||
| argv[4] = shared_memory_string.data(); | |||||
| argv[5] = minloglevel_string.data(); | |||||
| argv[6] = daemonize_string.data(); | |||||
| argv[7] = memory_cap_ratio_string.data(); | |||||
| argv[8] = nullptr; | |||||
| // Now exec the binary | // Now exec the binary | ||||
| execv(cache_server_binary.data(), argv); | execv(cache_server_binary.data(), argv); | ||||
| @@ -509,17 +527,27 @@ void CacheAdminArgHandler::Help() { | |||||
| std::cerr << "Syntax:\n"; | std::cerr << "Syntax:\n"; | ||||
| std::cerr << " cache_admin [--start | --stop]\n"; | std::cerr << " cache_admin [--start | --stop]\n"; | ||||
| std::cerr << " [ [-h | --hostname] <hostname> ]\n"; | std::cerr << " [ [-h | --hostname] <hostname> ]\n"; | ||||
| std::cerr << " Default is " << kCfgDefaultCacheHost << ".\n"; | |||||
| std::cerr << " [ [-p | --port] <port number> ]\n"; | std::cerr << " [ [-p | --port] <port number> ]\n"; | ||||
| std::cerr << " Possible values are in range [1025..65535].\n"; | |||||
| std::cerr << " Default is " << kCfgDefaultCachePort << ".\n"; | |||||
| std::cerr << " [ [-g | --generate_session] ]\n"; | std::cerr << " [ [-g | --generate_session] ]\n"; | ||||
| std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; | std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; | ||||
| std::cerr << " [ [-w | --workers] <number of workers> ]\n"; | std::cerr << " [ [-w | --workers] <number of workers> ]\n"; | ||||
| std::cerr << " Possible values are in range [1...max(100, Number of CPU)].\n"; | |||||
| std::cerr << " Default is " << kDefaultNumWorkers << ".\n"; | |||||
| std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; | std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; | ||||
| std::cerr << " Default is " << kDefaultSpillDir << ".\n"; | |||||
| std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; | std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; | ||||
| std::cerr << " Possible values are 0, 1, 2 and 3.\n"; | |||||
| std::cerr << " Default is 1 (info level).\n"; | |||||
| std::cerr << " [ --list_sessions ]\n"; | std::cerr << " [ --list_sessions ]\n"; | ||||
| // Do not expose these option to the user via help or documentation, but the options do exist to aid with | // Do not expose these option to the user via help or documentation, but the options do exist to aid with | ||||
| // development and tuning. | // development and tuning. | ||||
| // std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | // std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | ||||
| // std::cerr << " Default is " << kDefaultSharedMemorySizeInGB << " (Gb in unit).\n"; | |||||
| // std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n"; | // std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n"; | ||||
| // std::cerr << " Default is " << kMemoryCapRatio << ".\n"; | |||||
| std::cerr << " [--help]" << std::endl; | std::cerr << " [--help]" << std::endl; | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -78,7 +78,7 @@ class CacheAdminArgHandler { | |||||
| kArgNumArgs = 14 // Must be the last position to provide a count | kArgNumArgs = 14 // Must be the last position to provide a count | ||||
| }; | }; | ||||
| Status StartStopServer(CommandId); | |||||
| Status StartServer(CommandId command_id); | |||||
| Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | ||||
| CommandId command_id = CommandId::kCmdUnknown); | CommandId command_id = CommandId::kCmdUnknown); | ||||
| @@ -21,11 +21,7 @@ CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) | |||||
| // We create the shared memory and we will destroy it. All other client just detach only. | // We create the shared memory and we will destroy it. All other client just detach only. | ||||
| shm_.RemoveResourcesOnExit(); | shm_.RemoveResourcesOnExit(); | ||||
| } | } | ||||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() { | |||||
| // Also remove the path we use to generate ftok. | |||||
| Path p(PortToUnixSocketPath(port_)); | |||||
| (void)p.Remove(); | |||||
| } | |||||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() {} | |||||
| Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | ||||
| size_t val_in_GB) { | size_t val_in_GB) { | ||||
| @@ -72,6 +72,9 @@ class CachedSharedMemoryArena : public MemoryPool { | |||||
| return os; | return os; | ||||
| } | } | ||||
| /// \brief Get the shared memory key of the shared memory | |||||
| SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); } | |||||
| private: | private: | ||||
| mutable std::mutex mux_; | mutable std::mutex mux_; | ||||
| int32_t val_in_GB_; | int32_t val_in_GB_; | ||||
| @@ -13,10 +13,12 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <chrono> | |||||
| #include <limits> | #include <limits> | ||||
| #include "minddata/dataset/engine/cache/cache_grpc_server.h" | #include "minddata/dataset/engine/cache/cache_grpc_server.h" | ||||
| #include "minddata/dataset/engine/cache/cache_server.h" | #include "minddata/dataset/engine/cache/cache_server.h" | ||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #else | #else | ||||
| @@ -25,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) | CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) | ||||
| : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) { | |||||
| : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) { | |||||
| // Setup a path for unix socket. | // Setup a path for unix socket. | ||||
| unix_socket_ = PortToUnixSocketPath(port); | unix_socket_ = PortToUnixSocketPath(port); | ||||
| // We can't generate the ftok key yet until the unix_socket_ is created | // We can't generate the ftok key yet until the unix_socket_ is created | ||||
| @@ -73,7 +75,8 @@ Status CacheServerGreeterImpl::Run() { | |||||
| MS_LOG(INFO) << "Server listening on " << server_address; | MS_LOG(INFO) << "Server listening on " << server_address; | ||||
| #if CACHE_LOCAL_CLIENT | #if CACHE_LOCAL_CLIENT | ||||
| RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | ||||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful"; | |||||
| shm_key_ = shm_pool_->GetKey(); | |||||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_; | |||||
| auto cs = CacheServer::GetInstance().GetHWControl(); | auto cs = CacheServer::GetInstance().GetHWControl(); | ||||
| // This shared memory is a hot memory and we will interleave among all the numa nodes. | // This shared memory is a hot memory and we will interleave among all the numa nodes. | ||||
| cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); | cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); | ||||
| @@ -181,5 +184,32 @@ void CacheServerRequest::Print(std::ostream &out) const { | |||||
| out << " "; | out << " "; | ||||
| BaseRequest::Print(out); | BaseRequest::Print(out); | ||||
| } | } | ||||
| Status CacheServerGreeterImpl::MonitorUnixSocket() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| Path p(unix_socket_); | |||||
| do { | |||||
| RETURN_IF_INTERRUPTED(); | |||||
| // If the unix socket is recreated for whatever reason, this server instance will be stale and | |||||
| // no other process and communicate with us. In this case we need to shutdown ourselves. | |||||
| if (p.Exists()) { | |||||
| SharedMemory::shm_key_t key; | |||||
| RETURN_IF_NOT_OK(PortToFtok(port_, &key)); | |||||
| if (key != shm_key_) { | |||||
| std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) + | |||||
| ". New key " + std::to_string(key) + ". Shutting down server"; | |||||
| MS_LOG(ERROR) << errMsg; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(WARNING) << "Unix socket is removed."; | |||||
| TaskManager::WakeUpWatchDog(); | |||||
| } | |||||
| std::this_thread::sleep_for(std::chrono::seconds(5)); | |||||
| } while (true); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -87,6 +87,9 @@ class CacheServerGreeterImpl final { | |||||
| /// \return Return the shared memory pool | /// \return Return the shared memory pool | ||||
| CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } | CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } | ||||
| /// \brief Montor the status of the unix socket in case it is gone. | |||||
| Status MonitorUnixSocket(); | |||||
| void Shutdown(); | void Shutdown(); | ||||
| private: | private: | ||||
| @@ -97,6 +100,7 @@ class CacheServerGreeterImpl final { | |||||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_; | std::unique_ptr<grpc::ServerCompletionQueue> cq_; | ||||
| std::unique_ptr<grpc::Server> server_; | std::unique_ptr<grpc::Server> server_; | ||||
| std::unique_ptr<CachedSharedMemoryArena> shm_pool_; | std::unique_ptr<CachedSharedMemoryArena> shm_pool_; | ||||
| SharedMemory::shm_key_t shm_key_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -160,6 +160,9 @@ class SharedMemory : public BaseIPC { | |||||
| /// \brief Set the public key | /// \brief Set the public key | ||||
| void SetPublicKey(key_t public_key) { shm_key_ = public_key; } | void SetPublicKey(key_t public_key) { shm_key_ = public_key; } | ||||
| /// \brief Retrieve the key | |||||
| shm_key_t GetKey() const { return shm_key_; } | |||||
| /// \brief This returns where we attach to the shared memory. | /// \brief This returns where we attach to the shared memory. | ||||
| /// \return Base address of the shared memory. | /// \return Base address of the shared memory. | ||||
| const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | ||||
| @@ -27,116 +27,6 @@ | |||||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | #include "minddata/dataset/engine/cache/cache_ipc.h" | ||||
| namespace ds = mindspore::dataset; | namespace ds = mindspore::dataset; | ||||
| /// Send a synchronous command to the local server using tcp/ip. | |||||
| /// We aren't using any client code because this binary is not necessarily linked with the client library. | |||||
| /// So just using grpc call directly. | |||||
| /// \param port tcp/ip port to use | |||||
| /// \param type Type of command. | |||||
| /// \param out grpc result | |||||
| /// \return Status object | |||||
| ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply, | |||||
| grpc::Status *out) { | |||||
| if (rq == nullptr) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null"); | |||||
| } | |||||
| if (reply == nullptr) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null"); | |||||
| } | |||||
| if (out == nullptr) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null"); | |||||
| } | |||||
| const std::string hostname = "127.0.0.1"; | |||||
| auto unix_socket = ds::PortToUnixSocketPath(port); | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| const std::string target = "unix://" + unix_socket; | |||||
| #else | |||||
| const std::string target = hostname + ":" + std::to_string(port); | |||||
| #endif | |||||
| try { | |||||
| rq->set_type(static_cast<int16_t>(type)); | |||||
| rq->set_client_id(-1); | |||||
| rq->set_flag(0); | |||||
| grpc::ChannelArguments args; | |||||
| grpc::ClientContext ctx; | |||||
| grpc::CompletionQueue cq; | |||||
| // Standard async rpc call | |||||
| std::shared_ptr<grpc::Channel> channel = | |||||
| grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); | |||||
| std::unique_ptr<ds::CacheServerGreeter::Stub> stub = ds::CacheServerGreeter::NewStub(channel); | |||||
| std::unique_ptr<grpc::ClientAsyncResponseReader<ds::CacheReply>> rpc = | |||||
| stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq); | |||||
| rpc->StartCall(); | |||||
| // We need to pass a tag. But since this is the only request in the completion queue and so we | |||||
| // just pass a dummy | |||||
| int64_t dummy; | |||||
| void *tag; | |||||
| bool success; | |||||
| rpc->Finish(reply, out, &dummy); | |||||
| // Now we wait on the completion queue synchronously. | |||||
| auto r = cq.Next(&tag, &success); | |||||
| if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { | |||||
| if (!success || tag != &dummy) { | |||||
| std::string errMsg = "Unexpected programming error "; | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } | |||||
| if (out->ok()) { | |||||
| return ds::Status(static_cast<ds::StatusCode>(reply->rc()), reply->msg()); | |||||
| } else { | |||||
| auto error_code = out->error_code(); | |||||
| std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code); | |||||
| return ds::Status(ds::StatusCode::kNetWorkError, errMsg); | |||||
| } | |||||
| } else { | |||||
| std::string errMsg = "Unexpected queue rc = " + std::to_string(r); | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } | |||||
| } catch (const std::exception &e) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); | |||||
| } | |||||
| } | |||||
| /// Stop the server | |||||
| /// \param argv | |||||
| /// \return Status object | |||||
| ds::Status StopServer(int argc, char **argv) { | |||||
| ds::Status rc; | |||||
| ds::CacheServer::Builder builder; | |||||
| std::string errMsg; | |||||
| if (argc != 2) { | |||||
| return ds::Status(ds::StatusCode::kSyntaxError); | |||||
| } | |||||
| int32_t port = strtol(argv[1], nullptr, 10); | |||||
| // We will go through the builder to do some snaity check. We only need the port number | |||||
| // to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything | |||||
| // to the spill directory. | |||||
| builder.SetPort(port).SetRootDirectory(""); | |||||
| // Part of the sanity check is check the shared memory. If the server is up and running, we expect | |||||
| // the return code is kDuplicate. | |||||
| rc = builder.SanityCheck(); | |||||
| if (rc.IsOk()) { | |||||
| errMsg = "Server is not up or has been shutdown already."; | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, errMsg); | |||||
| } else if (rc.get_code() != ds::StatusCode::kDuplicateKey) { | |||||
| // Not OK, and no duplicate, just return the rc whatever it is. | |||||
| return rc; | |||||
| } else { | |||||
| // Now we get some work to do. We will send a tcp/ip request to the given port. | |||||
| // This binary is not linked with client side of code, so we will just call grpc directly. | |||||
| ds::CacheRequest rq; | |||||
| ds::CacheReply reply; | |||||
| grpc::Status grpc_rc; | |||||
| rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc); | |||||
| // The request is like a self destruct message, the server will not send anything back and | |||||
| // shutdown all incoming request. So we should expect some unexpected network error if | |||||
| // all goes well and we expect to GRPC code 14. | |||||
| auto err_code = grpc_rc.error_code(); | |||||
| if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__); | |||||
| } | |||||
| } | |||||
| return ds::Status::OK(); | |||||
| } | |||||
| /// Start the server | /// Start the server | ||||
| /// \param argv | /// \param argv | ||||
| /// \return Status object | /// \return Status object | ||||
| @@ -235,15 +125,8 @@ ds::Status StartServer(int argc, char **argv) { | |||||
| } | } | ||||
| int main(int argc, char **argv) { | int main(int argc, char **argv) { | ||||
| ds::Status rc; | |||||
| ds::CacheServer::Builder builder; | |||||
| // This executable is not to be called directly, and should be invoked by cache_admin executable. | // This executable is not to be called directly, and should be invoked by cache_admin executable. | ||||
| if (strcmp(argv[0], "-") == 0) { | |||||
| rc = StopServer(argc, argv); | |||||
| } else { | |||||
| rc = StartServer(argc, argv); | |||||
| } | |||||
| ds::Status rc = StartServer(argc, argv); | |||||
| // Check result | // Check result | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| auto errCode = rc.get_code(); | auto errCode = rc.get_code(); | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #endif | #endif | ||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <cstring> | |||||
| #include <thread> | #include <thread> | ||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| @@ -326,5 +327,11 @@ Status ListSessionsRequest::PostReply() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ServerStopRequest::PostReply() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(strcmp(reply_.result().data(), "OK") == 0, "Not the right response"); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -394,6 +394,15 @@ class ToggleWriteModeRequest : public BaseRequest { | |||||
| ~ToggleWriteModeRequest() override = default; | ~ToggleWriteModeRequest() override = default; | ||||
| }; | }; | ||||
| class ServerStopRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| explicit ServerStopRequest(int32_t qID) : BaseRequest(RequestType::kStopService) { | |||||
| rq_.add_buf_data(std::to_string(qID)); | |||||
| } | |||||
| Status PostReply() override; | |||||
| }; | |||||
| class ConnectResetRequest : public BaseRequest { | class ConnectResetRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| @@ -108,6 +108,9 @@ Status CacheServer::DoServiceStart() { | |||||
| try { | try { | ||||
| comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | ||||
| RETURN_IF_NOT_OK(comm_layer_->Run()); | RETURN_IF_NOT_OK(comm_layer_->Run()); | ||||
| // Bring up a thread to monitor the unix socket in case it is removed. | |||||
| auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get()); | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f)); | |||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| RETURN_STATUS_UNEXPECTED(e.what()); | RETURN_STATUS_UNEXPECTED(e.what()); | ||||
| } | } | ||||
| @@ -154,6 +157,15 @@ Status CacheServer::DoServiceStop() { | |||||
| } | } | ||||
| ++it; | ++it; | ||||
| } | } | ||||
| // Also remove the path we use to generate ftok. | |||||
| Path p(PortToUnixSocketPath(port_)); | |||||
| (void)p.Remove(); | |||||
| // Finally wake up cache_admin if it is waiting | |||||
| for (int32_t qID : shutdown_qIDs_) { | |||||
| SharedMessage msg(qID); | |||||
| msg.RemoveResourcesOnExit(); | |||||
| // Let msg goes out of scope which will destroy the queue. | |||||
| } | |||||
| return rc; | return rc; | ||||
| } | } | ||||
| @@ -374,6 +386,68 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||||
| return rc; | return rc; | ||||
| } | } | ||||
| Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| int32_t numQ = GetNumGrpcWorkers(); | |||||
| auto rng = GetRandomDevice(); | |||||
| std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1); | |||||
| int32_t qID = distribution(rng); | |||||
| std::vector<CacheServerRequest *> cache_rq_list; | |||||
| auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||||
| const auto num_elements = p->rows()->size(); | |||||
| auto connection_id = p->connection_id(); | |||||
| cache_rq_list.reserve(num_elements); | |||||
| int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | |||||
| auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | |||||
| offset_array[0] = data_offset; | |||||
| for (auto i = 0; i < num_elements; ++i) { | |||||
| auto data_locator = p->rows()->Get(i); | |||||
| auto node_id = data_locator->node_id(); | |||||
| size_t sz = data_locator->size(); | |||||
| void *source_addr = reinterpret_cast<void *>(data_locator->addr()); | |||||
| auto key = data_locator->key(); | |||||
| // Please read the comment in CacheServer::BatchFetchRows where we allocate | |||||
| // the buffer big enough so each thread (which we are going to dispatch) will | |||||
| // not run into false sharing problem. We are going to round up sz to 4k. | |||||
| auto sz_4k = round_up_4K(sz); | |||||
| offset_array[i + 1] = offset_array[i] + sz_4k; | |||||
| if (sz > 0) { | |||||
| WritableSlice row_data(*out, offset_array[i], sz); | |||||
| // Get a request and send to the proper worker (at some numa node) to do the fetch. | |||||
| worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker(); | |||||
| CacheServerRequest *cache_rq; | |||||
| RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq)); | |||||
| cache_rq_list.push_back(cache_rq); | |||||
| // Set up all the necessarily field. | |||||
| cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; | |||||
| cache_rq->st_ = CacheServerRequest::STATE::PROCESS; | |||||
| cache_rq->rq_.set_connection_id(connection_id); | |||||
| cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_)); | |||||
| auto dest_addr = row_data.GetMutablePointer(); | |||||
| flatbuffers::FlatBufferBuilder fb2; | |||||
| FetchRowMsgBuilder bld(fb2); | |||||
| bld.add_key(key); | |||||
| bld.add_size(sz); | |||||
| bld.add_source_addr(reinterpret_cast<int64_t>(source_addr)); | |||||
| bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr)); | |||||
| auto offset = bld.Finish(); | |||||
| fb2.Finish(offset); | |||||
| cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); | |||||
| RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq)); | |||||
| } | |||||
| } | |||||
| // Now wait for all of them to come back. | |||||
| Status rc; | |||||
| for (CacheServerRequest *rq : cache_rq_list) { | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { | |||||
| rc = rq->rc_; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ReturnRequestTag(rq)); | |||||
| } | |||||
| return rc; | |||||
| } | |||||
| Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | ||||
| auto connection_id = rq->connection_id(); | auto connection_id = rq->connection_id(); | ||||
| // Hold the shared lock to prevent the cache from being dropped. | // Hold the shared lock to prevent the cache from being dropped. | ||||
| @@ -394,6 +468,9 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||||
| } | } | ||||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | ||||
| RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb)); | RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb)); | ||||
| // Let go of the shared lock. We don't need to interact with the CacheService anymore. | |||||
| // We shouldn't be holding any lock while we can wait for a long time for the rows to come back. | |||||
| lck.Unlock(); | |||||
| auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | ||||
| int64_t mem_sz = sizeof(int64_t) * (sz + 1); | int64_t mem_sz = sizeof(int64_t) * (sz + 1); | ||||
| for (auto i = 0; i < sz; ++i) { | for (auto i = 0; i < sz; ++i) { | ||||
| @@ -418,7 +495,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||||
| void *q = nullptr; | void *q = nullptr; | ||||
| RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | ||||
| WritableSlice dest(q, mem_sz); | WritableSlice dest(q, mem_sz); | ||||
| Status rc = cs->BatchFetch(fbb, &dest); | |||||
| Status rc = BatchFetch(fbb, &dest); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| shared_pool->Deallocate(q); | shared_pool->Deallocate(q); | ||||
| return rc; | return rc; | ||||
| @@ -439,7 +516,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| } | } | ||||
| WritableSlice dest(mem.data(), mem_sz); | WritableSlice dest(mem.data(), mem_sz); | ||||
| RETURN_IF_NOT_OK(cs->BatchFetch(fbb, &dest)); | |||||
| RETURN_IF_NOT_OK(BatchFetch(fbb, &dest)); | |||||
| reply->set_result(std::move(mem)); | reply->set_result(std::move(mem)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -721,7 +798,7 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) { | |||||
| } | } | ||||
| case BaseRequest::RequestType::kStopService: { | case BaseRequest::RequestType::kStopService: { | ||||
| // This command shutdowns everything. | // This command shutdowns everything. | ||||
| cache_req->rc_ = GlobalShutdown(); | |||||
| cache_req->rc_ = GlobalShutdown(cache_req); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kHeartBeat: { | case BaseRequest::RequestType::kHeartBeat: { | ||||
| @@ -914,7 +991,25 @@ Status CacheServer::RpcRequest(worker_id_t worker_id) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheServer::GlobalShutdown() { | |||||
| Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { | |||||
| auto *rq = &cache_req->rq_; | |||||
| auto *reply = &cache_req->reply_; | |||||
| if (!rq->buf_data().empty()) { | |||||
| // cache_admin sends us a message qID and we will destroy the | |||||
| // queue in our destructor and this will wake up cache_admin. | |||||
| // But we don't want the cache_admin blindly just block itself. | |||||
| // So we will send back an ack before shutdown the comm layer. | |||||
| try { | |||||
| int32_t qID = std::stoi(rq->buf_data(0)); | |||||
| shutdown_qIDs_.push_back(qID); | |||||
| } catch (const std::exception &e) { | |||||
| // ignore it. | |||||
| } | |||||
| } | |||||
| reply->set_result("OK"); | |||||
| Status2CacheReply(cache_req->rc_, reply); | |||||
| cache_req->st_ = CacheServerRequest::STATE::FINISH; | |||||
| cache_req->responder_.Finish(*reply, grpc::Status::OK, cache_req); | |||||
| // Let's shutdown in proper order. | // Let's shutdown in proper order. | ||||
| bool expected = false; | bool expected = false; | ||||
| if (global_shutdown_.compare_exchange_strong(expected, true)) { | if (global_shutdown_.compare_exchange_strong(expected, true)) { | ||||
| @@ -939,7 +1034,7 @@ Status CacheServer::GlobalShutdown() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { | |||||
| worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const { | |||||
| auto num_numa_nodes = GetNumaNodeCount(); | auto num_numa_nodes = GetNumaNodeCount(); | ||||
| MS_ASSERT(numa_id < num_numa_nodes); | MS_ASSERT(numa_id < num_numa_nodes); | ||||
| auto num_workers_per_node = GetNumWorkers() / num_numa_nodes; | auto num_workers_per_node = GetNumWorkers() / num_numa_nodes; | ||||
| @@ -951,7 +1046,7 @@ worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { | |||||
| return worker_id; | return worker_id; | ||||
| } | } | ||||
| worker_id_t CacheServer::GetRandomWorker() { | |||||
| worker_id_t CacheServer::GetRandomWorker() const { | |||||
| std::mt19937 gen = GetRandomDevice(); | std::mt19937 gen = GetRandomDevice(); | ||||
| std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1); | std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1); | ||||
| return dist(gen); | return dist(gen); | ||||
| @@ -187,11 +187,11 @@ class CacheServer : public Service { | |||||
| /// \brief Assign a worker by a numa id | /// \brief Assign a worker by a numa id | ||||
| /// \return worker id | /// \return worker id | ||||
| worker_id_t GetWorkerByNumaId(numa_id_t node_id); | |||||
| worker_id_t GetWorkerByNumaId(numa_id_t node_id) const; | |||||
| /// \brief Randomly pick a worker | /// \brief Randomly pick a worker | ||||
| /// \return worker id | /// \return worker id | ||||
| worker_id_t GetRandomWorker(); | |||||
| worker_id_t GetRandomWorker() const; | |||||
| /// \brief Check if we bind threads to numa cores | /// \brief Check if we bind threads to numa cores | ||||
| bool IsNumaAffinityOn() const { return numa_affinity_; } | bool IsNumaAffinityOn() const { return numa_affinity_; } | ||||
| @@ -227,6 +227,7 @@ class CacheServer : public Service { | |||||
| std::shared_ptr<CacheServerHW> hw_info_; | std::shared_ptr<CacheServerHW> hw_info_; | ||||
| std::map<worker_id_t, Task *> numa_tasks_; | std::map<worker_id_t, Task *> numa_tasks_; | ||||
| bool numa_affinity_; | bool numa_affinity_; | ||||
| std::vector<int32_t> shutdown_qIDs_; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param spill_path Top directory for spilling buffers to. | /// \param spill_path Top directory for spilling buffers to. | ||||
| @@ -314,7 +315,7 @@ class CacheServer : public Service { | |||||
| /// \brief A proper shutdown of the server | /// \brief A proper shutdown of the server | ||||
| /// \return Status object | /// \return Status object | ||||
| Status GlobalShutdown(); | |||||
| Status GlobalShutdown(CacheServerRequest *); | |||||
| /// \brief Find keys that will be cache miss | /// \brief Find keys that will be cache miss | ||||
| /// \return Status object | /// \return Status object | ||||
| @@ -330,6 +331,13 @@ class CacheServer : public Service { | |||||
| /// \brief Connect request by a pipeline | /// \brief Connect request by a pipeline | ||||
| Status ConnectReset(CacheRequest *rq); | Status ConnectReset(CacheRequest *rq); | ||||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | |||||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | |||||
| /// \param[in] v A vector of row id. | |||||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | |||||
| /// \return Status object | |||||
| Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out); | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -209,6 +209,10 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||||
| Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | ||||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) { | const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) { | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| if (st_ == CacheServiceState::kBuildPhase) { | |||||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||||
| } | |||||
| std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v; | std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v; | ||||
| datalocator_v.reserve(v.size()); | datalocator_v.reserve(v.size()); | ||||
| for (auto row_id : v) { | for (auto row_id : v) { | ||||
| @@ -225,76 +229,6 @@ Status CacheService::PreBatchFetch(connection_id_type connection_id, const std:: | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) const { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| SharedLock rw(&rw_lock_); | |||||
| if (st_ == CacheServiceState::kBuildPhase) { | |||||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||||
| } | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| int32_t numQ = cs.GetNumGrpcWorkers(); | |||||
| auto rng = GetRandomDevice(); | |||||
| std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1); | |||||
| int32_t qID = distribution(rng); | |||||
| std::vector<CacheServerRequest *> cache_rq_list; | |||||
| auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||||
| const auto num_elements = p->rows()->size(); | |||||
| auto connection_id = p->connection_id(); | |||||
| cache_rq_list.reserve(num_elements); | |||||
| int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | |||||
| auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | |||||
| offset_array[0] = data_offset; | |||||
| for (auto i = 0; i < num_elements; ++i) { | |||||
| auto data_locator = p->rows()->Get(i); | |||||
| auto node_id = data_locator->node_id(); | |||||
| size_t sz = data_locator->size(); | |||||
| void *source_addr = reinterpret_cast<void *>(data_locator->addr()); | |||||
| auto key = data_locator->key(); | |||||
| // Please read the comment in CacheServer::BatchFetchRows where we allocate | |||||
| // the buffer big enough so each thread (which we are going to dispatch) will | |||||
| // not run into false sharing problem. We are going to round up sz to 4k. | |||||
| auto sz_4k = round_up_4K(sz); | |||||
| offset_array[i + 1] = offset_array[i] + sz_4k; | |||||
| if (sz > 0) { | |||||
| WritableSlice row_data(*out, offset_array[i], sz); | |||||
| // Get a request and send to the proper worker (at some numa node) to do the fetch. | |||||
| worker_id_t worker_id = cs.IsNumaAffinityOn() ? cs.GetWorkerByNumaId(node_id) : cs.GetRandomWorker(); | |||||
| CacheServerRequest *cache_rq; | |||||
| RETURN_IF_NOT_OK(cs.GetFreeRequestTag(qID++ % numQ, &cache_rq)); | |||||
| cache_rq_list.push_back(cache_rq); | |||||
| // Set up all the necessarily field. | |||||
| cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; | |||||
| cache_rq->st_ = CacheServerRequest::STATE::PROCESS; | |||||
| cache_rq->rq_.set_connection_id(connection_id); | |||||
| cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_)); | |||||
| auto dest_addr = row_data.GetMutablePointer(); | |||||
| flatbuffers::FlatBufferBuilder fb2; | |||||
| FetchRowMsgBuilder bld(fb2); | |||||
| bld.add_key(key); | |||||
| bld.add_size(sz); | |||||
| bld.add_source_addr(reinterpret_cast<int64_t>(source_addr)); | |||||
| bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr)); | |||||
| auto offset = bld.Finish(); | |||||
| fb2.Finish(offset); | |||||
| cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); | |||||
| RETURN_IF_NOT_OK(cs.PushRequest(worker_id, cache_rq)); | |||||
| } | |||||
| } | |||||
| // Now wait for all of them to come back. Let go of the shared lock. We shouldn't be holding | |||||
| // any lock while we can wait for a long time. | |||||
| rw.Unlock(); | |||||
| Status rc; | |||||
| for (CacheServerRequest *rq : cache_rq_list) { | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { | |||||
| rc = rq->rc_; | |||||
| } | |||||
| RETURN_IF_NOT_OK(cs.ReturnRequestTag(rq)); | |||||
| } | |||||
| return rc; | |||||
| } | |||||
| Status CacheService::InternalFetchRow(const FetchRowMsg *p) { | Status CacheService::InternalFetchRow(const FetchRowMsg *p) { | ||||
| RETURN_UNEXPECTED_IF_NULL(p); | RETURN_UNEXPECTED_IF_NULL(p); | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| @@ -75,13 +75,6 @@ class CacheService : public Service { | |||||
| Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | ||||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &); | const std::shared_ptr<flatbuffers::FlatBufferBuilder> &); | ||||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | |||||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | |||||
| /// \param[in] v A vector of row id. | |||||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | |||||
| /// \return Status object | |||||
| Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, WritableSlice *out) const; | |||||
| /// \brief Getter function | /// \brief Getter function | ||||
| /// \return Spilling path | /// \return Spilling path | ||||
| Path GetSpillPath() const; | Path GetSpillPath() const; | ||||
| @@ -87,6 +87,7 @@ class WritableSlice : public ReadableSlice { | |||||
| public: | public: | ||||
| friend class StorageContainer; | friend class StorageContainer; | ||||
| friend class CacheService; | friend class CacheService; | ||||
| friend class CacheServer; | |||||
| /// \brief Default constructor | /// \brief Default constructor | ||||
| WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} | WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} | ||||
| /// \brief This form of a constructor takes a pointer and its size. | /// \brief This form of a constructor takes a pointer and its size. | ||||
| @@ -42,7 +42,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheCApiSamplerNull) { | |||||
| // Create an ImageFolder Dataset, this folder_path only has 2 images in it | // Create an ImageFolder Dataset, this folder_path only has 2 images in it | ||||
| std::string folder_path = datasets_root_path_ + "/testImageNetData/train/"; | std::string folder_path = datasets_root_path_ + "/testImageNetData/train/"; | ||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, nullptr, {}, {}, some_cache); | std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, nullptr, {}, {}, some_cache); | ||||
| EXPECT_EQ(ds, nullptr); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| // Now the parameter check for ImageFolderNode would fail and we would end up with a nullptr iter. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_EQ(iter, nullptr); | |||||
| } | } | ||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheImageFolderCApi) { | TEST_F(MindDataTestCacheOp, DISABLED_TestCacheImageFolderCApi) { | ||||
| @@ -121,7 +121,7 @@ HandleRcExit $? 0 1 | |||||
| # find a port that is occupied using netstat | # find a port that is occupied using netstat | ||||
| if [ -x "$(command -v netstat)" ]; then | if [ -x "$(command -v netstat)" ]; then | ||||
| port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) | |||||
| port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) | |||||
| if [ ${port} -gt 1025 ]; then | if [ ${port} -gt 1025 ]; then | ||||
| # start cache server with occupied port | # start cache server with occupied port | ||||
| cmd="${CACHE_ADMIN} --start -p ${port}" | cmd="${CACHE_ADMIN} --start -p ${port}" | ||||
| @@ -171,7 +171,12 @@ HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} --start -w illegal" | cmd="${CACHE_ADMIN} --start -w illegal" | ||||
| CacheAdminCmd "${cmd}" 1 | CacheAdminCmd "${cmd}" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| cmd="${CACHE_ADMIN} --start -w 101" | |||||
| num_cpu=$(grep -c processor /proc/cpuinfo) | |||||
| if [ $num_cpu -lt 100 ]; then | |||||
| cmd="${CACHE_ADMIN} --start -w 101" | |||||
| else | |||||
| cmd="${CACHE_ADMIN} --start -w ${num_cpu}+1" | |||||
| fi | |||||
| CacheAdminCmd "${cmd}" 1 | CacheAdminCmd "${cmd}" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| cmd="${CACHE_ADMIN} --start -w 9999999" | cmd="${CACHE_ADMIN} --start -w 9999999" | ||||
| @@ -1,10 +0,0 @@ | |||||
| ~/cache/cache_admin --start | |||||
| session_id=$(~/cache/cache_admin -g | awk '{print $NF}') | |||||
| export SESSION_ID=${session_id} | |||||
| pytest dataset/test_cache_nomap.py::test_cache_nomap_server_stop & | |||||
| pid=("$!") | |||||
| sleep 2 | |||||
| ~/cache/cache_admin --stop | |||||
| sleep 1 | |||||
| wait ${pid} | |||||