From: @lixiachen Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -42,14 +42,6 @@ int main(int argc, char **argv) { | |||
| google::InitGoogleLogging(argv[0]); | |||
| #endif | |||
| // Create default spilling dir | |||
| ds::Path spill_dir = ds::Path(ds::DefaultSpillDir()); | |||
| rc = spill_dir.CreateDirectories(); | |||
| if (!rc.IsOk()) { | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return 1; | |||
| } | |||
| if (argc == 1) { | |||
| args.Help(); | |||
| return 0; | |||
| @@ -48,7 +48,7 @@ CacheAdminArgHandler::CacheAdminArgHandler() | |||
| log_level_(kDefaultLogLevel), | |||
| memory_cap_ratio_(kMemoryCapRatio), | |||
| hostname_(kCfgDefaultCacheHost), | |||
| spill_dir_(DefaultSpillDir()), | |||
| spill_dir_(""), | |||
| command_id_(CommandId::kCmdUnknown) { | |||
| // Initialize the command mappings | |||
| arg_map_["-h"] = ArgValue::kArgHost; | |||
| @@ -334,32 +334,7 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| break; | |||
| } | |||
| case CommandId::kCmdStop: { | |||
| CacheClientGreeter comm(hostname_, port_, 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| SharedMessage msg; | |||
| RETURN_IF_NOT_OK(msg.Create()); | |||
| auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId()); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| Status rc = rq->Wait(); | |||
| if (rc.IsError()) { | |||
| msg.RemoveResourcesOnExit(); | |||
| if (rc.IsNetWorkError()) { | |||
| std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already."; | |||
| return Status(StatusCode::kNetWorkError, errMsg); | |||
| } | |||
| return rc; | |||
| } | |||
| // OK return code only means the server acknowledge our request but we still | |||
| // have to wait for its complete shutdown because the server will shutdown | |||
| // the comm layer as soon as the request is received, and we need to wait | |||
| // on the message queue instead. | |||
| // The server will send a message back and remove the queue and we will then wake up. But on the safe | |||
| // side, we will also set up an alarm and kill this process if we hang on | |||
| // the message queue. | |||
| alarm(60); | |||
| Status dummy_rc; | |||
| (void)msg.ReceiveStatus(&dummy_rc); | |||
| std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl; | |||
| RETURN_IF_NOT_OK(StopServer(command_id_)); | |||
| break; | |||
| } | |||
| case CommandId::kCmdGenerateSession: { | |||
| @@ -430,6 +405,36 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::StopServer(CommandId command_id) { | |||
| CacheClientGreeter comm(hostname_, port_, 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| SharedMessage msg; | |||
| RETURN_IF_NOT_OK(msg.Create()); | |||
| auto rq = std::make_shared<ServerStopRequest>(msg.GetMsgQueueId()); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| Status rc = rq->Wait(); | |||
| if (rc.IsError()) { | |||
| msg.RemoveResourcesOnExit(); | |||
| if (rc.IsNetWorkError()) { | |||
| std::string errMsg = "Server on port " + std::to_string(port_) + " is not up or has been shutdown already."; | |||
| return Status(StatusCode::kNetWorkError, errMsg); | |||
| } | |||
| return rc; | |||
| } | |||
| // OK return code only means the server acknowledge our request but we still | |||
| // have to wait for its complete shutdown because the server will shutdown | |||
| // the comm layer as soon as the request is received, and we need to wait | |||
| // on the message queue instead. | |||
| // The server will send a message back and remove the queue and we will then wake up. But on the safe | |||
| // side, we will also set up an alarm and kill this process if we hang on | |||
| // the message queue. | |||
| alarm(60); | |||
| Status dummy_rc; | |||
| (void)msg.ReceiveStatus(&dummy_rc); | |||
| std::cout << "Cache server on port " << std::to_string(port_) << " has been stopped successfully." << std::endl; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::StartServer(CommandId command_id) { | |||
| // There currently does not exist any "install path" or method to identify which path the installed binaries will | |||
| // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | |||
| @@ -462,7 +467,6 @@ Status CacheAdminArgHandler::StartServer(CommandId command_id) { | |||
| // fork the child process to become the daemon | |||
| pid_t pid; | |||
| pid = fork(); | |||
| // failed to fork | |||
| if (pid < 0) { | |||
| std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); | |||
| @@ -538,7 +542,7 @@ void CacheAdminArgHandler::Help() { | |||
| std::cerr << " [[-h | --hostname] <hostname>] Default is " << kCfgDefaultCacheHost << ".\n"; | |||
| std::cerr << " [[-p | --port] <port number>] Default is " << kCfgDefaultCachePort << ".\n"; | |||
| std::cerr << " [[-w | --workers] <number of workers>] Default is " << kDefaultNumWorkers << ".\n"; | |||
| std::cerr << " [[-s | --spilldir] <spilling directory>] Default is " << DefaultSpillDir() << ".\n"; | |||
| std::cerr << " [[-s | --spilldir] <spilling directory>] Default is no spilling.\n"; | |||
| std::cerr << " [[-l | --loglevel] <log level>] Default is 1 (warning level).\n"; | |||
| std::cerr << " [--destroy_session | -d] <session id>\n"; | |||
| std::cerr << " [[-p | --port] <port number>]\n"; | |||
| @@ -79,6 +79,8 @@ class CacheAdminArgHandler { | |||
| Status StartServer(CommandId command_id); | |||
| Status StopServer(CommandId command_id); | |||
| Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id = CommandId::kCmdUnknown); | |||
| @@ -91,9 +91,6 @@ using worker_id_t = int32_t; | |||
| using numa_id_t = int32_t; | |||
| using cpu_id_t = int32_t; | |||
| /// Return the default spill dir for cache | |||
| inline std::string DefaultSpillDir() { return kDefaultPathPrefix; } | |||
| /// Return the default log dir for cache | |||
| inline std::string DefaultLogDir() { return kDefaultPathPrefix + std::string("/log"); } | |||
| @@ -125,6 +125,33 @@ class BaseRequest { | |||
| /// \return Status object | |||
| Status Wait(); | |||
| /// \brief Return if the request is of row request type | |||
| /// \return True if the request is row-related request | |||
| bool IsRowRequest() const { | |||
| return type_ == RequestType::kBatchCacheRows || type_ == RequestType::kBatchFetchRows || | |||
| type_ == RequestType::kInternalCacheRow || type_ == RequestType::kInternalFetchRow || | |||
| type_ == RequestType::kCacheRow; | |||
| } | |||
| /// \brief Return if the request is of admin request type | |||
| /// \return True if the request is admin-related request | |||
| bool IsAdminRequest() const { | |||
| return type_ == RequestType::kCreateCache || type_ == RequestType::kDestroyCache || | |||
| type_ == RequestType::kGetStat || type_ == RequestType::kGetCacheState || | |||
| type_ == RequestType::kAllocateSharedBlock || type_ == RequestType::kFreeSharedBlock || | |||
| type_ == RequestType::kCacheSchema || type_ == RequestType::kFetchSchema || | |||
| type_ == RequestType::kBuildPhaseDone || type_ == RequestType::kToggleWriteMode || | |||
| type_ == RequestType::kConnectReset || type_ == RequestType::kStopService || | |||
| type_ == RequestType::kHeartBeat || type_ == RequestType::kGetCacheMissKeys; | |||
| } | |||
| /// \brief Return if the request is of session request type | |||
| /// \return True if the request is session-related request | |||
| bool IsSessionRequest() const { | |||
| return type_ == RequestType::kGenerateSessionId || type_ == RequestType::kDropSession || | |||
| type_ == RequestType::kListSessions; | |||
| } | |||
| protected: | |||
| CacheRequest rq_; // This is what we send to the server | |||
| CacheReply reply_; // This is what the server send back | |||
| @@ -155,6 +155,42 @@ CacheService *CacheServer::GetService(connection_id_type id) const { | |||
| return nullptr; | |||
| } | |||
| // We would like to protect ourselves from over allocating too much. We will go over existing cache | |||
| // and calculate how much we have consumed so far. | |||
| Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) { | |||
| auto end = all_caches_.end(); | |||
| auto it = all_caches_.begin(); | |||
| auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; | |||
| int64_t max_avail = avail_mem; | |||
| while (it != end) { | |||
| auto &cs = it->second; | |||
| CacheService::ServiceStat stat; | |||
| RETURN_IF_NOT_OK(cs->GetStat(&stat)); | |||
| int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz; | |||
| max_avail -= mem_consumed; | |||
| if (max_avail <= 0) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } | |||
| ++it; | |||
| } | |||
| // If we have some cache using some memory already, make a reasonable decision if we should return | |||
| // out of memory. | |||
| if (max_avail < avail_mem) { | |||
| int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit. | |||
| if (req_mem > max_avail) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } else if (req_mem == 0) { | |||
| // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than | |||
| // 85% of our limit, fail this request. | |||
| if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); | |||
| std::string cookie; | |||
| @@ -186,55 +222,25 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| if (spill && top_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); | |||
| } | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| flatbuffers::Offset<flatbuffers::String> off_cookie; | |||
| flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list; | |||
| // Before creating the cache, first check if this is a request for a shared usage of an existing cache | |||
| // If two CreateService come in with identical connection_id, we need to serialize the create. | |||
| // The first create will be successful and be given a special cookie. | |||
| UniqueLock lck(&rwLock_); | |||
| bool duplicate = false; | |||
| CacheService *curr_cs = GetService(connection_id); | |||
| if (curr_cs != nullptr) { | |||
| duplicate = true; | |||
| client_id = curr_cs->num_clients_.fetch_add(1); | |||
| MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " + | |||
| std::to_string(connection_id) + " to create cache service"; | |||
| } | |||
| // Early exit if we are doing global shutdown | |||
| if (global_shutdown_) { | |||
| return Status::OK(); | |||
| } | |||
| // We would like to protect ourselves from over allocating too much. We will go over existing cache | |||
| // and calculate how much we have consumed so far. | |||
| auto end = all_caches_.end(); | |||
| auto it = all_caches_.begin(); | |||
| bool duplicate = false; | |||
| auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; | |||
| int64_t max_avail = avail_mem; | |||
| while (it != end) { | |||
| if (it->first == connection_id) { | |||
| duplicate = true; | |||
| break; | |||
| } else { | |||
| auto &cs = it->second; | |||
| CacheService::ServiceStat stat; | |||
| RETURN_IF_NOT_OK(cs->GetStat(&stat)); | |||
| int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz; | |||
| max_avail -= mem_consumed; | |||
| if (max_avail <= 0) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } | |||
| } | |||
| ++it; | |||
| } | |||
| if (it == end) { | |||
| // If we have some cache using some memory already, make a reasonable decision if we should return | |||
| // out of memory. | |||
| if (max_avail < avail_mem) { | |||
| int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit. | |||
| if (req_mem > max_avail) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } else if (req_mem == 0) { | |||
| // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than | |||
| // 85% of our limit, fail this request. | |||
| if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } | |||
| } | |||
| } | |||
| if (!duplicate) { | |||
| RETURN_IF_NOT_OK(GlobalMemoryCheck(cache_mem_sz)); | |||
| std::unique_ptr<CacheService> cs; | |||
| try { | |||
| cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id); | |||
| @@ -245,12 +251,8 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| } else { | |||
| duplicate = true; | |||
| client_id = it->second->num_clients_.fetch_add(1); | |||
| MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " + | |||
| std::to_string(connection_id) + " to create cache service"; | |||
| } | |||
| // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling | |||
| // the following function | |||
| lck.Unlock(); | |||
| @@ -258,6 +260,9 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| auto numa_id = client_id % GetNumaNodeCount(); | |||
| std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id); | |||
| // Send back the data | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| flatbuffers::Offset<flatbuffers::String> off_cookie; | |||
| flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list; | |||
| off_cookie = fbb.CreateString(cookie); | |||
| off_cpu_list = fbb.CreateVector(cpu_list); | |||
| CreateCacheReplyMsgBuilder bld(fbb); | |||
| @@ -376,6 +381,57 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| return rc; | |||
| } | |||
| Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| // Look into the flag to see where we can find the data and call the appropriate method. | |||
| auto flag = rq->flag(); | |||
| Status rc; | |||
| if (BitTest(flag, kDataIsInSharedMemory)) { | |||
| rc = FastCacheRow(rq, reply); | |||
| // This is an internal request and is not tied to rpc. But need to post because there | |||
| // is a thread waiting on the completion of this request. | |||
| try { | |||
| int64_t addr = strtol(rq->buf_data(3).data(), nullptr, 10); | |||
| auto *bw = reinterpret_cast<BatchWait *>(addr); | |||
| // Check if the object is still around. | |||
| auto bwObj = bw->GetBatchWait(); | |||
| if (bwObj.lock()) { | |||
| RETURN_IF_NOT_OK(bw->Set(rc)); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| } else { | |||
| rc = CacheRow(rq, reply); | |||
| } | |||
| return rc; | |||
| } | |||
| Status CacheServer::InternalFetchRow(CacheRequest *rq) { | |||
| auto connection_id = rq->connection_id(); | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| Status rc; | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(0).data())); | |||
| // This is an internal request and is not tied to rpc. But need to post because there | |||
| // is a thread waiting on the completion of this request. | |||
| try { | |||
| int64_t addr = strtol(rq->buf_data(1).data(), nullptr, 10); | |||
| auto *bw = reinterpret_cast<BatchWait *>(addr); | |||
| // Check if the object is still around. | |||
| auto bwObj = bw->GetBatchWait(); | |||
| if (bwObj.lock()) { | |||
| RETURN_IF_NOT_OK(bw->Set(rc)); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| return rc; | |||
| } | |||
| Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||
| @@ -741,40 +797,24 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { | |||
| bool internal_request = false; | |||
| Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request) { | |||
| auto &rq = cache_req->rq_; | |||
| auto &reply = cache_req->reply_; | |||
| // Except for creating a new session, we expect cs is not null. | |||
| switch (cache_req->type_) { | |||
| case BaseRequest::RequestType::kCacheRow: | |||
| case BaseRequest::RequestType::kInternalCacheRow: { | |||
| // Look into the flag to see where we can find the data and | |||
| // call the appropriate method. | |||
| auto flag = rq.flag(); | |||
| if (BitTest(flag, kDataIsInSharedMemory)) { | |||
| case BaseRequest::RequestType::kCacheRow: { | |||
| // Look into the flag to see where we can find the data and call the appropriate method. | |||
| if (BitTest(rq.flag(), kDataIsInSharedMemory)) { | |||
| cache_req->rc_ = FastCacheRow(&rq, &reply); | |||
| internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow); | |||
| if (internal_request) { | |||
| // This is an internal request and is not tied to rpc. But need to post because there | |||
| // is a thread waiting on the completion of this request. | |||
| try { | |||
| int64_t addr = strtol(rq.buf_data(3).data(), nullptr, 10); | |||
| auto *bw = reinterpret_cast<BatchWait *>(addr); | |||
| // Check if the object is still around. | |||
| auto bwObj = bw->GetBatchWait(); | |||
| if (bwObj.lock()) { | |||
| RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_))); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| } | |||
| } else { | |||
| cache_req->rc_ = CacheRow(&rq, &reply); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kInternalCacheRow: { | |||
| *internal_request = true; | |||
| cache_req->rc_ = InternalCacheRow(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBatchCacheRows: { | |||
| cache_req->rc_ = BatchCacheRows(&rq); | |||
| break; | |||
| @@ -784,31 +824,46 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kInternalFetchRow: { | |||
| internal_request = true; | |||
| auto connection_id = rq.connection_id(); | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data())); | |||
| // This is an internal request and is not tied to rpc. But need to post because there | |||
| // is a thread waiting on the completion of this request. | |||
| try { | |||
| int64_t addr = strtol(rq.buf_data(1).data(), nullptr, 10); | |||
| auto *bw = reinterpret_cast<BatchWait *>(addr); | |||
| // Check if the object is still around. | |||
| auto bwObj = bw->GetBatchWait(); | |||
| if (bwObj.lock()) { | |||
| RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_))); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| } | |||
| *internal_request = true; | |||
| cache_req->rc_ = InternalFetchRow(&rq); | |||
| break; | |||
| } | |||
| default: | |||
| std::string errMsg("Internal error, request type is not row request: "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) { | |||
| auto &rq = cache_req->rq_; | |||
| auto &reply = cache_req->reply_; | |||
| switch (cache_req->type_) { | |||
| case BaseRequest::RequestType::kDropSession: { | |||
| cache_req->rc_ = DestroySession(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGenerateSessionId: { | |||
| cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kListSessions: { | |||
| cache_req->rc_ = ListSessions(&reply); | |||
| break; | |||
| } | |||
| default: | |||
| std::string errMsg("Internal error, request type is not session request: "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) { | |||
| auto &rq = cache_req->rq_; | |||
| auto &reply = cache_req->reply_; | |||
| switch (cache_req->type_) { | |||
| case BaseRequest::RequestType::kCreateCache: { | |||
| cache_req->rc_ = CreateService(&rq, &reply); | |||
| break; | |||
| @@ -837,14 +892,6 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { | |||
| cache_req->rc_ = BuildPhaseDone(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDropSession: { | |||
| cache_req->rc_ = DestroySession(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGenerateSessionId: { | |||
| cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kAllocateSharedBlock: { | |||
| cache_req->rc_ = AllocateSharedMemory(&rq, &reply); | |||
| break; | |||
| @@ -868,40 +915,45 @@ Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { | |||
| cache_req->rc_ = ToggleWriteMode(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kListSessions: { | |||
| cache_req->rc_ = ListSessions(&reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kConnectReset: { | |||
| cache_req->rc_ = ConnectReset(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetCacheState: { | |||
| auto connection_id = rq.connection_id(); | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto state = cs->GetState(); | |||
| reply.set_result(std::to_string(static_cast<int8_t>(state))); | |||
| cache_req->rc_ = Status::OK(); | |||
| } | |||
| cache_req->rc_ = GetCacheState(&rq, &reply); | |||
| break; | |||
| } | |||
| default: | |||
| std::string errMsg("Unknown request type : "); | |||
| std::string errMsg("Internal error, request type is not admin request: "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { | |||
| bool internal_request = false; | |||
| // Except for creating a new session, we expect cs is not null. | |||
| if (cache_req->IsRowRequest()) { | |||
| RETURN_IF_NOT_OK(ProcessRowRequest(cache_req, &internal_request)); | |||
| } else if (cache_req->IsSessionRequest()) { | |||
| RETURN_IF_NOT_OK(ProcessSessionRequest(cache_req)); | |||
| } else if (cache_req->IsAdminRequest()) { | |||
| RETURN_IF_NOT_OK(ProcessAdminRequest(cache_req)); | |||
| } else { | |||
| std::string errMsg("Unknown request type : "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| // Notify it is done, and move on to the next request. | |||
| Status2CacheReply(cache_req->rc_, &reply); | |||
| Status2CacheReply(cache_req->rc_, &cache_req->reply_); | |||
| cache_req->st_ = CacheServerRequest::STATE::FINISH; | |||
| // We will re-tag the request back to the grpc queue. Once it comes back from the client, | |||
| // the CacheServerRequest, i.e. the pointer cache_req, will be free | |||
| if (!internal_request && !global_shutdown_) { | |||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||
| cache_req->responder_.Finish(cache_req->reply_, grpc::Status::OK, cache_req); | |||
| } else { | |||
| // We can free up the request now. | |||
| RETURN_IF_NOT_OK(ReturnRequestTag(cache_req)); | |||
| @@ -1084,6 +1136,20 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto state = cs->GetState(); | |||
| reply->set_result(std::to_string(static_cast<int8_t>(state))); | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| Status CacheServer::RpcRequest(worker_id_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); | |||
| @@ -1213,7 +1279,7 @@ Status CacheServer::Builder::SanityCheck() { | |||
| } | |||
| CacheServer::Builder::Builder() | |||
| : top_(DefaultSpillDir()), | |||
| : top_(""), | |||
| num_workers_(std::thread::hardware_concurrency() / 2), | |||
| port_(50052), | |||
| shared_memory_sz_in_gb_(kDefaultSharedMemorySize), | |||
| @@ -201,6 +201,22 @@ class CacheServer : public Service { | |||
| /// \brief Return the memory cap ratio | |||
| float GetMemoryCapRatio() const { return memory_cap_ratio_; } | |||
| /// \brief Function to handle a row request | |||
| /// \param[in] cache_req A row request to handle | |||
| /// \param[out] internal_request Indicator if the request is an internal request | |||
| /// \return Status object | |||
| Status ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request); | |||
| /// \brief Function to handle an admin request | |||
| /// \param[in] cache_req An admin request to handle | |||
| /// \return Status object | |||
| Status ProcessAdminRequest(CacheServerRequest *cache_req); | |||
| /// \brief Function to handle a session request | |||
| /// \param[in] cache_req A session request to handle | |||
| /// \return Status object | |||
| Status ProcessSessionRequest(CacheServerRequest *cache_req); | |||
| /// \brief How a request is handled. | |||
| /// \note that it can be process immediately by a grpc thread or routed to a server thread | |||
| /// which is pinned to some numa node core. | |||
| @@ -256,6 +272,12 @@ class CacheServer : public Service { | |||
| /// \return Pointer to cache service. Null if not found | |||
| CacheService *GetService(connection_id_type id) const; | |||
| /// \brief Going over existing cache service and calculate how much we have consumed so far, a new cache service | |||
| /// can only be created if there is still enough avail memory left | |||
| /// \param cache_mem_sz Requested memory for a new cache service | |||
| /// \return Status object | |||
| Status GlobalMemoryCheck(uint64_t cache_mem_sz); | |||
| /// \brief Create a cache service. We allow multiple clients to create the same cache service. | |||
| /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given | |||
| /// a special unique cookie. | |||
| @@ -314,6 +336,12 @@ class CacheServer : public Service { | |||
| /// \return Status object | |||
| Status GetStat(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Internal function to get cache state | |||
| /// \param rq | |||
| /// \param reply | |||
| /// \return Status object | |||
| Status GetCacheState(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Cache a schema request | |||
| /// \param rq | |||
| /// \return Status object | |||
| @@ -411,6 +439,9 @@ class CacheServer : public Service { | |||
| /// \return Status object | |||
| Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out); | |||
| Status BatchCacheRows(CacheRequest *rq); | |||
| Status InternalFetchRow(CacheRequest *rq); | |||
| Status InternalCacheRow(CacheRequest *rq, CacheReply *reply); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -15,7 +15,8 @@ | |||
| # ============================================================================ | |||
| # source the globals and functions for use with cache testing | |||
| SKIP_ADMIN_COUNTER=false | |||
| export SKIP_ADMIN_COUNTER=false | |||
| declare failed_tests | |||
| . cachetest_lib.sh | |||
| echo | |||
| @@ -15,7 +15,8 @@ | |||
| # ============================================================================ | |||
| # source the globals and functions for use with cache testing | |||
| SKIP_ADMIN_COUNTER=true | |||
| export SKIP_ADMIN_COUNTER=true | |||
| declare session_id failed_tests | |||
| . cachetest_lib.sh | |||
| echo | |||
| @@ -28,8 +29,10 @@ UT_TEST_DIR="${BUILD_PATH}/mindspore/tests/ut/cpp" | |||
| DateStamp=$(date +%Y%m%d_%H%M%S); | |||
| CPP_TEST_LOG_OUTPUT="/tmp/ut_tests_cache_${DateStamp}.log" | |||
| # Start a basic cache server to be used for all tests | |||
| StartServer | |||
| # start cache server with a spilling path to be used for all tests | |||
| cmd="${CACHE_ADMIN} --start -s /tmp" | |||
| CacheAdminCmd "${cmd}" 0 | |||
| sleep 1 | |||
| HandleRcExit $? 1 1 | |||
| # Set the environment variable to enable these pytests | |||
| @@ -15,7 +15,8 @@ | |||
| # ============================================================================ | |||
| # source the globals and functions for use with cache testing | |||
| SKIP_ADMIN_COUNTER=true | |||
| export SKIP_ADMIN_COUNTER=true | |||
| declare session_id failed_tests | |||
| . cachetest_lib.sh | |||
| echo | |||
| @@ -84,10 +85,6 @@ export SESSION_ID=$session_id | |||
| PytestCmd "test_cache_map.py" "test_cache_map_running_twice2" | |||
| HandleRcExit $? 0 0 | |||
| # Set size parameter of DatasetCache to a extra small value | |||
| PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_no_image" | |||
| HandleRcExit $? 0 0 | |||
| @@ -255,15 +252,6 @@ export SESSION_ID=$session_id | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_running_twice2" | |||
| HandleRcExit $? 0 0 | |||
| # Set size parameter of DatasetCache to a extra small value | |||
| GetSession | |||
| HandleRcExit $? 1 1 | |||
| export SESSION_ID=$session_id | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1 | |||
| HandleRcExit $? 0 0 | |||
| DestroySession $session_id | |||
| HandleRcExit $? 1 1 | |||
| # Run two parallel pipelines (sharing cache) | |||
| for i in $(seq 1 2) | |||
| do | |||
| @@ -366,7 +354,7 @@ HandleRcExit $? 1 1 | |||
| export SESSION_ID=$session_id | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_session_destroy" & | |||
| pid=("$!") | |||
| pid=$! | |||
| sleep 10 | |||
| DestroySession $session_id | |||
| @@ -381,7 +369,7 @@ HandleRcExit $? 1 1 | |||
| export SESSION_ID=$session_id | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_server_stop" & | |||
| pid=("$!") | |||
| pid=$! | |||
| sleep 10 | |||
| StopServer | |||
| @@ -417,6 +405,26 @@ HandleRcExit $? 0 0 | |||
| StopServer | |||
| HandleRcExit $? 0 1 | |||
| # start cache server with a spilling path | |||
| cmd="${CACHE_ADMIN} --start -s /tmp" | |||
| CacheAdminCmd "${cmd}" 0 | |||
| sleep 1 | |||
| HandleRcExit $? 0 0 | |||
| GetSession | |||
| HandleRcExit $? 1 1 | |||
| export SESSION_ID=$session_id | |||
| # Set size parameter of mappable DatasetCache to a extra small value | |||
| PytestCmd "test_cache_map.py" "test_cache_map_extra_small_size" 1 | |||
| HandleRcExit $? 0 0 | |||
| # Set size parameter of non-mappable DatasetCache to a extra small value | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_extra_small_size" 1 | |||
| HandleRcExit $? 0 0 | |||
| StopServer | |||
| HandleRcExit $? 0 1 | |||
| unset RUN_CACHE_TEST | |||
| unset SESSION_ID | |||
| @@ -57,7 +57,7 @@ def test_cache_map_basic1(): | |||
| else: | |||
| session_id = 1 | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -91,7 +91,7 @@ def test_cache_map_basic2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -115,7 +115,7 @@ def test_cache_map_basic3(): | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -155,7 +155,7 @@ def test_cache_map_basic4(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -189,7 +189,7 @@ def test_cache_map_basic5(): | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -225,7 +225,7 @@ def test_cache_map_failure1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -269,7 +269,7 @@ def test_cache_map_failure2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -310,7 +310,7 @@ def test_cache_map_failure3(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -351,7 +351,7 @@ def test_cache_map_failure4(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -391,7 +391,7 @@ def test_cache_map_failure5(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| data = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -432,7 +432,7 @@ def test_cache_map_failure6(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] | |||
| num_readers = 1 | |||
| @@ -478,7 +478,7 @@ def test_cache_map_failure7(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| data = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data = data.map((lambda x: x), ["data"], cache=some_cache) | |||
| @@ -514,7 +514,7 @@ def test_cache_map_failure8(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -554,7 +554,7 @@ def test_cache_map_failure9(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -596,7 +596,7 @@ def test_cache_map_failure10(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -616,6 +616,37 @@ def test_cache_map_failure10(): | |||
| logger.info('test_cache_failure10 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_failure11(): | |||
| """ | |||
| Test set spilling=true when cache server is started without spilling support (failure) | |||
| Cache(spilling=true) | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache failure 11") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "Unexpected error. Server is not set up with spill support" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure11 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_split1(): | |||
| """ | |||
| @@ -641,7 +672,7 @@ def test_cache_map_split1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -692,7 +723,7 @@ def test_cache_map_split2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 9 records | |||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||
| @@ -725,27 +756,27 @@ def test_cache_map_parameter_check(): | |||
| logger.info("Test cache map parameter check") | |||
| with pytest.raises(ValueError) as info: | |||
| ds.DatasetCache(session_id=-1, size=0, spilling=True) | |||
| ds.DatasetCache(session_id=-1, size=0) | |||
| assert "Input is not within the required interval" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id="1", size=0, spilling=True) | |||
| ds.DatasetCache(session_id="1", size=0) | |||
| assert "Argument session_id with value 1 is not of type (<class 'int'>,)" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=None, size=0, spilling=True) | |||
| ds.DatasetCache(session_id=None, size=0) | |||
| assert "Argument session_id with value None is not of type (<class 'int'>,)" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| ds.DatasetCache(session_id=1, size=-1, spilling=True) | |||
| ds.DatasetCache(session_id=1, size=-1) | |||
| assert "Input size must be greater than 0" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size="1", spilling=True) | |||
| ds.DatasetCache(session_id=1, size="1") | |||
| assert "Argument size with value 1 is not of type (<class 'int'>,)" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size=None, spilling=True) | |||
| ds.DatasetCache(session_id=1, size=None) | |||
| assert "Argument size with value None is not of type (<class 'int'>,)" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| @@ -753,31 +784,31 @@ def test_cache_map_parameter_check(): | |||
| assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value) | |||
| with pytest.raises(TypeError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052) | |||
| ds.DatasetCache(session_id=1, size=0, hostname=50052) | |||
| assert "Argument hostname with value 50052 is not of type (<class 'str'>,)" in str(err.value) | |||
| with pytest.raises(RuntimeError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") | |||
| ds.DatasetCache(session_id=1, size=0, hostname="illegal") | |||
| assert "now cache client has to be on the same host with cache server" in str(err.value) | |||
| with pytest.raises(RuntimeError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2") | |||
| ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2") | |||
| assert "now cache client has to be on the same host with cache server" in str(err.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") | |||
| ds.DatasetCache(session_id=1, size=0, port="illegal") | |||
| assert "Argument port with value illegal is not of type (<class 'int'>,)" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052") | |||
| ds.DatasetCache(session_id=1, size=0, port="50052") | |||
| assert "Argument port with value 50052 is not of type (<class 'int'>,)" in str(info.value) | |||
| with pytest.raises(ValueError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, port=0) | |||
| ds.DatasetCache(session_id=1, size=0, port=0) | |||
| assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value) | |||
| with pytest.raises(ValueError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536) | |||
| ds.DatasetCache(session_id=1, size=0, port=65536) | |||
| assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value) | |||
| with pytest.raises(TypeError) as err: | |||
| @@ -807,7 +838,7 @@ def test_cache_map_running_twice1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -850,7 +881,7 @@ def test_cache_map_running_twice2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -998,7 +1029,7 @@ def test_cache_map_parallel_pipeline1(shard): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache) | |||
| @@ -1035,7 +1066,7 @@ def test_cache_map_parallel_pipeline2(shard): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard)) | |||
| @@ -1072,7 +1103,7 @@ def test_cache_map_parallel_workers(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4) | |||
| @@ -1109,7 +1140,7 @@ def test_cache_map_server_workers_1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -1146,7 +1177,7 @@ def test_cache_map_server_workers_100(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -1183,7 +1214,7 @@ def test_cache_map_num_connections_1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -1220,7 +1251,7 @@ def test_cache_map_num_connections_100(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -1257,7 +1288,7 @@ def test_cache_map_prefetch_size_1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -1294,7 +1325,7 @@ def test_cache_map_prefetch_size_100(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -1335,7 +1366,7 @@ def test_cache_map_to_device(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -1366,7 +1397,7 @@ def test_cache_map_epoch_ctrl1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -1406,7 +1437,7 @@ def test_cache_map_epoch_ctrl2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| @@ -1452,7 +1483,7 @@ def test_cache_map_epoch_ctrl3(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -1495,7 +1526,7 @@ def test_cache_map_coco1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 6 records | |||
| ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, | |||
| @@ -1531,7 +1562,7 @@ def test_cache_map_coco2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 6 records | |||
| ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) | |||
| @@ -1566,7 +1597,7 @@ def test_cache_map_mnist1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache) | |||
| num_epoch = 4 | |||
| @@ -1599,7 +1630,7 @@ def test_cache_map_mnist2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| @@ -1633,7 +1664,7 @@ def test_cache_map_celeba1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 4 records | |||
| ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache) | |||
| @@ -1668,7 +1699,7 @@ def test_cache_map_celeba2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 4 records | |||
| ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) | |||
| @@ -1703,7 +1734,7 @@ def test_cache_map_manifest1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 4 records | |||
| ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache) | |||
| @@ -1738,7 +1769,7 @@ def test_cache_map_manifest2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 4 records | |||
| ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) | |||
| @@ -1773,7 +1804,7 @@ def test_cache_map_cifar1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) | |||
| num_epoch = 4 | |||
| @@ -1806,7 +1837,7 @@ def test_cache_map_cifar2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| @@ -1841,7 +1872,7 @@ def test_cache_map_cifar3(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=1) | |||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) | |||
| @@ -1875,7 +1906,7 @@ def test_cache_map_cifar4(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) | |||
| ds1 = ds1.shuffle(10) | |||
| @@ -1907,7 +1938,7 @@ def test_cache_map_voc1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 9 records | |||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache) | |||
| @@ -1942,7 +1973,7 @@ def test_cache_map_voc2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 9 records | |||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||
| @@ -1987,7 +2018,7 @@ def test_cache_map_python_sampler1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache) | |||
| @@ -2023,7 +2054,7 @@ def test_cache_map_python_sampler2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler()) | |||
| @@ -2061,7 +2092,7 @@ def test_cache_map_nested_repeat(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| @@ -62,7 +62,7 @@ def test_cache_nomap_basic1(): | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| # create a cache. arbitrary session_id for now | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # User-created sampler here | |||
| ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache) | |||
| @@ -96,7 +96,7 @@ def test_cache_nomap_basic2(): | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| # create a cache. arbitrary session_id for now | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler: | |||
| # num_samples, shuffle, num_shards, shard_id | |||
| @@ -134,7 +134,7 @@ def test_cache_nomap_basic3(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"]) | |||
| @@ -183,7 +183,7 @@ def test_cache_nomap_basic4(): | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache | |||
| # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will | |||
| # explicitly give the global option, even though it's the default in python. | |||
| @@ -231,7 +231,7 @@ def test_cache_nomap_basic5(): | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"]) | |||
| @@ -270,7 +270,7 @@ def test_cache_nomap_basic6(): | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||
| # However, the sharding will be done by the sampler, not by the tf record leaf node | |||
| @@ -313,7 +313,7 @@ def test_cache_nomap_basic7(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache) | |||
| @@ -344,7 +344,7 @@ def test_cache_nomap_basic8(): | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) | |||
| @@ -371,7 +371,7 @@ def test_cache_nomap_basic9(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline | |||
| # so there will not be any cache to get stats on. | |||
| @@ -404,7 +404,7 @@ def test_cache_nomap_allowed_share1(): | |||
| ds.config.set_seed(1) | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=32) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=32) | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| @@ -446,7 +446,7 @@ def test_cache_nomap_allowed_share2(): | |||
| ds.config.set_seed(1) | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| @@ -488,7 +488,7 @@ def test_cache_nomap_allowed_share3(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"] | |||
| ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache) | |||
| @@ -529,7 +529,7 @@ def test_cache_nomap_allowed_share4(): | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| @@ -572,7 +572,7 @@ def test_cache_nomap_disallowed_share1(): | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| # This dataset has 3 records in it only | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| decode_op = c_vision.Decode() | |||
| rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0) | |||
| @@ -615,7 +615,7 @@ def test_cache_nomap_running_twice1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -658,7 +658,7 @@ def test_cache_nomap_running_twice2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) | |||
| @@ -763,7 +763,7 @@ def test_cache_nomap_parallel_pipeline1(shard): | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache) | |||
| @@ -799,7 +799,7 @@ def test_cache_nomap_parallel_pipeline2(shard): | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard)) | |||
| @@ -835,7 +835,7 @@ def test_cache_nomap_parallel_workers(): | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4) | |||
| @@ -872,7 +872,7 @@ def test_cache_nomap_server_workers_1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -909,7 +909,7 @@ def test_cache_nomap_server_workers_100(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) | |||
| @@ -946,7 +946,7 @@ def test_cache_nomap_num_connections_1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -983,7 +983,7 @@ def test_cache_nomap_num_connections_100(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) | |||
| @@ -1020,7 +1020,7 @@ def test_cache_nomap_prefetch_size_1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -1057,7 +1057,7 @@ def test_cache_nomap_prefetch_size_100(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) | |||
| @@ -1098,7 +1098,7 @@ def test_cache_nomap_to_device(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -1134,7 +1134,7 @@ def test_cache_nomap_session_destroy(): | |||
| shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # User-created sampler here | |||
| ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache) | |||
| @@ -1172,7 +1172,7 @@ def test_cache_nomap_server_stop(): | |||
| shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # User-created sampler here | |||
| ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache) | |||
| @@ -1206,7 +1206,7 @@ def test_cache_nomap_epoch_ctrl1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) | |||
| @@ -1246,7 +1246,7 @@ def test_cache_nomap_epoch_ctrl2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -1292,7 +1292,7 @@ def test_cache_nomap_epoch_ctrl3(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) | |||
| @@ -1339,7 +1339,7 @@ def test_cache_nomap_epoch_ctrl4(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -1381,8 +1381,8 @@ def test_cache_nomap_multiple_cache1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| train_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| eval_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 12 records in it | |||
| train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) | |||
| @@ -1425,8 +1425,8 @@ def test_cache_nomap_multiple_cache2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| text_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| image_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| text_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -1470,8 +1470,8 @@ def test_cache_nomap_multiple_cache3(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| tf_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| tf_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| image_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -1515,7 +1515,7 @@ def test_cache_nomap_multiple_cache_train(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| train_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 12 records in it | |||
| train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) | |||
| @@ -1553,7 +1553,7 @@ def test_cache_nomap_multiple_cache_eval(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| eval_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset only has 3 records in it | |||
| eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -1591,7 +1591,7 @@ def test_cache_nomap_clue1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||
| @@ -1630,7 +1630,7 @@ def test_cache_nomap_clue2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) | |||
| ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache) | |||
| @@ -1666,7 +1666,7 @@ def test_cache_nomap_csv1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||
| @@ -1706,7 +1706,7 @@ def test_cache_nomap_csv2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) | |||
| @@ -1743,7 +1743,7 @@ def test_cache_nomap_textfile1(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # With only 3 records shard into 3, we expect only 1 record returned for this shard | |||
| # However, the sharding will be done by the sampler, not by the clue leaf node | |||
| @@ -1788,7 +1788,7 @@ def test_cache_nomap_textfile2(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2) | |||
| tokenizer = text.PythonTokenizer(my_tokenizer) | |||
| @@ -1828,7 +1828,7 @@ def test_cache_nomap_nested_repeat(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| @@ -1867,7 +1867,7 @@ def test_cache_nomap_get_repeat_count(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| @@ -1902,7 +1902,7 @@ def test_cache_nomap_long_file_list(): | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=1) | |||
| ds1 = ds.TFRecordDataset([DATA_DIR[0] for _ in range(0, 1000)], SCHEMA_DIR, columns_list=["image"], | |||
| cache=some_cache) | |||