diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc index 8feaf2cf3b..0995c4ea48 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc @@ -42,14 +42,6 @@ int main(int argc, char **argv) { google::InitGoogleLogging(argv[0]); #endif - // Create default spilling dir - ds::Path spill_dir = ds::Path(ds::DefaultSpillDir()); - rc = spill_dir.CreateDirectories(); - if (!rc.IsOk()) { - std::cerr << rc.ToString() << std::endl; - return 1; - } - if (argc == 1) { args.Help(); return 0; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc index a0c33daea0..cf735224c8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc @@ -48,7 +48,7 @@ CacheAdminArgHandler::CacheAdminArgHandler() log_level_(kDefaultLogLevel), memory_cap_ratio_(kMemoryCapRatio), hostname_(kCfgDefaultCacheHost), - spill_dir_(DefaultSpillDir()), + spill_dir_(""), command_id_(CommandId::kCmdUnknown) { // Initialize the command mappings arg_map_["-h"] = ArgValue::kArgHost; @@ -334,32 +334,7 @@ Status CacheAdminArgHandler::RunCommand() { break; } case CommandId::kCmdStop: { - CacheClientGreeter comm(hostname_, port_, 1); - RETURN_IF_NOT_OK(comm.ServiceStart()); - SharedMessage msg; - RETURN_IF_NOT_OK(msg.Create()); - auto rq = std::make_shared(msg.GetMsgQueueId()); - RETURN_IF_NOT_OK(comm.HandleRequest(rq)); - Status rc = rq->Wait(); - if (rc.IsError()) { - msg.RemoveResourcesOnExit(); - if (rc.IsNetWorkError()) { - std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already."; - return Status(StatusCode::kNetWorkError, errMsg); - } - return rc; - } - // OK return code only means the server acknowledge our request but we still - // have to wait for its complete shutdown because the server will shutdown - // the comm layer as soon as the request is received, and we need to wait - // on the message queue instead. - // The server will send a message back and remove the queue and we will then wake up. But on the safe - // side, we will also set up an alarm and kill this process if we hang on - // the message queue. - alarm(60); - Status dummy_rc; - (void)msg.ReceiveStatus(&dummy_rc); - std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl; + RETURN_IF_NOT_OK(StopServer(command_id_)); break; } case CommandId::kCmdGenerateSession: { @@ -430,6 +405,36 @@ Status CacheAdminArgHandler::RunCommand() { return Status::OK(); } +Status CacheAdminArgHandler::StopServer(CommandId command_id) { + CacheClientGreeter comm(hostname_, port_, 1); + RETURN_IF_NOT_OK(comm.ServiceStart()); + SharedMessage msg; + RETURN_IF_NOT_OK(msg.Create()); + auto rq = std::make_shared(msg.GetMsgQueueId()); + RETURN_IF_NOT_OK(comm.HandleRequest(rq)); + Status rc = rq->Wait(); + if (rc.IsError()) { + msg.RemoveResourcesOnExit(); + if (rc.IsNetWorkError()) { + std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already."; + return Status(StatusCode::kNetWorkError, errMsg); + } + return rc; + } + // OK return code only means the server acknowledge our request but we still + // have to wait for its complete shutdown because the server will shutdown + // the comm layer as soon as the request is received, and we need to wait + // on the message queue instead. + // The server will send a message back and remove the queue and we will then wake up. But on the safe + // side, we will also set up an alarm and kill this process if we hang on + // the message queue. + alarm(60); + Status dummy_rc; + (void)msg.ReceiveStatus(&dummy_rc); + std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl; + return Status::OK(); +} + Status CacheAdminArgHandler::StartServer(CommandId command_id) { // There currently does not exist any "install path" or method to identify which path the installed binaries will // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the @@ -462,7 +467,6 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) { // fork the child process to become the daemon pid_t pid; pid = fork(); - // failed to fork if (pid < 0) { std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); @@ -538,7 +542,7 @@ void CacheAdminArgHandler::Help() { std::cerr << " [[-h | --hostname] ] Default is " << kCfgDefaultCacheHost << ".\n"; std::cerr << " [[-p | --port] ] Default is " << kCfgDefaultCachePort << ".\n"; std::cerr << " [[-w | --workers] ] Default is " << kDefaultNumWorkers << ".\n"; - std::cerr << " [[-s | --spilldir] ] Default is " << DefaultSpillDir() << ".\n"; + std::cerr << " [[-s | --spilldir] ] Default is no spilling.\n"; std::cerr << " [[-l | --loglevel] ] Default is 1 (warning level).\n"; std::cerr << " [--destroy_session | -d] \n"; std::cerr << " [[-p | --port] ]\n"; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h index e06eb07fcd..2779c76b78 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h @@ -79,6 +79,8 @@ class CacheAdminArgHandler { Status StartServer(CommandId command_id); + Status StopServer(CommandId command_id); + Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, CommandId command_id = CommandId::kCmdUnknown); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h index 101b15d369..637bbe38c8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h @@ -91,9 +91,6 @@ using worker_id_t = int32_t; using numa_id_t = int32_t; using cpu_id_t = int32_t; -/// Return the default spill dir for cache -inline std::string DefaultSpillDir() { return kDefaultPathPrefix; } - /// Return the default log dir for cache inline std::string DefaultLogDir() { return kDefaultPathPrefix + std::string("/log"); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h index 7edbe415f8..6ff6bdb16f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -125,6 +125,33 @@ class BaseRequest { /// \return Status object Status Wait(); + /// \brief Return if the request is of row request type + /// \return True if the request is row-related request + bool IsRowRequest() const { + return type_ == RequestType::kBatchCacheRows || type_ == RequestType::kBatchFetchRows || + type_ == RequestType::kInternalCacheRow || type_ == RequestType::kInternalFetchRow || + type_ == RequestType::kCacheRow; + } + + /// \brief Return if the request is of admin request type + /// \return True if the request is admin-related request + bool IsAdminRequest() const { + return type_ == RequestType::kCreateCache || type_ == RequestType::kDestroyCache || + type_ == RequestType::kGetStat || type_ == RequestType::kGetCacheState || + type_ == RequestType::kAllocateSharedBlock || type_ == RequestType::kFreeSharedBlock || + type_ == RequestType::kCacheSchema || type_ == RequestType::kFetchSchema || + type_ == RequestType::kBuildPhaseDone || type_ == RequestType::kToggleWriteMode || + type_ == RequestType::kConnectReset || type_ == RequestType::kStopService || + type_ == RequestType::kHeartBeat || type_ == RequestType::kGetCacheMissKeys; + } + + /// \brief Return if the request is of session request type + /// \return True if the request is session-related request + bool IsSessionRequest() const { + return type_ == RequestType::kGenerateSessionId || type_ == RequestType::kDropSession || + type_ == RequestType::kListSessions; + } + protected: CacheRequest rq_; // This is what we send to the server CacheReply reply_; // This is what the server send back diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc index 039d7f1e49..31d9b936d6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -155,6 +155,42 @@ CacheService *CacheServer::GetService(connection_id_type id) const { return nullptr; } +// We would like to protect ourselves from over allocating too much. We will go over existing cache +// and calculate how much we have consumed so far. +Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) { + auto end = all_caches_.end(); + auto it = all_caches_.begin(); + auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; + int64_t max_avail = avail_mem; + while (it != end) { + auto &cs = it->second; + CacheService::ServiceStat stat; + RETURN_IF_NOT_OK(cs->GetStat(&stat)); + int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz; + max_avail -= mem_consumed; + if (max_avail <= 0) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); + } + ++it; + } + + // If we have some cache using some memory already, make a reasonable decision if we should return + // out of memory. + if (max_avail < avail_mem) { + int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit. + if (req_mem > max_avail) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); + } else if (req_mem == 0) { + // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than + // 85% of our limit, fail this request. + if (static_cast(max_avail) / static_cast(avail_mem) <= 0.15) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); + } + } + } + return Status::OK(); +} + Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); std::string cookie; @@ -186,55 +222,25 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { if (spill && top_.empty()) { RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); } - flatbuffers::FlatBufferBuilder fbb; - flatbuffers::Offset off_cookie; - flatbuffers::Offset> off_cpu_list; // Before creating the cache, first check if this is a request for a shared usage of an existing cache // If two CreateService come in with identical connection_id, we need to serialize the create. // The first create will be successful and be given a special cookie. UniqueLock lck(&rwLock_); + bool duplicate = false; + CacheService *curr_cs = GetService(connection_id); + if (curr_cs != nullptr) { + duplicate = true; + client_id = curr_cs->num_clients_.fetch_add(1); + MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " + + std::to_string(connection_id) + " to create cache service"; + } // Early exit if we are doing global shutdown if (global_shutdown_) { return Status::OK(); } - // We would like to protect ourselves from over allocating too much. We will go over existing cache - // and calculate how much we have consumed so far. - auto end = all_caches_.end(); - auto it = all_caches_.begin(); - bool duplicate = false; - auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; - int64_t max_avail = avail_mem; - while (it != end) { - if (it->first == connection_id) { - duplicate = true; - break; - } else { - auto &cs = it->second; - CacheService::ServiceStat stat; - RETURN_IF_NOT_OK(cs->GetStat(&stat)); - int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz; - max_avail -= mem_consumed; - if (max_avail <= 0) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); - } - } - ++it; - } - if (it == end) { - // If we have some cache using some memory already, make a reasonable decision if we should return - // out of memory. - if (max_avail < avail_mem) { - int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit. - if (req_mem > max_avail) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); - } else if (req_mem == 0) { - // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than - // 85% of our limit, fail this request. - if (static_cast(max_avail) / static_cast(avail_mem) <= 0.15) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); - } - } - } + + if (!duplicate) { + RETURN_IF_NOT_OK(GlobalMemoryCheck(cache_mem_sz)); std::unique_ptr cs; try { cs = std::make_unique(cache_mem_sz, spill ? top_ : "", generate_id); @@ -245,12 +251,8 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { } catch (const std::bad_alloc &e) { return Status(StatusCode::kOutOfMemory); } - } else { - duplicate = true; - client_id = it->second->num_clients_.fetch_add(1); - MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " + - std::to_string(connection_id) + " to create cache service"; } + // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling // the following function lck.Unlock(); @@ -258,6 +260,9 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { auto numa_id = client_id % GetNumaNodeCount(); std::vector cpu_list = hw_info_->GetCpuList(numa_id); // Send back the data + flatbuffers::FlatBufferBuilder fbb; + flatbuffers::Offset off_cookie; + flatbuffers::Offset> off_cpu_list; off_cookie = fbb.CreateString(cookie); off_cpu_list = fbb.CreateVector(cpu_list); CreateCacheReplyMsgBuilder bld(fbb); @@ -376,6 +381,57 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { return rc; } +Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) { + // Look into the flag to see where we can find the data and call the appropriate method. + auto flag = rq->flag(); + Status rc; + if (BitTest(flag, kDataIsInSharedMemory)) { + rc = FastCacheRow(rq, reply); + // This is an internal request and is not tied to rpc. But need to post because there + // is a thread waiting on the completion of this request. + try { + int64_t addr = strtol(rq->buf_data(3).data(), nullptr, 10); + auto *bw = reinterpret_cast(addr); + // Check if the object is still around. + auto bwObj = bw->GetBatchWait(); + if (bwObj.lock()) { + RETURN_IF_NOT_OK(bw->Set(rc)); + } + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + } else { + rc = CacheRow(rq, reply); + } + return rc; +} + +Status CacheServer::InternalFetchRow(CacheRequest *rq) { + auto connection_id = rq->connection_id(); + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); + Status rc; + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } + rc = cs->InternalFetchRow(flatbuffers::GetRoot(rq->buf_data(0).data())); + // This is an internal request and is not tied to rpc. But need to post because there + // is a thread waiting on the completion of this request. + try { + int64_t addr = strtol(rq->buf_data(1).data(), nullptr, 10); + auto *bw = reinterpret_cast(addr); + // Check if the object is still around. + auto bwObj = bw->GetBatchWait(); + if (bwObj.lock()) { + RETURN_IF_NOT_OK(bw->Set(rc)); + } + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return rc; +} + Status CacheServer::BatchFetch(const std::shared_ptr &fbb, WritableSlice *out) { RETURN_UNEXPECTED_IF_NULL(out); auto p = flatbuffers::GetRoot(fbb->GetBufferPointer()); @@ -741,40 +797,24 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) { return Status::OK(); } -Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { - bool internal_request = false; +Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request) { auto &rq = cache_req->rq_; auto &reply = cache_req->reply_; - // Except for creating a new session, we expect cs is not null. switch (cache_req->type_) { - case BaseRequest::RequestType::kCacheRow: - case BaseRequest::RequestType::kInternalCacheRow: { - // Look into the flag to see where we can find the data and - // call the appropriate method. - auto flag = rq.flag(); - if (BitTest(flag, kDataIsInSharedMemory)) { + case BaseRequest::RequestType::kCacheRow: { + // Look into the flag to see where we can find the data and call the appropriate method. + if (BitTest(rq.flag(), kDataIsInSharedMemory)) { cache_req->rc_ = FastCacheRow(&rq, &reply); - internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow); - if (internal_request) { - // This is an internal request and is not tied to rpc. But need to post because there - // is a thread waiting on the completion of this request. - try { - int64_t addr = strtol(rq.buf_data(3).data(), nullptr, 10); - auto *bw = reinterpret_cast(addr); - // Check if the object is still around. - auto bwObj = bw->GetBatchWait(); - if (bwObj.lock()) { - RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_))); - } - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - } } else { cache_req->rc_ = CacheRow(&rq, &reply); } break; } + case BaseRequest::RequestType::kInternalCacheRow: { + *internal_request = true; + cache_req->rc_ = InternalCacheRow(&rq, &reply); + break; + } case BaseRequest::RequestType::kBatchCacheRows: { cache_req->rc_ = BatchCacheRows(&rq); break; @@ -784,31 +824,46 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { break; } case BaseRequest::RequestType::kInternalFetchRow: { - internal_request = true; - auto connection_id = rq.connection_id(); - SharedLock lck(&rwLock_); - CacheService *cs = GetService(connection_id); - if (cs == nullptr) { - std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; - cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot(rq.buf_data(0).data())); - // This is an internal request and is not tied to rpc. But need to post because there - // is a thread waiting on the completion of this request. - try { - int64_t addr = strtol(rq.buf_data(1).data(), nullptr, 10); - auto *bw = reinterpret_cast(addr); - // Check if the object is still around. - auto bwObj = bw->GetBatchWait(); - if (bwObj.lock()) { - RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_))); - } - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - } + *internal_request = true; + cache_req->rc_ = InternalFetchRow(&rq); + break; + } + default: + std::string errMsg("Internal error, request type is not row request: "); + errMsg += std::to_string(static_cast(cache_req->type_)); + cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } + return Status::OK(); +} + +Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) { + auto &rq = cache_req->rq_; + auto &reply = cache_req->reply_; + switch (cache_req->type_) { + case BaseRequest::RequestType::kDropSession: { + cache_req->rc_ = DestroySession(&rq); + break; + } + case BaseRequest::RequestType::kGenerateSessionId: { + cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); break; } + case BaseRequest::RequestType::kListSessions: { + cache_req->rc_ = ListSessions(&reply); + break; + } + default: + std::string errMsg("Internal error, request type is not session request: "); + errMsg += std::to_string(static_cast(cache_req->type_)); + cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } + return Status::OK(); +} + +Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) { + auto &rq = cache_req->rq_; + auto &reply = cache_req->reply_; + switch (cache_req->type_) { case BaseRequest::RequestType::kCreateCache: { cache_req->rc_ = CreateService(&rq, &reply); break; @@ -837,14 +892,6 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { cache_req->rc_ = BuildPhaseDone(&rq); break; } - case BaseRequest::RequestType::kDropSession: { - cache_req->rc_ = DestroySession(&rq); - break; - } - case BaseRequest::RequestType::kGenerateSessionId: { - cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); - break; - } case BaseRequest::RequestType::kAllocateSharedBlock: { cache_req->rc_ = AllocateSharedMemory(&rq, &reply); break; @@ -868,40 +915,45 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { cache_req->rc_ = ToggleWriteMode(&rq); break; } - case BaseRequest::RequestType::kListSessions: { - cache_req->rc_ = ListSessions(&reply); - break; - } case BaseRequest::RequestType::kConnectReset: { cache_req->rc_ = ConnectReset(&rq); break; } case BaseRequest::RequestType::kGetCacheState: { - auto connection_id = rq.connection_id(); - SharedLock lck(&rwLock_); - CacheService *cs = GetService(connection_id); - if (cs == nullptr) { - std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; - cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto state = cs->GetState(); - reply.set_result(std::to_string(static_cast(state))); - cache_req->rc_ = Status::OK(); - } + cache_req->rc_ = GetCacheState(&rq, &reply); break; } default: - std::string errMsg("Unknown request type : "); + std::string errMsg("Internal error, request type is not admin request: "); errMsg += std::to_string(static_cast(cache_req->type_)); cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); } + return Status::OK(); +} + +Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { + bool internal_request = false; + + // Except for creating a new session, we expect cs is not null. + if (cache_req->IsRowRequest()) { + RETURN_IF_NOT_OK(ProcessRowRequest(cache_req, &internal_request)); + } else if (cache_req->IsSessionRequest()) { + RETURN_IF_NOT_OK(ProcessSessionRequest(cache_req)); + } else if (cache_req->IsAdminRequest()) { + RETURN_IF_NOT_OK(ProcessAdminRequest(cache_req)); + } else { + std::string errMsg("Unknown request type : "); + errMsg += std::to_string(static_cast(cache_req->type_)); + cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } + // Notify it is done, and move on to the next request. - Status2CacheReply(cache_req->rc_, &reply); + Status2CacheReply(cache_req->rc_, &cache_req->reply_); cache_req->st_ = CacheServerRequest::STATE::FINISH; // We will re-tag the request back to the grpc queue. Once it comes back from the client, // the CacheServerRequest, i.e. the pointer cache_req, will be free if (!internal_request && !global_shutdown_) { - cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); + cache_req->responder_.Finish(cache_req->reply_, grpc::Status::OK, cache_req); } else { // We can free up the request now. RETURN_IF_NOT_OK(ReturnRequestTag(cache_req)); @@ -1084,6 +1136,20 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) { return Status::OK(); } +Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) { + auto connection_id = rq->connection_id(); + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto state = cs->GetState(); + reply->set_result(std::to_string(static_cast(state))); + return Status::OK(); + } +} + Status CacheServer::RpcRequest(worker_id_t worker_id) { TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); @@ -1213,7 +1279,7 @@ Status CacheServer::Builder::SanityCheck() { } CacheServer::Builder::Builder() - : top_(DefaultSpillDir()), + : top_(""), num_workers_(std::thread::hardware_concurrency() / 2), port_(50052), shared_memory_sz_in_gb_(kDefaultSharedMemorySize), diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index 011a8f5a5d..7fc5049df9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -201,6 +201,22 @@ class CacheServer : public Service { /// \brief Return the memory cap ratio float GetMemoryCapRatio() const { return memory_cap_ratio_; } + /// \brief Function to handle a row request + /// \param[in] cache_req A row request to handle + /// \param[out] internal_request Indicator if the request is an internal request + /// \return Status object + Status ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request); + + /// \brief Function to handle an admin request + /// \param[in] cache_req An admin request to handle + /// \return Status object + Status ProcessAdminRequest(CacheServerRequest *cache_req); + + /// \brief Function to handle a session request + /// \param[in] cache_req A session request to handle + /// \return Status object + Status ProcessSessionRequest(CacheServerRequest *cache_req); + /// \brief How a request is handled. /// \note that it can be process immediately by a grpc thread or routed to a server thread /// which is pinned to some numa node core. @@ -256,6 +272,12 @@ class CacheServer : public Service { /// \return Pointer to cache service. Null if not found CacheService *GetService(connection_id_type id) const; + /// \brief Going over existing cache service and calculate how much we have consumed so far, a new cache service + /// can only be created if there is still enough avail memory left + /// \param cache_mem_sz Requested memory for a new cache service + /// \return Status object + Status GlobalMemoryCheck(uint64_t cache_mem_sz); + /// \brief Create a cache service. We allow multiple clients to create the same cache service. /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given /// a special unique cookie. @@ -314,6 +336,12 @@ class CacheServer : public Service { /// \return Status object Status GetStat(CacheRequest *rq, CacheReply *reply); + /// \brief Internal function to get cache state + /// \param rq + /// \param reply + /// \return Status object + Status GetCacheState(CacheRequest *rq, CacheReply *reply); + /// \brief Cache a schema request /// \param rq /// \return Status object @@ -411,6 +439,9 @@ class CacheServer : public Service { /// \return Status object Status BatchFetch(const std::shared_ptr &fbb, WritableSlice *out); Status BatchCacheRows(CacheRequest *rq); + + Status InternalFetchRow(CacheRequest *rq); + Status InternalCacheRow(CacheRequest *rq, CacheReply *reply); }; } // namespace dataset } // namespace mindspore diff --git a/tests/ut/python/cachetests/cachetest_args.sh b/tests/ut/python/cachetests/cachetest_args.sh index 82c4414ad1..8148926946 100755 --- a/tests/ut/python/cachetests/cachetest_args.sh +++ b/tests/ut/python/cachetests/cachetest_args.sh @@ -15,7 +15,8 @@ # ============================================================================ # source the globals and functions for use with cache testing -SKIP_ADMIN_COUNTER=false +export SKIP_ADMIN_COUNTER=false +declare failed_tests . cachetest_lib.sh echo diff --git a/tests/ut/python/cachetests/cachetest_cpp.sh b/tests/ut/python/cachetests/cachetest_cpp.sh index 875ebbd5b5..b1121f56bd 100755 --- a/tests/ut/python/cachetests/cachetest_cpp.sh +++ b/tests/ut/python/cachetests/cachetest_cpp.sh @@ -15,7 +15,8 @@ # ============================================================================ # source the globals and functions for use with cache testing -SKIP_ADMIN_COUNTER=true +export SKIP_ADMIN_COUNTER=true +declare session_id failed_tests . cachetest_lib.sh echo @@ -28,8 +29,10 @@ UT_TEST_DIR="${BUILD_PATH}/mindspore/tests/ut/cpp" DateStamp=$(date +%Y%m%d_%H%M%S); CPP_TEST_LOG_OUTPUT="/tmp/ut_tests_cache_${DateStamp}.log" -# Start a basic cache server to be used for all tests -StartServer +# start cache server with a spilling path to be used for all tests +cmd="${CACHE_ADMIN} --start -s /tmp" +CacheAdminCmd "${cmd}" 0 +sleep 1 HandleRcExit $? 1 1 # Set the environment variable to enable these pytests diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh index 2a7aaedd16..b7867fa527 100755 --- a/tests/ut/python/cachetests/cachetest_py.sh +++ b/tests/ut/python/cachetests/cachetest_py.sh @@ -15,7 +15,8 @@ # ============================================================================ # source the globals and functions for use with cache testing -SKIP_ADMIN_COUNTER=true +export SKIP_ADMIN_COUNTER=true +declare session_id failed_tests . cachetest_lib.sh echo @@ -84,10 +85,6 @@ export SESSION_ID=$session_id PytestCmd "test_cache_map.py" "test_cache_map_running_twice2" HandleRcExit $? 0 0 -# Set size parameter of DatasetCache to a extra small value -PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1 -HandleRcExit $? 0 0 - PytestCmd "test_cache_map.py" "test_cache_map_no_image" HandleRcExit $? 0 0 @@ -255,15 +252,6 @@ export SESSION_ID=$session_id PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice2" HandleRcExit $? 0 0 -# Set size parameter of DatasetCache to a extra small value -GetSession -HandleRcExit $? 1 1 -export SESSION_ID=$session_id -PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1 -HandleRcExit $? 0 0 -DestroySession $session_id -HandleRcExit $? 1 1 - # Run two parallel pipelines (sharing cache) for i in $(seq 1 2) do @@ -366,7 +354,7 @@ HandleRcExit $? 1 1 export SESSION_ID=$session_id PytestCmd "test_cache_nomap.py" "test_cache_nomap_session_destroy" & -pid=("$!") +pid=$! sleep 10 DestroySession $session_id @@ -381,7 +369,7 @@ HandleRcExit $? 1 1 export SESSION_ID=$session_id PytestCmd "test_cache_nomap.py" "test_cache_nomap_server_stop" & -pid=("$!") +pid=$! sleep 10 StopServer @@ -417,6 +405,26 @@ HandleRcExit $? 0 0 StopServer HandleRcExit $? 0 1 +# start cache server with a spilling path +cmd="${CACHE_ADMIN} --start -s /tmp" +CacheAdminCmd "${cmd}" 0 +sleep 1 +HandleRcExit $? 0 0 + +GetSession +HandleRcExit $? 1 1 +export SESSION_ID=$session_id + +# Set size parameter of mappable DatasetCache to a extra small value +PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1 +HandleRcExit $? 0 0 +# Set size parameter of non-mappable DatasetCache to a extra small value +PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1 +HandleRcExit $? 0 0 + +StopServer +HandleRcExit $? 0 1 + unset RUN_CACHE_TEST unset SESSION_ID diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index b885bc203c..6f5f11fd89 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -57,7 +57,7 @@ def test_cache_map_basic1(): else: session_id = 1 - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -91,7 +91,7 @@ def test_cache_map_basic2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -115,7 +115,7 @@ def test_cache_map_basic3(): session_id = int(os.environ['SESSION_ID']) else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -155,7 +155,7 @@ def test_cache_map_basic4(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -189,7 +189,7 @@ def test_cache_map_basic5(): session_id = int(os.environ['SESSION_ID']) else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -225,7 +225,7 @@ def test_cache_map_failure1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -269,7 +269,7 @@ def test_cache_map_failure2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -310,7 +310,7 @@ def test_cache_map_failure3(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -351,7 +351,7 @@ def test_cache_map_failure4(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -391,7 +391,7 @@ def test_cache_map_failure5(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it data = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -432,7 +432,7 @@ def test_cache_map_failure6(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] num_readers = 1 @@ -478,7 +478,7 @@ def test_cache_map_failure7(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) data = ds.GeneratorDataset(generator_1d, ["data"]) data = data.map((lambda x: x), ["data"], cache=some_cache) @@ -514,7 +514,7 @@ def test_cache_map_failure8(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -554,7 +554,7 @@ def test_cache_map_failure9(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -596,7 +596,7 @@ def test_cache_map_failure10(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -616,6 +616,37 @@ def test_cache_map_failure10(): logger.info('test_cache_failure10 Ended.\n') +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_failure11(): + """ + Test set spilling=true when cache server is started without spilling support (failure) + + Cache(spilling=true) + | + ImageFolder + + """ + logger.info("Test cache failure 11") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert "Unexpected error. Server is not set up with spill support" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_failure11 Ended.\n') + + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_split1(): """ @@ -641,7 +672,7 @@ def test_cache_map_split1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -692,7 +723,7 @@ def test_cache_map_split2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 9 records ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) @@ -725,27 +756,27 @@ def test_cache_map_parameter_check(): logger.info("Test cache map parameter check") with pytest.raises(ValueError) as info: - ds.DatasetCache(session_id=-1, size=0, spilling=True) + ds.DatasetCache(session_id=-1, size=0) assert "Input is not within the required interval" in str(info.value) with pytest.raises(TypeError) as info: - ds.DatasetCache(session_id="1", size=0, spilling=True) + ds.DatasetCache(session_id="1", size=0) assert "Argument session_id with value 1 is not of type (,)" in str(info.value) with pytest.raises(TypeError) as info: - ds.DatasetCache(session_id=None, size=0, spilling=True) + ds.DatasetCache(session_id=None, size=0) assert "Argument session_id with value None is not of type (,)" in str(info.value) with pytest.raises(ValueError) as info: - ds.DatasetCache(session_id=1, size=-1, spilling=True) + ds.DatasetCache(session_id=1, size=-1) assert "Input size must be greater than 0" in str(info.value) with pytest.raises(TypeError) as info: - ds.DatasetCache(session_id=1, size="1", spilling=True) + ds.DatasetCache(session_id=1, size="1") assert "Argument size with value 1 is not of type (,)" in str(info.value) with pytest.raises(TypeError) as info: - ds.DatasetCache(session_id=1, size=None, spilling=True) + ds.DatasetCache(session_id=1, size=None) assert "Argument size with value None is not of type (,)" in str(info.value) with pytest.raises(TypeError) as info: @@ -753,31 +784,31 @@ def test_cache_map_parameter_check(): assert "Argument spilling with value illegal is not of type (,)" in str(info.value) with pytest.raises(TypeError) as err: - ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052) + ds.DatasetCache(session_id=1, size=0, hostname=50052) assert "Argument hostname with value 50052 is not of type (,)" in str(err.value) with pytest.raises(RuntimeError) as err: - ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") + ds.DatasetCache(session_id=1, size=0, hostname="illegal") assert "now cache client has to be on the same host with cache server" in str(err.value) with pytest.raises(RuntimeError) as err: - ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2") + ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2") assert "now cache client has to be on the same host with cache server" in str(err.value) with pytest.raises(TypeError) as info: - ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") + ds.DatasetCache(session_id=1, size=0, port="illegal") assert "Argument port with value illegal is not of type (,)" in str(info.value) with pytest.raises(TypeError) as info: - ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052") + ds.DatasetCache(session_id=1, size=0, port="50052") assert "Argument port with value 50052 is not of type (,)" in str(info.value) with pytest.raises(ValueError) as err: - ds.DatasetCache(session_id=1, size=0, spilling=True, port=0) + ds.DatasetCache(session_id=1, size=0, port=0) assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value) with pytest.raises(ValueError) as err: - ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536) + ds.DatasetCache(session_id=1, size=0, port=65536) assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value) with pytest.raises(TypeError) as err: @@ -807,7 +838,7 @@ def test_cache_map_running_twice1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -850,7 +881,7 @@ def test_cache_map_running_twice2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -998,7 +1029,7 @@ def test_cache_map_parallel_pipeline1(shard): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache) @@ -1035,7 +1066,7 @@ def test_cache_map_parallel_pipeline2(shard): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard)) @@ -1072,7 +1103,7 @@ def test_cache_map_parallel_workers(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4) @@ -1109,7 +1140,7 @@ def test_cache_map_server_workers_1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -1146,7 +1177,7 @@ def test_cache_map_server_workers_100(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -1183,7 +1214,7 @@ def test_cache_map_num_connections_1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1) + some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -1220,7 +1251,7 @@ def test_cache_map_num_connections_100(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100) + some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -1257,7 +1288,7 @@ def test_cache_map_prefetch_size_1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1) + some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -1294,7 +1325,7 @@ def test_cache_map_prefetch_size_100(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100) + some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -1335,7 +1366,7 @@ def test_cache_map_to_device(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -1366,7 +1397,7 @@ def test_cache_map_epoch_ctrl1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -1406,7 +1437,7 @@ def test_cache_map_epoch_ctrl2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) @@ -1452,7 +1483,7 @@ def test_cache_map_epoch_ctrl3(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) @@ -1495,7 +1526,7 @@ def test_cache_map_coco1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 6 records ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, @@ -1531,7 +1562,7 @@ def test_cache_map_coco2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 6 records ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) @@ -1566,7 +1597,7 @@ def test_cache_map_mnist1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache) num_epoch = 4 @@ -1599,7 +1630,7 @@ def test_cache_map_mnist2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) resize_op = c_vision.Resize((224, 224)) @@ -1633,7 +1664,7 @@ def test_cache_map_celeba1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 4 records ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache) @@ -1668,7 +1699,7 @@ def test_cache_map_celeba2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 4 records ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) @@ -1703,7 +1734,7 @@ def test_cache_map_manifest1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 4 records ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache) @@ -1738,7 +1769,7 @@ def test_cache_map_manifest2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 4 records ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) @@ -1773,7 +1804,7 @@ def test_cache_map_cifar1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) num_epoch = 4 @@ -1806,7 +1837,7 @@ def test_cache_map_cifar2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) resize_op = c_vision.Resize((224, 224)) @@ -1841,7 +1872,7 @@ def test_cache_map_cifar3(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) + some_cache = ds.DatasetCache(session_id=session_id, size=1) ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) @@ -1875,7 +1906,7 @@ def test_cache_map_cifar4(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) ds1 = ds1.shuffle(10) @@ -1907,7 +1938,7 @@ def test_cache_map_voc1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 9 records ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache) @@ -1942,7 +1973,7 @@ def test_cache_map_voc2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 9 records ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) @@ -1987,7 +2018,7 @@ def test_cache_map_python_sampler1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache) @@ -2023,7 +2054,7 @@ def test_cache_map_python_sampler2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler()) @@ -2061,7 +2092,7 @@ def test_cache_map_nested_repeat(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This DATA_DIR only has 2 images in it ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 3e6f847479..48b17f6b6c 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -62,7 +62,7 @@ def test_cache_nomap_basic1(): schema.add_column('label', de_type=mstype.uint8, shape=[1]) # create a cache. arbitrary session_id for now - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # User-created sampler here ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache) @@ -96,7 +96,7 @@ def test_cache_nomap_basic2(): schema.add_column('label', de_type=mstype.uint8, shape=[1]) # create a cache. arbitrary session_id for now - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler: # num_samples, shuffle, num_shards, shard_id @@ -134,7 +134,7 @@ def test_cache_nomap_basic3(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) decode_op = c_vision.Decode() ds1 = ds1.map(operations=decode_op, input_columns=["image"]) @@ -183,7 +183,7 @@ def test_cache_nomap_basic4(): raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will # explicitly give the global option, even though it's the default in python. @@ -231,7 +231,7 @@ def test_cache_nomap_basic5(): raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache) decode_op = c_vision.Decode() ds1 = ds1.map(operations=decode_op, input_columns=["image"]) @@ -270,7 +270,7 @@ def test_cache_nomap_basic6(): raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # With only 3 records shard into 3, we expect only 1 record returned for this shard # However, the sharding will be done by the sampler, not by the tf record leaf node @@ -313,7 +313,7 @@ def test_cache_nomap_basic7(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache) @@ -344,7 +344,7 @@ def test_cache_nomap_basic8(): session_id = int(os.environ['SESSION_ID']) else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) @@ -371,7 +371,7 @@ def test_cache_nomap_basic9(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline # so there will not be any cache to get stats on. @@ -404,7 +404,7 @@ def test_cache_nomap_allowed_share1(): ds.config.set_seed(1) # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=32) + some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=32) ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) ds1 = ds1.repeat(4) @@ -446,7 +446,7 @@ def test_cache_nomap_allowed_share2(): ds.config.set_seed(1) # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) decode_op = c_vision.Decode() ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) @@ -488,7 +488,7 @@ def test_cache_nomap_allowed_share3(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"] ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache) @@ -529,7 +529,7 @@ def test_cache_nomap_allowed_share4(): raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) decode_op = c_vision.Decode() ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) @@ -572,7 +572,7 @@ def test_cache_nomap_disallowed_share1(): raise RuntimeError("Testcase requires SESSION_ID environment variable") # This dataset has 3 records in it only - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) decode_op = c_vision.Decode() rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0) @@ -615,7 +615,7 @@ def test_cache_nomap_running_twice1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -658,7 +658,7 @@ def test_cache_nomap_running_twice2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) @@ -763,7 +763,7 @@ def test_cache_nomap_parallel_pipeline1(shard): session_id = int(os.environ['SESSION_ID']) else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache) @@ -799,7 +799,7 @@ def test_cache_nomap_parallel_pipeline2(shard): session_id = int(os.environ['SESSION_ID']) else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard)) @@ -835,7 +835,7 @@ def test_cache_nomap_parallel_workers(): session_id = int(os.environ['SESSION_ID']) else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4) @@ -872,7 +872,7 @@ def test_cache_nomap_server_workers_1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -909,7 +909,7 @@ def test_cache_nomap_server_workers_100(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) @@ -946,7 +946,7 @@ def test_cache_nomap_num_connections_1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1) + some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -983,7 +983,7 @@ def test_cache_nomap_num_connections_100(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100) + some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) @@ -1020,7 +1020,7 @@ def test_cache_nomap_prefetch_size_1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1) + some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -1057,7 +1057,7 @@ def test_cache_nomap_prefetch_size_100(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100) + some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) @@ -1098,7 +1098,7 @@ def test_cache_nomap_to_device(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -1134,7 +1134,7 @@ def test_cache_nomap_session_destroy(): shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) schema.add_column('label', de_type=mstype.uint8, shape=[1]) - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # User-created sampler here ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache) @@ -1172,7 +1172,7 @@ def test_cache_nomap_server_stop(): shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) schema.add_column('label', de_type=mstype.uint8, shape=[1]) - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # User-created sampler here ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache) @@ -1206,7 +1206,7 @@ def test_cache_nomap_epoch_ctrl1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) @@ -1246,7 +1246,7 @@ def test_cache_nomap_epoch_ctrl2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -1292,7 +1292,7 @@ def test_cache_nomap_epoch_ctrl3(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) @@ -1339,7 +1339,7 @@ def test_cache_nomap_epoch_ctrl4(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -1381,8 +1381,8 @@ def test_cache_nomap_multiple_cache1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) - eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + train_cache = ds.DatasetCache(session_id=session_id, size=0) + eval_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 12 records in it train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) @@ -1425,8 +1425,8 @@ def test_cache_nomap_multiple_cache2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) - text_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + image_cache = ds.DatasetCache(session_id=session_id, size=0) + text_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -1470,8 +1470,8 @@ def test_cache_nomap_multiple_cache3(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - tf_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) - image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + tf_cache = ds.DatasetCache(session_id=session_id, size=0) + image_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -1515,7 +1515,7 @@ def test_cache_nomap_multiple_cache_train(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + train_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 12 records in it train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) @@ -1553,7 +1553,7 @@ def test_cache_nomap_multiple_cache_eval(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + eval_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset only has 3 records in it eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -1591,7 +1591,7 @@ def test_cache_nomap_clue1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # With only 3 records shard into 3, we expect only 1 record returned for this shard # However, the sharding will be done by the sampler, not by the clue leaf node @@ -1630,7 +1630,7 @@ def test_cache_nomap_clue2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache) @@ -1666,7 +1666,7 @@ def test_cache_nomap_csv1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # With only 3 records shard into 3, we expect only 1 record returned for this shard # However, the sharding will be done by the sampler, not by the clue leaf node @@ -1706,7 +1706,7 @@ def test_cache_nomap_csv2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) @@ -1743,7 +1743,7 @@ def test_cache_nomap_textfile1(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # With only 3 records shard into 3, we expect only 1 record returned for this shard # However, the sharding will be done by the sampler, not by the clue leaf node @@ -1788,7 +1788,7 @@ def test_cache_nomap_textfile2(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2) tokenizer = text.PythonTokenizer(my_tokenizer) @@ -1828,7 +1828,7 @@ def test_cache_nomap_nested_repeat(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) @@ -1867,7 +1867,7 @@ def test_cache_nomap_get_repeat_count(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + some_cache = ds.DatasetCache(session_id=session_id, size=0) # This dataset has 3 records in it only ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) @@ -1902,7 +1902,7 @@ def test_cache_nomap_long_file_list(): else: raise RuntimeError("Testcase requires SESSION_ID environment variable") - some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) + some_cache = ds.DatasetCache(session_id=session_id, size=1) ds1 = ds.TFRecordDataset([DATA_DIR[0] for _ in range(0, 1000)], SCHEMA_DIR, columns_list=["image"], cache=some_cache)