| @@ -1,3 +1,4 @@ | |||
| add_subdirectory(perf EXCLUDE_FROM_ALL) | |||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||
| set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | |||
| ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||
| @@ -5,6 +6,18 @@ ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engin | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| # Try to find numa header file and its library | |||
| find_file(NUMA_HDR NAMES "numa.h") | |||
| if (EXISTS ${NUMA_HDR}) | |||
| ADD_DEFINITIONS(-DNUMA_ENABLED) | |||
| MESSAGE("Numa package found") | |||
| endif () | |||
| if (${CMAKE_SYSTEM_NAME} MATCHES "Linux") | |||
| ADD_DEFINITIONS(-DCACHE_LOCAL_CLIENT) | |||
| endif () | |||
| add_library(engine-cache-client OBJECT | |||
| cache_client.cc | |||
| cache_fbb.cc | |||
| @@ -20,8 +33,13 @@ if (ENABLE_CACHE) | |||
| ${CACHE_GRPC_SRCS} | |||
| cache_grpc_server.cc | |||
| cache_arena.cc | |||
| cache_hw.cc | |||
| cache_numa.cc | |||
| cache_pool.cc | |||
| cache_service.cc | |||
| cache_server.cc) | |||
| cache_server.cc | |||
| storage_manager.cc | |||
| storage_container.cc) | |||
| add_executable(cache_server cache_main.cc) | |||
| target_link_libraries(cache_server | |||
| @@ -39,6 +57,10 @@ if (ENABLE_CACHE) | |||
| target_link_libraries(cache_server mindspore::glog) | |||
| endif () | |||
| if (EXISTS ${NUMA_HDR}) | |||
| target_link_libraries(cache_server numa) | |||
| endif () | |||
| add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) | |||
| target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES}) | |||
| @@ -49,7 +71,7 @@ if (ENABLE_CACHE) | |||
| add_dependencies(engine-cache-server generated_engine_files) | |||
| else () | |||
| ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto) | |||
| ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PROTO_HDRS cache_grpc.proto) | |||
| target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS}) | |||
| endif () | |||
| @@ -18,6 +18,7 @@ | |||
| #include <sys/stat.h> | |||
| #include <sys/wait.h> | |||
| #include <unistd.h> | |||
| #include <algorithm> | |||
| #include <cerrno> | |||
| #include <iomanip> | |||
| #include <iostream> | |||
| @@ -31,7 +32,9 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const int32_t CacheAdminArgHandler::kDefaultNumWorkers = std::thread::hardware_concurrency() > 2 | |||
| ? std::thread::hardware_concurrency() / 2 | |||
| : 1; | |||
| const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; | |||
| const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; | |||
| @@ -304,8 +307,10 @@ Status CacheAdminArgHandler::Validate() { | |||
| } | |||
| // Additional checks here | |||
| if (num_workers_ < 1 || num_workers_ > 100) | |||
| return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100."); | |||
| auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100); | |||
| if (num_workers_ < 1 || num_workers_ > max_num_workers) | |||
| return Status(StatusCode::kSyntaxError, | |||
| "Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + "."); | |||
| if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); | |||
| if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) | |||
| return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1"); | |||
| @@ -354,13 +359,15 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo(); | |||
| if (!session_info.empty()) { | |||
| std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached" | |||
| << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl; | |||
| << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::setw(10) << "Numa hit" | |||
| << std::endl; | |||
| for (auto curr_session : session_info) { | |||
| std::string cache_id; | |||
| std::string stat_mem_cached; | |||
| std::string stat_disk_cached; | |||
| std::string stat_avg_cached; | |||
| int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); | |||
| std::string stat_numa_hit; | |||
| uint32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); | |||
| cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc); | |||
| stat_mem_cached = | |||
| (curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached); | |||
| @@ -368,10 +375,12 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| (curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached); | |||
| stat_avg_cached = | |||
| (curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz); | |||
| stat_numa_hit = | |||
| (curr_session.stats.num_numa_hit == 0) ? "n/a" : std::to_string(curr_session.stats.num_numa_hit); | |||
| std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12) | |||
| << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached | |||
| << std::endl; | |||
| << std::setw(10) << stat_numa_hit << std::endl; | |||
| } | |||
| } else { | |||
| std::cout << "No active sessions." << std::endl; | |||
| @@ -21,6 +21,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <sstream> | |||
| #include <thread> | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| @@ -29,7 +30,7 @@ namespace dataset { | |||
| class CacheAdminArgHandler { | |||
| public: | |||
| static constexpr int32_t kDefaultNumWorkers = 32; | |||
| static const int32_t kDefaultNumWorkers; | |||
| static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | |||
| static constexpr int32_t kDefaultLogLevel = 1; | |||
| static constexpr float kMemoryCapRatio = 0.8; | |||
| @@ -17,7 +17,6 @@ | |||
| #include <iomanip> | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | |||
| #include "minddata/dataset/util/bit.h" | |||
| @@ -59,6 +58,7 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool | |||
| : server_connection_id_(0), | |||
| cache_mem_sz_(cache_mem_sz), | |||
| spill_(spill), | |||
| client_id_(-1), | |||
| local_bypass_(false), | |||
| hostname_(std::move(hostname)), | |||
| port_(port), | |||
| @@ -71,6 +71,22 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool | |||
| CacheClient::~CacheClient() { | |||
| cache_miss_keys_wp_.Set(); | |||
| if (client_id_ != -1) { | |||
| try { | |||
| // Send a message to the server, saying I am done. | |||
| auto rq = std::make_shared<ConnectResetRequest>(server_connection_id_, client_id_); | |||
| Status rc = PushRequest(rq); | |||
| if (rc.IsOk()) { | |||
| rc = rq->Wait(); | |||
| if (rc.IsOk()) { | |||
| MS_LOG(INFO) << "Disconnect from server successful"; | |||
| } | |||
| } | |||
| } catch (const std::exception &e) { | |||
| // Can't do anything in destructor. So just log the error. | |||
| MS_LOG(ERROR) << e.what(); | |||
| } | |||
| } | |||
| (void)comm_->ServiceStop(); | |||
| } | |||
| @@ -85,7 +101,7 @@ void CacheClient::Print(std::ostream &out) const { | |||
| } | |||
| Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { | |||
| auto rq = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient()); | |||
| auto rq = std::make_shared<CacheRowRequest>(this); | |||
| RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row)); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| @@ -104,7 +120,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||
| for (auto i = 0; i < num_rows; ++i) { | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | |||
| arr[i] = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient()); | |||
| arr[i] = std::make_shared<CacheRowRequest>(this); | |||
| RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row)); | |||
| RETURN_IF_NOT_OK(PushRequest(arr[i])); | |||
| } | |||
| @@ -118,7 +134,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||
| Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| auto rq = std::make_shared<BatchFetchRequest>(server_connection_id_, row_id, SupportLocalClient()); | |||
| auto rq = std::make_shared<BatchFetchRequest>(this, row_id); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| int64_t mem_addr; | |||
| @@ -167,7 +183,7 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. | |||
| CacheServiceStat stat{}; | |||
| RETURN_IF_NOT_OK(GetStat(&stat)); | |||
| if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) { | |||
| if (stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase)) { | |||
| return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); | |||
| } | |||
| } else { | |||
| @@ -183,18 +199,16 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| // Start the comm layer to receive reply | |||
| RETURN_IF_NOT_OK(comm_->ServiceStart()); | |||
| // Initiate connection | |||
| auto rq = std::make_shared<CreateCacheRequest>(cinfo_, cache_mem_sz_, createFlag); | |||
| auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| Status rc = rq->Wait(); | |||
| if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { | |||
| std::string cookie; | |||
| rq->ParseResult(&server_connection_id_, &cookie); | |||
| if (rc.IsOk()) { | |||
| // The 1st guy creating the cache will get a cookie back. | |||
| // But this object may be shared among pipelines and we don't want | |||
| // overwrite it. | |||
| cookie_ = cookie; | |||
| } | |||
| bool success = (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey); | |||
| // If we get kDuplicateKey, it just means we aren't the first one to create the cache, | |||
| // and we will continue to parse the result. | |||
| if (rc.get_code() == StatusCode::kDuplicateKey) { | |||
| RETURN_IF_NOT_OK(rq->PostReply()); | |||
| } | |||
| if (success) { | |||
| // Attach to shared memory for local client | |||
| RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); | |||
| } | |||
| @@ -47,6 +47,9 @@ namespace dataset { | |||
| class CacheClient { | |||
| public: | |||
| friend class CacheMergeOp; | |||
| friend class CreateCacheRequest; | |||
| friend class CacheRowRequest; | |||
| friend class BatchFetchRequest; | |||
| /// \brief A builder to help creating a CacheClient object | |||
| class Builder { | |||
| @@ -115,7 +118,7 @@ class CacheClient { | |||
| session_id_type GetSessionId() const { return session_id_; } | |||
| uint64_t GetCacheMemSz() const { return cache_mem_sz_; } | |||
| bool isSpill() const { return spill_; } | |||
| const std::string &getHostname() const { return hostname_; } | |||
| const std::string &GetHostname() const { return hostname_; } | |||
| int32_t GetPort() const { return port_; } | |||
| int32_t GetNumConnections() const { return num_connections_; } | |||
| int32_t GetPrefetchSize() const { return prefetch_size_; } | |||
| @@ -256,8 +259,10 @@ class CacheClient { | |||
| CacheClientInfo cinfo_; | |||
| // The server_connection_id_ is the actual id we use for operations after the cache is built | |||
| connection_id_type server_connection_id_; | |||
| // Some magic cookie returned from the cache server. | |||
| // Some magic cookie/id returned from the cache server. | |||
| std::string cookie_; | |||
| int32_t client_id_; | |||
| std::vector<int32_t> cpu_list_; | |||
| // Comm layer | |||
| bool local_bypass_; | |||
| std::string hostname_; | |||
| @@ -20,11 +20,6 @@ | |||
| /// both client and server side codes. Do not put code that is not common here. | |||
| /// There are client and server specific header files. | |||
| // On platform like Windows, we may support only tcp/ip clients | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #define CACHE_LOCAL_CLIENT 1 | |||
| #endif | |||
| #ifdef ENABLE_CACHE | |||
| #include <grpcpp/grpcpp.h> | |||
| #endif | |||
| @@ -50,6 +45,9 @@ constexpr static uint32_t kDataIsInSharedMemory = 2; | |||
| /// \brief Size of each message used in message queue. | |||
| constexpr static int32_t kSharedMessageSize = 2048; | |||
| /// \brief State of CacheService at the server. | |||
| enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; | |||
| /// \brief Convert a Status object into a protobuf | |||
| /// \param rc[in] Status object | |||
| /// \param reply[in/out] pointer to pre-allocated protobuf object | |||
| @@ -61,6 +59,22 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) { | |||
| /// \param port | |||
| /// \return unix socket url | |||
| inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } | |||
| /// \brief Round up to the next 4k | |||
| inline int64_t round_up_4K(int64_t sz) { | |||
| // Since 4096 is a power of 2, a simple way to round up is add 4095 and mask off all the | |||
| // bits of 4095 | |||
| return static_cast<uint64_t>(sz + 4095) & ~static_cast<uint64_t>(4095); | |||
| } | |||
| /// Memory policy | |||
| enum CachePoolPolicy : int8_t { kOnNode, kPreferred, kLocal, kInterleave, kNone }; | |||
| /// Misc typedef | |||
| using worker_id_t = int32_t; | |||
| using numa_id_t = int32_t; | |||
| using cpu_id_t = int32_t; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | |||
| @@ -32,12 +32,13 @@ message CacheRequest { | |||
| uint32 flag = 2; | |||
| oneof connect_info { | |||
| // The server_connection_id is the actual id we use for operations after the cache is built | |||
| int64 connection_id = 3; | |||
| uint64 connection_id = 3; | |||
| // But some request like CreateCache we have to use the session id and crc to connect to the server. | |||
| CacheClientInfo connection_info = 4; | |||
| } | |||
| int32 client_id = 5; | |||
| // Everything else is just vector of buffers | |||
| repeated bytes buf_data = 5; | |||
| repeated bytes buf_data = 6; | |||
| } | |||
| message CacheReply { | |||
| @@ -74,6 +74,9 @@ Status CacheServerGreeterImpl::Run() { | |||
| #if CACHE_LOCAL_CLIENT | |||
| RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | |||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful"; | |||
| auto cs = CacheServer::GetInstance().GetHWControl(); | |||
| // This shared memory is a hot memory and we will interleave among all the numa nodes. | |||
| cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); | |||
| #endif | |||
| } else { | |||
| std::string errMsg = "Fail to start server. "; | |||
| @@ -127,8 +130,13 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp | |||
| st_ = STATE::PROCESS; | |||
| svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this); | |||
| } else if (st_ == STATE::PROCESS) { | |||
| auto &cs = CacheServer::GetInstance(); | |||
| // Get a new tag and handle the next request before we serve the current request. | |||
| // The tag will be recycled when its state is changed to FINISH | |||
| // The tag will be recycled when its state is changed to FINISH. | |||
| // The number of free list queues is the same as the number of grpc threads. | |||
| // Where we get the free list it doesn't matter (as long we return it back to the right queue). | |||
| // We can round robin, use the qid or even use the worker id. We will use the free list queue | |||
| // where the current request comes from. | |||
| CacheServerRequest *next_rq; | |||
| RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq)); | |||
| RETURN_IF_NOT_OK((*next_rq)(svc, cq)); | |||
| @@ -138,8 +146,24 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp | |||
| type_ = static_cast<RequestType>(rq_.type()); | |||
| // Now we pass the address of this instance to CacheServer's main loop. | |||
| MS_LOG(DEBUG) << "Handle request " << *this; | |||
| auto &cs = CacheServer::GetInstance(); | |||
| RETURN_IF_NOT_OK(cs.PushRequest(myQID, this)); | |||
| // We will distribute the request evenly (or randomly) over all the numa nodes. | |||
| // The exception is BatchFetch which we need to pre-process here. | |||
| if (type_ == BaseRequest::RequestType::kBatchFetchRows) { | |||
| rc_ = cs.BatchFetchRows(&rq_, &reply_); | |||
| if (!rc_.IsInterrupted()) { | |||
| Status2CacheReply(rc_, &reply_); | |||
| st_ = CacheServerRequest::STATE::FINISH; | |||
| responder_.Finish(reply_, grpc::Status::OK, this); | |||
| } else { | |||
| return rc_; | |||
| } | |||
| } else { | |||
| // When the number of grpc workers is the same as the server workers, we will use this queue id | |||
| // and push to the corresponding queue. | |||
| bool random = cs.GetNumWorkers() != cs.GetNumGrpcWorkers(); | |||
| worker_id_t worker_id = random ? cs.GetRandomWorker() : myQID; | |||
| RETURN_IF_NOT_OK(cs.PushRequest(worker_id, this)); | |||
| } | |||
| } else if (st_ == STATE::FINISH) { | |||
| MS_LOG(DEBUG) << *this << " Finished."; | |||
| // Return back to the free list. | |||
| @@ -16,6 +16,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | |||
| #include <atomic> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| @@ -34,6 +35,7 @@ namespace dataset { | |||
| class CacheServerRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| friend class CacheService; | |||
| enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | |||
| explicit CacheServerRequest(int32_t queue_id) | |||
| : BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown), | |||
| @@ -0,0 +1,220 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||
| #ifdef NUMA_ENABLED | |||
| #include <numa.h> | |||
| #endif | |||
| #include <sched.h> | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <cctype> | |||
| #include <fstream> | |||
| #include <regex> | |||
| #include <thread> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheServerHW::CacheServerHW() { | |||
| num_cpus_ = std::thread::hardware_concurrency(); | |||
| MS_LOG(DEBUG) << "Number of cpu(s) : " << num_cpus_; | |||
| #ifdef NUMA_ENABLED | |||
| if (numa_enabled()) { | |||
| MS_LOG(WARNING) << "Numa support enabled"; | |||
| for (auto i = 0; i <= numa_max_node(); ++i) { | |||
| int64_t free_avail; | |||
| int64_t mem_avail = numa_node_size(i, &free_avail); | |||
| MS_LOG(INFO) << "Total physical/free RAM in bytes at node " << i << " : " << mem_avail << "/" << free_avail; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| int64_t CacheServerHW::GetTotalSystemMemory() { | |||
| auto pages = sysconf(_SC_PHYS_PAGES); | |||
| auto page_size = sysconf(_SC_PAGE_SIZE); | |||
| auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size); | |||
| MS_LOG(INFO) << "Total physical RAM in bytes: " << total; | |||
| return total; | |||
| } | |||
| Status CacheServerHW::SetDefaultMemoryPolicy(CachePoolPolicy policy) { | |||
| #ifdef NUMA_ENABLED | |||
| if (numa_enabled()) { | |||
| // Set our default memory policy. | |||
| switch (policy) { | |||
| case kLocal: | |||
| numa_set_localalloc(); | |||
| MS_LOG(DEBUG) << "Setting memory default policy to local node. Low level code may override the setting"; | |||
| break; | |||
| case kInterleave: | |||
| numa_set_interleave_mask(numa_all_nodes_ptr); | |||
| MS_LOG(DEBUG) << "Numa affinity is turned off. Use interleave memory policy as default."; | |||
| break; | |||
| case kOnNode: | |||
| case kPreferred: | |||
| RETURN_STATUS_UNEXPECTED("Unsupported memory policy"); | |||
| break; | |||
| case kNone: | |||
| default: | |||
| // No action taken. | |||
| break; | |||
| } | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServerHW::GetNumaNodeInfo() { | |||
| std::set<Path> numa_nodes_; | |||
| Path node(kSysNodePath); | |||
| auto it = Path::DirIterator::OpenDirectory(&node); | |||
| if (it == nullptr) { | |||
| MS_LOG(WARNING) << "Unable to open directory " << kSysNodePath << ". Skip scanning hardware info"; | |||
| return Status::OK(); | |||
| } | |||
| auto isdigit_string = [](const char *str) -> bool { | |||
| bool r = true; | |||
| for (auto i = 0; i < strlen(str); ++i) { | |||
| if (!std::isdigit(str[i])) { | |||
| r = false; | |||
| break; | |||
| } | |||
| } | |||
| return r; | |||
| }; | |||
| // Look for name starts with 'node' and followed by digits. | |||
| const char kNodeName[] = "node"; | |||
| while (it->hasNext()) { | |||
| auto p = it->next(); | |||
| const std::string entry = p.Basename(); | |||
| const char *name = entry.data(); | |||
| if (strncmp(name, kNodeName, 4) == 0 && isdigit_string(name + strlen(kNodeName))) { | |||
| numa_nodes_.insert(p); | |||
| } | |||
| } | |||
| // There should be at least one. But if not found in any case, just move on the | |||
| // rest of the server start up. | |||
| if (numa_nodes_.empty()) { | |||
| MS_LOG(WARNING) << "No numa nodes ? Skip scanning hardware info"; | |||
| return Status::OK(); | |||
| } | |||
| // For each numa node, get a list of CPU that is associated with it. | |||
| const char kCpuList[] = "cpulist"; | |||
| auto r = std::regex("[0-9]*-[0-9]*"); | |||
| for (Path p : numa_nodes_) { | |||
| auto node_dir = p.Basename().data(); | |||
| numa_id_t numa_node = strtol(node_dir + strlen(kNodeName), nullptr, 10); | |||
| Path f = p / kCpuList; | |||
| std::ifstream fs(f.toString()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + f.toString()); | |||
| std::string cpu_string; | |||
| cpu_set_t cpuset; | |||
| CPU_ZERO(&cpuset); | |||
| int32_t cpu_cnt = 0; | |||
| while (getline(fs, cpu_string)) { | |||
| // Now we parse the content of cpu_string. | |||
| std::sregex_iterator iter(cpu_string.begin(), cpu_string.end(), r); | |||
| std::sregex_iterator end; | |||
| while (iter != end) { | |||
| auto match = iter->str(); | |||
| auto pos = match.find_first_of('-'); | |||
| std::string min = match.substr(0, pos); | |||
| std::string max = match.substr(pos + 1); | |||
| cpu_id_t cpu_min = strtol(min.data(), nullptr, 10); | |||
| cpu_id_t cpu_max = strtol(max.data(), nullptr, 10); | |||
| MS_LOG(DEBUG) << "Numa node " << numa_node << " CPU(s) : " << cpu_min << "-" << cpu_max; | |||
| for (int i = cpu_min; i <= cpu_max; ++i) { | |||
| CPU_SET(i, &cpuset); | |||
| ++cpu_cnt; | |||
| } | |||
| ++iter; | |||
| } | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!fs.bad(), "Fail to read file: " + f.toString()); | |||
| fs.close(); | |||
| // Remember which cpu is attached to this numa node. | |||
| numa_cpuset_.emplace(numa_node, cpuset); | |||
| numa_cpu_cnt_.emplace(numa_node, cpu_cnt); | |||
| } | |||
| MS_LOG(DEBUG) << "Number of numa nodes : " << numa_cpuset_.size(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServerHW::SetAffinity(const Task &tk, numa_id_t numa_node) { | |||
| auto r = numa_cpuset_.find(numa_node); | |||
| if (r != numa_cpuset_.end()) { | |||
| auto err = pthread_setaffinity_np(tk.GetNativeHandle(), sizeof(r->second), &r->second); | |||
| if (err) { | |||
| std::string errMsg = "Unable to set affiity. Errno = " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Numa node " + std::to_string(numa_node) + " not found"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::vector<cpu_id_t> CacheServerHW::GetCpuList(numa_id_t numa_id) { | |||
| std::vector<cpu_id_t> v; | |||
| auto it = numa_cpuset_.find(numa_id); | |||
| if (it != numa_cpuset_.end()) { | |||
| auto &cpu_set = it->second; | |||
| for (auto i = 0; i < num_cpus_; ++i) { | |||
| if (CPU_ISSET(i, &cpu_set)) { | |||
| v.push_back(i); | |||
| } | |||
| } | |||
| } | |||
| return v; | |||
| } | |||
| numa_id_t CacheServerHW::GetMyNode() const { | |||
| numa_id_t node_id = 0; | |||
| auto cpu = sched_getcpu(); | |||
| #ifdef NUMA_ENABLED | |||
| node_id = numa_node_of_cpu(cpu); | |||
| #else | |||
| bool found = false; | |||
| for (auto it : numa_cpuset_) { | |||
| cpu_set_t &cpu_set = it.second; | |||
| if (CPU_ISSET(cpu, &cpu_set)) { | |||
| node_id = it.first; | |||
| found = true; | |||
| break; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "cpu id " << cpu << " found : " << std::boolalpha << found; | |||
| #endif | |||
| return node_id; | |||
| } | |||
| void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) { | |||
| #ifdef NUMA_ENABLED | |||
| if (numa_enabled()) { | |||
| numa_interleave_memory(ptr, sz, numa_all_nodes_ptr); | |||
| } | |||
| #endif | |||
| } | |||
| bool CacheServerHW::numa_enabled() { | |||
| #ifdef NUMA_ENABLED | |||
| return (numa_available() != -1); | |||
| #else | |||
| return false; | |||
| #endif | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ | |||
| #ifdef NUMA_ENABLED | |||
| #include <numa.h> | |||
| #endif | |||
| #include <sched.h> | |||
| #include <stdlib.h> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/util/memory_pool.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/util/task.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheServerHW { | |||
| public: | |||
| CacheServerHW(); | |||
| ~CacheServerHW() = default; | |||
| /// \brief Get Numa node info without using numa library | |||
| /// \return Status object | |||
| Status GetNumaNodeInfo(); | |||
| /// \brief Set thread affinity | |||
| Status SetAffinity(const Task &tk, numa_id_t numa_node); | |||
| /// \brief Get total number of cpu(s) | |||
| int32_t GetCpuCount() const { return num_cpus_; } | |||
| /// \brief Get total number of numa nodes | |||
| int32_t GetNumaNodeCount() const { return numa_cpuset_.empty() ? 1 : numa_cpuset_.size(); } | |||
| /// \brief Get a list of cpu for a given numa node. | |||
| std::vector<cpu_id_t> GetCpuList(numa_id_t numa_id); | |||
| static bool numa_enabled(); | |||
| /// \brief Return the numa the current thread is running on. | |||
| numa_id_t GetMyNode() const; | |||
| /// \brief Interleave a given memory block. Used by shared memory only. | |||
| static void InterleaveMemory(void *ptr, size_t sz); | |||
| /// \brief Set default memory policy. | |||
| static Status SetDefaultMemoryPolicy(CachePoolPolicy); | |||
| /// \brief This returns the size (in bytes) of the physical RAM on the machine. | |||
| /// \return the size (in bytes) of the physical RAM on the machine. | |||
| static int64_t GetTotalSystemMemory(); | |||
| private: | |||
| constexpr static char kSysNodePath[] = "/sys/devices/system/node"; | |||
| int32_t num_cpus_; | |||
| std::map<numa_id_t, cpu_set_t> numa_cpuset_; | |||
| std::map<numa_id_t, int32_t> numa_cpu_cnt_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ | |||
| @@ -54,6 +54,8 @@ ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds:: | |||
| #endif | |||
| try { | |||
| rq->set_type(static_cast<int16_t>(type)); | |||
| rq->set_client_id(-1); | |||
| rq->set_flag(0); | |||
| grpc::ChannelArguments args; | |||
| grpc::ClientContext ctx; | |||
| grpc::CompletionQueue cq; | |||
| @@ -0,0 +1,224 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <iterator> | |||
| #include <limits> | |||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| NumaMemoryPool::NumaMemoryPool(std::shared_ptr<CacheServerHW> hw, float memory_cap_ratio) | |||
| : hw_(std::move(hw)), memory_cap_ratio_(memory_cap_ratio) { | |||
| int64_t total_avail = 0; | |||
| // We will create a number of small Arenas to spread out the server threads so it | |||
| // will be less contention. If we link with the numa library, i.e. if | |||
| // NUMA_ENABLED is defined, we will make use of the low level numa library such that | |||
| // each Arena solely comes from one particular socket. | |||
| // The total number of Arenas will be controlled under the number of cpus. | |||
| auto num_cpus = hw_->GetCpuCount(); | |||
| memory_segments_.reserve(num_cpus); | |||
| arena_list_.reserve(num_cpus); | |||
| mux_ = std::make_unique<std::mutex[]>(num_cpus); | |||
| auto num_memory_nodes = num_cpus; | |||
| int64_t max_avail = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; | |||
| int64_t arena_sz = max_avail / num_memory_nodes; | |||
| // If arena_sz is too small, lower the number of Arenas. | |||
| if (arena_sz < std::numeric_limits<int32_t>::max()) { | |||
| arena_sz = round_up_4K(std::numeric_limits<int32_t>::max()); | |||
| num_memory_nodes = max_avail / arena_sz; | |||
| if (num_memory_nodes == 0) { | |||
| num_memory_nodes = 1; | |||
| arena_sz = max_avail; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Creating " << num_memory_nodes << " number of arena. Each one of size " << arena_sz; | |||
| #ifdef NUMA_ENABLED | |||
| if (numa_available() != -1) { | |||
| auto num_numa_nodes = hw_->GetNumaNodeCount(); | |||
| numa_id_t node_id = 0; | |||
| for (auto i = 0; i < num_memory_nodes; ++i) { | |||
| auto success = CreateMultipleArenas(arena_sz, node_id++ % num_numa_nodes, 1); | |||
| total_avail += success * arena_sz; | |||
| } | |||
| } else { | |||
| auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes); | |||
| total_avail += success * arena_sz; | |||
| } | |||
| #else | |||
| auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes); | |||
| total_avail += success * arena_sz; | |||
| #endif | |||
| memory_cap_ = total_avail; | |||
| MS_LOG(WARNING) << "Memory pool created. Total available memory " << memory_cap_ << " spread in " << nodes_.size() | |||
| << " arenas"; | |||
| int32_t slot = 0; | |||
| // Set up a map for future easy access. | |||
| for (auto node_id : nodes_) { | |||
| numa_map_[node_id].push_back(slot); | |||
| ++slot; | |||
| } | |||
| } | |||
| int32_t NumaMemoryPool::CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count) { | |||
| int32_t success = 0; | |||
| for (auto i = 0; i < repeat_count; ++i) { | |||
| #ifdef NUMA_ENABLED | |||
| void *ptr = numa_alloc_onnode(segment_sz, node_id); | |||
| #else | |||
| void *ptr = malloc(segment_sz); | |||
| #endif | |||
| if (ptr != nullptr) { | |||
| memory_segments_.emplace_back(ptr, segment_sz); | |||
| arena_list_.push_back(std::make_unique<ArenaImpl>(ptr, segment_sz)); | |||
| nodes_.push_back(node_id); | |||
| ++success; | |||
| } else { | |||
| // Skip the rest. | |||
| break; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Allocate " << success << " arenas from node " << node_id; | |||
| return success; | |||
| } | |||
| NumaMemoryPool::~NumaMemoryPool() { | |||
| if (!memory_segments_.empty()) { | |||
| for (auto &s : memory_segments_) { | |||
| #ifdef NUMA_ENABLED | |||
| numa_free(s.first, s.second); | |||
| #else | |||
| free(s.first); | |||
| #endif | |||
| } | |||
| } | |||
| } | |||
| Status NumaMemoryPool::Allocate(size_t n, void **p) { | |||
| RETURN_UNEXPECTED_IF_NULL(p); | |||
| auto mt = GetRandomDevice(); | |||
| Status rc; | |||
| void *ptr = nullptr; | |||
| auto num_segments = memory_segments_.size(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_segments > 0, "No numa nodes available"); | |||
| if (NumaAware()) { | |||
| auto num_numa_nodes = hw_->GetNumaNodeCount(); | |||
| // We will start from the numa node this worker id is running on and do a round robin search. | |||
| numa_id_t start = hw_->GetMyNode(); | |||
| numa_id_t node_id = start; | |||
| do { | |||
| auto it = numa_map_.find(node_id); | |||
| if (it != numa_map_.end()) { | |||
| auto &slots = it->second; | |||
| auto num_slots = slots.size(); | |||
| std::uniform_int_distribution<int32_t> distribution(0, num_slots - 1); | |||
| auto start_slot = distribution(mt); | |||
| int32_t inx = start_slot; | |||
| do { | |||
| int32_t k = slots.at(inx); | |||
| std::unique_lock lock_x(mux_[k]); | |||
| auto &impl = arena_list_.at(k); | |||
| rc = impl->Allocate(n, &ptr); | |||
| if (rc.IsOk()) { | |||
| *p = ptr; | |||
| break; | |||
| } else if (rc.IsOutofMemory()) { | |||
| inx = (inx + 1) % num_slots; | |||
| } else { | |||
| return rc; | |||
| } | |||
| } while (inx != start_slot); | |||
| } | |||
| // We have done searching for this numa node. If not found, move to the next node. | |||
| if (ptr == nullptr) { | |||
| node_id = (node_id + 1) % num_numa_nodes; | |||
| } else { | |||
| break; | |||
| } | |||
| } while (node_id != start); | |||
| } else { | |||
| // If not numa aware, just randomly pick a slot. | |||
| std::uniform_int_distribution<int32_t> distribution(0, num_segments - 1); | |||
| auto start_slot = distribution(mt); | |||
| int32_t slot = start_slot; | |||
| do { | |||
| std::unique_lock lock_x(mux_[slot]); | |||
| auto &impl = arena_list_.at(slot); | |||
| rc = impl->Allocate(n, &ptr); | |||
| if (rc.IsOk()) { | |||
| *p = ptr; | |||
| break; | |||
| } else if (rc.IsOutofMemory()) { | |||
| // Make the next arena and continue. | |||
| slot = (slot + 1) % num_segments; | |||
| } else { | |||
| return rc; | |||
| } | |||
| } while (slot != start_slot); | |||
| } | |||
| // Handle the case we have done one round robin search. | |||
| if (ptr == nullptr) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| return rc; | |||
| } | |||
| void NumaMemoryPool::Deallocate(void *p) { | |||
| // Find out which numa slot it comes from. | |||
| auto slot = Locate(p); | |||
| MS_ASSERT(slot != -1); | |||
| std::unique_lock lock_x(mux_[slot]); | |||
| auto &impl = arena_list_.at(slot); | |||
| impl->Deallocate(p); | |||
| } | |||
| int NumaMemoryPool::PercentFree() const { | |||
| int percent_free = 0; | |||
| int num_arena = 0; | |||
| for (auto const &p : arena_list_) { | |||
| percent_free += p->PercentFree(); | |||
| num_arena++; | |||
| } | |||
| if (num_arena) { | |||
| return percent_free / num_arena; | |||
| } else { | |||
| return 100; | |||
| } | |||
| } | |||
| int32_t NumaMemoryPool::Locate(void *p) const { | |||
| int32_t slot = 0; | |||
| char *mem = reinterpret_cast<char *>(p); | |||
| for (slot = 0; slot < memory_segments_.size(); ++slot) { | |||
| auto elem = memory_segments_.at(slot); | |||
| char *q = reinterpret_cast<char *>(elem.first); | |||
| if (mem >= q && mem < q + elem.second) { | |||
| return slot; | |||
| } | |||
| } | |||
| return -1; | |||
| } | |||
| std::vector<numa_id_t> NumaMemoryPool::GetAvailableNodes() const { | |||
| std::vector<numa_id_t> v; | |||
| std::transform(numa_map_.begin(), numa_map_.end(), std::back_inserter(v), | |||
| [](const std::pair<numa_id_t, std::vector<int32_t>> &v) { return v.first; }); | |||
| return v; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,195 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ | |||
| #include <limits> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/util/memory_pool.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief An allocator but for a particular numa node. | |||
| template <typename T> | |||
| class NumaAllocator { | |||
| public: | |||
| explicit NumaAllocator(numa_id_t node_id, CachePoolPolicy policy) | |||
| : policy_(policy), numa_enabled_(false), node_id_(node_id) { | |||
| #ifdef NUMA_ENABLED | |||
| numa_enabled_ = numa_available() != -1; | |||
| #endif | |||
| } | |||
| ~NumaAllocator() = default; | |||
| template <typename U> | |||
| explicit NumaAllocator(NumaAllocator<U> const &rhs) | |||
| : policy_(rhs.policy_), numa_enabled_(rhs.numa_enabled_), node_id_(rhs.node_id_) {} | |||
| template <typename U> | |||
| bool operator==(Allocator<U> const &rhs) const { | |||
| return node_id_ == rhs.node_id_; | |||
| } | |||
| template <typename U> | |||
| bool operator!=(Allocator<U> const &rhs) const { | |||
| return node_id_ != rhs.node_id_; | |||
| } | |||
| template <typename U> | |||
| friend class NumaAllocator; | |||
| using value_type = T; | |||
| using pointer = T *; | |||
| using const_pointer = const T *; | |||
| using reference = T &; | |||
| using const_reference = const T &; | |||
| using size_type = uint64_t; | |||
| using difference_type = std::ptrdiff_t; | |||
| template <typename U> | |||
| struct rebind { | |||
| using other = Allocator<U>; | |||
| }; | |||
| using propagate_on_container_copy_assignment = std::true_type; | |||
| using propagate_on_container_move_assignment = std::true_type; | |||
| using propagate_on_container_swap = std::true_type; | |||
| /// Allocate memory on this node only. Return nullptr if no memory on this numa node. | |||
| /// \note. This version will not throw if we can't allocate memory from this node. | |||
| /// User must check if the pointer returned is null or not. | |||
| pointer allocate(std::size_t n) noexcept { | |||
| auto sz = n * sizeof(T); | |||
| void *p = nullptr; | |||
| #ifdef NUMA_ENABLED | |||
| if (numa_enabled_) { | |||
| switch (policy_) { | |||
| case kPreferred: | |||
| numa_set_preferred(node_id_); | |||
| p = numa_alloc(sz); | |||
| break; | |||
| case kLocal: | |||
| p = numa_alloc_local(sz); | |||
| break; | |||
| case kInterleave: | |||
| p = numa_alloc_interleaved(sz); | |||
| break; | |||
| case kOnNode: | |||
| p = numa_alloc_onnode(sz, node_id_); | |||
| break; | |||
| case kNone: | |||
| default: | |||
| p = numa_alloc(sz); | |||
| break; | |||
| } | |||
| } else { | |||
| p = malloc(sz); | |||
| } | |||
| #else | |||
| p = malloc(sz); | |||
| #endif | |||
| return reinterpret_cast<pointer>(p); | |||
| } | |||
| /// Free a memory allocated on this node. | |||
| void deallocate(pointer p, std::size_t n) noexcept { | |||
| #ifdef NUMA_ENABLED | |||
| if (numa_enabled_) { | |||
| numa_free(p, n * sizeof(T)); | |||
| } else { | |||
| free(p); | |||
| } | |||
| #else | |||
| free(p); | |||
| #endif | |||
| } | |||
| /// \brief Allow one to change to another numa node | |||
| void SetNodeId(numa_id_t node_id) { node_id_ = node_id; } | |||
| /// \brif Getter for node_id; | |||
| numa_id_t GetNodeId() const { return node_id_; } | |||
| /// \brief Getter for policy | |||
| CachePoolPolicy GetPolicy() const { return policy_; } | |||
| private: | |||
| CachePoolPolicy policy_; | |||
| bool numa_enabled_; | |||
| numa_id_t node_id_; | |||
| }; | |||
| /// \brief A NumaMemoryPool is like a CircularPool but all the arenas have already been allocated | |||
| /// and each one comes from a numa socket. Memory is allocated using OnNode policy. That is, | |||
| /// it is solely comes from one particular numa node, and is not interleaved. | |||
| class NumaMemoryPool : public MemoryPool { | |||
| public: | |||
| explicit NumaMemoryPool(std::shared_ptr<CacheServerHW> hw, float memory_cap_ratio); | |||
| ~NumaMemoryPool() override; | |||
| // As a derived class, we override the following functions | |||
| Status Allocate(size_t size, void **pVoid) override; | |||
| void Deallocate(void *pVoid) override; | |||
| Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||
| uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); } | |||
| int PercentFree() const override; | |||
| /// \brief Return if the memory pool is numa aware | |||
| bool NumaAware() const { return CacheServerHW::numa_enabled(); } | |||
| /// \brief. This returns all the numa nodes that we are able to allocate memory from. | |||
| std::vector<numa_id_t> GetAvailableNodes() const; | |||
| /// \brief. Given a pointer (allocated from this pool), return the numa node where it is located. | |||
| /// \note. -1 is returned if not found. | |||
| numa_id_t FindNode(void *p) const { | |||
| auto slot = Locate(p); | |||
| if (slot != -1) { | |||
| return nodes_.at(slot); | |||
| } else { | |||
| return -1; | |||
| } | |||
| } | |||
| /// \brief Return maximum available memory | |||
| int64_t GetAvailableMemory() const { return memory_cap_; } | |||
| private: | |||
| std::shared_ptr<CacheServerHW> hw_; | |||
| float memory_cap_ratio_; | |||
| int64_t memory_cap_; | |||
| std::vector<std::pair<void *, int64_t>> memory_segments_; | |||
| std::vector<std::unique_ptr<ArenaImpl>> arena_list_; | |||
| std::unique_ptr<std::mutex[]> mux_; | |||
| std::vector<numa_id_t> nodes_; | |||
| std::map<numa_id_t, std::vector<int32_t>> numa_map_; | |||
| /// \brief. Returns the slot that a given memory comes from. | |||
| /// \return slot from numa_segments. -1 if not found. | |||
| int32_t Locate(void *p) const; | |||
| /// If numa library is not linked, or numa_availble() return -1, we will fall back to this method. | |||
| int32_t CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ | |||
| @@ -15,18 +15,14 @@ | |||
| */ | |||
| #include <algorithm> | |||
| #include "utils/ms_utils.h" | |||
| #include "minddata/dataset/util/cache_pool.h" | |||
| #include "minddata/dataset/engine/cache/cache_pool.h" | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root) | |||
| : alloc_(alloc), | |||
| root_(root), | |||
| subfolder_(Services::GetUniqueID()), | |||
| sm_(nullptr), | |||
| tree_(nullptr), | |||
| custom_arena_(ourOwnArena) {} | |||
| CachePool::CachePool(std::shared_ptr<NumaMemoryPool> mp, const std::string &root) | |||
| : mp_(std::move(mp)), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} | |||
| Status CachePool::DoServiceStart() { | |||
| tree_ = std::make_shared<data_index>(); | |||
| @@ -36,10 +32,11 @@ Status CachePool::DoServiceStart() { | |||
| RETURN_IF_NOT_OK(spill.CreateDirectories()); | |||
| sm_ = std::make_shared<StorageManager>(spill); | |||
| RETURN_IF_NOT_OK(sm_->ServiceStart()); | |||
| MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); | |||
| MS_LOG(INFO) << "CachePool will use disk folder: " << spill.toString(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CachePool::DoServiceStop() { | |||
| Status rc; | |||
| Status rc2; | |||
| @@ -50,14 +47,14 @@ Status CachePool::DoServiceStop() { | |||
| } | |||
| } | |||
| sm_.reset(); | |||
| // If it is our own arena, skip freeing individual pieces. | |||
| if (!custom_arena_) { | |||
| for (auto &bl : *tree_) { | |||
| if (bl.ptr != nullptr) { | |||
| alloc_.deallocate(bl.ptr, bl.sz); | |||
| } | |||
| value_allocator alloc(mp_); | |||
| for (auto &bl : *tree_) { | |||
| if (bl.ptr != nullptr) { | |||
| alloc.deallocate(bl.ptr, bl.sz); | |||
| } | |||
| } | |||
| tree_.reset(); | |||
| if (!root_.toString().empty()) { | |||
| Path spill = GetSpillPath(); | |||
| @@ -75,8 +72,10 @@ Status CachePool::DoServiceStop() { | |||
| } | |||
| return rc2; | |||
| } | |||
| CachePool::~CachePool() noexcept { (void)ServiceStop(); } | |||
| Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly) { | |||
| Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf) { | |||
| DataLocator bl; | |||
| Status rc; | |||
| size_t sz = 0; | |||
| @@ -85,26 +84,35 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic | |||
| sz += v.GetSize(); | |||
| } | |||
| bl.sz = sz; | |||
| try { | |||
| if (!writeToDiskDirectly) { | |||
| bl.ptr = alloc_.allocate(sz); | |||
| // We will do a piecewise copy. | |||
| WritableSlice dest(bl.ptr, bl.sz); | |||
| size_t pos = 0; | |||
| for (auto &v : buf) { | |||
| WritableSlice out(dest, pos); | |||
| rc = WritableSlice::Copy(&out, v); | |||
| if (rc.IsError()) { | |||
| break; | |||
| } | |||
| pos += v.GetSize(); | |||
| } | |||
| rc = mp_->Allocate(sz, reinterpret_cast<void **>(&bl.ptr)); | |||
| if (rc.IsOk()) { | |||
| // Write down which numa node where we allocate from. It only make sense if the policy is kOnNode. | |||
| if (CacheServerHW::numa_enabled()) { | |||
| auto &cs = CacheServer::GetInstance(); | |||
| auto node_id = cs.GetHWControl()->GetMyNode(); | |||
| bl.node_id = mp_->FindNode(bl.ptr); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(bl.node_id != -1, "Allocator is not from numa memory pool"); | |||
| bl.node_hit = (bl.node_id == node_id); | |||
| } | |||
| // We will do a piecewise copy. | |||
| WritableSlice dest(bl.ptr, bl.sz); | |||
| size_t pos = 0; | |||
| for (auto &v : buf) { | |||
| WritableSlice out(dest, pos); | |||
| rc = WritableSlice::Copy(&out, v); | |||
| if (rc.IsError()) { | |||
| alloc_.deallocate(bl.ptr, sz); | |||
| bl.ptr = nullptr; | |||
| return rc; | |||
| break; | |||
| } | |||
| } else if (sm_ != nullptr) { | |||
| pos += v.GetSize(); | |||
| } | |||
| if (rc.IsError()) { | |||
| mp_->Deallocate(bl.ptr); | |||
| bl.ptr = nullptr; | |||
| return rc; | |||
| } | |||
| } else if (rc.IsOutofMemory()) { | |||
| // If no memory, write to disk. | |||
| if (sm_ != nullptr) { | |||
| MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes."; | |||
| RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | |||
| } else { | |||
| @@ -112,12 +120,8 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic | |||
| // instead. | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } catch (std::bad_alloc &e) { | |||
| if (sm_ != nullptr) { | |||
| RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | |||
| } else { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } else { | |||
| return rc; | |||
| } | |||
| // Insert into the B+ tree. We may still get out of memory error. So need to catch it. | |||
| try { | |||
| @@ -127,10 +131,13 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic | |||
| } | |||
| // Duplicate key is treated as error and we will also free the memory. | |||
| if (rc.IsError() && bl.ptr != nullptr) { | |||
| alloc_.deallocate(bl.ptr, sz); | |||
| mp_->Deallocate(bl.ptr); | |||
| bl.ptr = nullptr; | |||
| return rc; | |||
| } | |||
| return rc; | |||
| } | |||
| Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { | |||
| RETURN_UNEXPECTED_IF_NULL(dest); | |||
| auto r = tree_->Search(key); | |||
| @@ -156,13 +163,14 @@ Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *byt | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } | |||
| Path CachePool::GetSpillPath() const { | |||
| auto spill = Path(root_) / subfolder_; | |||
| return spill; | |||
| } | |||
| CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||
| CacheStat cs{-1, -1, 0, 0, 0}; | |||
| CacheStat cs{-1, -1, 0, 0, 0, 0}; | |||
| int64_t total_sz = 0; | |||
| if (tree_->begin() != tree_->end()) { | |||
| cs.min_key = tree_->begin().key(); | |||
| @@ -174,6 +182,9 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||
| } else { | |||
| ++cs.num_disk_cached; | |||
| } | |||
| if (it.value().node_hit) { | |||
| ++cs.num_numa_hit; | |||
| } | |||
| auto cur_key = it.key(); | |||
| if (GetMissingKeys) { | |||
| for (auto i = cs.max_key + 1; i < cur_key; ++i) { | |||
| @@ -192,49 +203,26 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||
| } | |||
| return cs; | |||
| } | |||
| Status CachePool::Spill(CachePool::DataLocator *dl) { | |||
| if (sm_ == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("No disk storage to spill"); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(dl); | |||
| RETURN_UNEXPECTED_IF_NULL(dl->ptr); | |||
| if (dl->storage_key == 0) { | |||
| ReadableSlice data(dl->ptr, dl->sz); | |||
| RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); | |||
| } | |||
| alloc_.deallocate(dl->ptr, dl->sz); | |||
| dl->ptr = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| Status CachePool::Locate(CachePool::DataLocator *dl) { | |||
| RETURN_UNEXPECTED_IF_NULL(dl); | |||
| if (dl->ptr == nullptr) { | |||
| if (sm_ == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); | |||
| } | |||
| try { | |||
| dl->ptr = alloc_.allocate(dl->sz); | |||
| WritableSlice dest(dl->ptr, dl->sz); | |||
| Status rc = Read(dl->storage_key, &dest); | |||
| if (rc.IsError()) { | |||
| alloc_.deallocate(dl->ptr, dl->sz); | |||
| dl->ptr = nullptr; | |||
| return rc; | |||
| } | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| size_t CachePool::GetSize(CachePool::key_type key) const { | |||
| Status CachePool::GetDataLocator(key_type key, const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, | |||
| flatbuffers::Offset<DataLocatorMsg> *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| auto r = tree_->Search(key); | |||
| if (r.second) { | |||
| auto &it = r.first; | |||
| return it->sz; | |||
| DataLocatorMsgBuilder bld(*fbb); | |||
| bld.add_key(key); | |||
| bld.add_size(it->sz); | |||
| bld.add_node_id(it->node_id); | |||
| bld.add_addr(reinterpret_cast<int64_t>(it->ptr)); | |||
| auto offset = bld.Finish(); | |||
| *out = offset; | |||
| } else { | |||
| return 0; | |||
| // Key not in the cache. | |||
| auto offset = CreateDataLocatorMsg(*fbb, key, 0, 0, 0); | |||
| *out = offset; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -19,11 +19,14 @@ | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||
| #include "minddata/dataset/engine/cache/storage_manager.h" | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/service.h" | |||
| #include "minddata/dataset/util/slice.h" | |||
| #include "minddata/dataset/util/storage_manager.h" | |||
| #include "minddata/dataset/util/auto_index.h" | |||
| #include "minddata/dataset/util/btree.h" | |||
| @@ -45,13 +48,15 @@ class CachePool : public Service { | |||
| // An internal class to locate the whereabouts of a backed up buffer which can be either in | |||
| class DataLocator { | |||
| public: | |||
| DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} | |||
| DataLocator() : ptr(nullptr), sz(0), node_id(0), node_hit(false), storage_key(0) {} | |||
| ~DataLocator() = default; | |||
| DataLocator(const DataLocator &other) = default; | |||
| DataLocator &operator=(const DataLocator &other) = default; | |||
| DataLocator(DataLocator &&other) noexcept { | |||
| ptr = other.ptr; | |||
| sz = other.sz; | |||
| node_id = other.node_id; | |||
| node_hit = other.node_hit; | |||
| storage_key = other.storage_key; | |||
| other.ptr = nullptr; | |||
| other.sz = 0; | |||
| @@ -61,6 +66,8 @@ class CachePool : public Service { | |||
| if (&other != this) { | |||
| ptr = other.ptr; | |||
| sz = other.sz; | |||
| node_id = other.node_id; | |||
| node_hit = other.node_hit; | |||
| storage_key = other.storage_key; | |||
| other.ptr = nullptr; | |||
| other.sz = 0; | |||
| @@ -70,6 +77,8 @@ class CachePool : public Service { | |||
| } | |||
| pointer ptr; | |||
| size_t sz; | |||
| numa_id_t node_id; // where the numa node the memory is allocated to | |||
| bool node_hit; // we can allocate to the preferred node | |||
| StorageManager::key_type storage_key; | |||
| }; | |||
| @@ -85,19 +94,20 @@ class CachePool : public Service { | |||
| int64_t num_mem_cached; | |||
| int64_t num_disk_cached; | |||
| int64_t average_cache_sz; | |||
| int64_t num_numa_hit; | |||
| std::vector<key_type> gap; | |||
| }; | |||
| /// \brief Constructor | |||
| /// \param alloc Allocator to allocate memory from | |||
| /// \param root Optional disk folder to spill | |||
| explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = ""); | |||
| explicit CachePool(std::shared_ptr<NumaMemoryPool> mp, const std::string &root = ""); | |||
| CachePool(const CachePool &) = delete; | |||
| CachePool(CachePool &&) = delete; | |||
| CachePool &operator=(const CachePool &) = delete; | |||
| CachePool &operator=(CachePool &&) = delete; | |||
| ~CachePool() noexcept; | |||
| ~CachePool() noexcept override; | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() override; | |||
| @@ -110,7 +120,8 @@ class CachePool : public Service { | |||
| /// \param[in] buf A sequence of ReadableSlice objects. | |||
| /// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory | |||
| /// \return Error code | |||
| Status Insert(key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly); | |||
| Status Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf); | |||
| /// \brief Restore a cached buffer (from memory or disk) | |||
| /// \param[in] key A previous key returned from Insert | |||
| /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice | |||
| @@ -118,18 +129,14 @@ class CachePool : public Service { | |||
| /// \return Error code | |||
| Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; | |||
| Status Spill(DataLocator *dl); | |||
| Status Locate(DataLocator *dl); | |||
| size_t GetSize(key_type key) const; | |||
| /// \brief Serialize a DataLocator | |||
| Status GetDataLocator(key_type, const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, | |||
| flatbuffers::Offset<DataLocatorMsg> *) const; | |||
| /// \brief Get statistics. | |||
| /// \return CacheStat object | |||
| CacheStat GetStat(bool GetMissingKeys = false) const; | |||
| const value_allocator &get_allocator() const; | |||
| std::string MyName() const { return subfolder_; } | |||
| /// \brief Toggle locking | |||
| @@ -137,12 +144,11 @@ class CachePool : public Service { | |||
| void SetLocking(bool on_off) { tree_->SetLocking(on_off); } | |||
| private: | |||
| value_allocator alloc_; | |||
| std::shared_ptr<NumaMemoryPool> mp_; | |||
| Path root_; | |||
| const std::string subfolder_; | |||
| std::shared_ptr<StorageManager> sm_; | |||
| std::shared_ptr<data_index> tree_; | |||
| bool custom_arena_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -14,6 +14,11 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| #include <sched.h> | |||
| #include <sys/types.h> | |||
| #include <unistd.h> | |||
| #endif | |||
| #include <cstdlib> | |||
| #include <thread> | |||
| #include "minddata/dataset/core/constants.h" | |||
| @@ -106,6 +111,7 @@ Status CacheRowRequest::PostReply() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheRowRequest::Prepare() { | |||
| if (BitTest(rq_.flag(), kDataIsInSharedMemory)) { | |||
| // First one is cookie, followed by address and then size. | |||
| @@ -118,10 +124,21 @@ Status CacheRowRequest::Prepare() { | |||
| return Status::OK(); | |||
| } | |||
| BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, | |||
| bool local_bypass) | |||
| : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) { | |||
| rq_.set_connection_id(connection_id); | |||
| CacheRowRequest::CacheRowRequest(const CacheClient *cc) | |||
| : BaseRequest(RequestType::kCacheRow), | |||
| support_local_bypass_(cc->local_bypass_), | |||
| addr_(-1), | |||
| sz_(0), | |||
| row_id_from_server_(-1) { | |||
| rq_.set_connection_id(cc->server_connection_id_); | |||
| rq_.set_client_id(cc->client_id_); | |||
| rq_.add_buf_data(cc->cookie_); | |||
| } | |||
| BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id) | |||
| : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) { | |||
| rq_.set_connection_id(cc->server_connection_id_); | |||
| rq_.set_client_id(cc->client_id_); | |||
| rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0); | |||
| // Convert the row id into a flatbuffer | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| @@ -186,9 +203,9 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, in | |||
| return Status::OK(); | |||
| } | |||
| CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||
| CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||
| CreateCacheRequest::CreateCacheFlag flag) | |||
| : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) { | |||
| : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) { | |||
| // Type has been set already in the base constructor. So we need to fill in the connection info. | |||
| // On successful return, we will get the connection id | |||
| rq_.mutable_connection_info()->operator=(cinfo); | |||
| @@ -209,6 +226,41 @@ Status CreateCacheRequest::Prepare() { | |||
| } | |||
| } | |||
| Status CreateCacheRequest::PostReply() { | |||
| auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | |||
| cc_->server_connection_id_ = p->connection_id(); | |||
| cc_->cookie_ = p->cookie()->str(); | |||
| cc_->client_id_ = p->client_id(); | |||
| // Next is a set of cpu id that we should re-adjust ourselves for better affinity. | |||
| auto sz = p->cpu_id()->size(); | |||
| cc_->cpu_list_.reserve(sz); | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| std::string c_list; | |||
| cpu_set_t cpu_set; | |||
| CPU_ZERO(&cpu_set); | |||
| #endif | |||
| for (auto i = 0; i < sz; ++i) { | |||
| auto cpu_id = p->cpu_id()->Get(i); | |||
| cc_->cpu_list_.push_back(cpu_id); | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| c_list += std::to_string(cpu_id) + " "; | |||
| CPU_SET(cpu_id, &cpu_set); | |||
| #endif | |||
| } | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| if (sz > 0) { | |||
| auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set); | |||
| if (err == -1) { | |||
| RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno)); | |||
| } | |||
| MS_LOG(WARNING) << "Changing cpu affinity to the following list of cpu id: " + c_list; | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) { | |||
| try { | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| @@ -245,6 +297,7 @@ Status GetStatRequest::PostReply() { | |||
| stat_.num_disk_cached = msg->num_disk_cached(); | |||
| stat_.num_mem_cached = msg->num_mem_cached(); | |||
| stat_.avg_cache_sz = msg->avg_cache_sz(); | |||
| stat_.num_numa_hit = msg->num_numa_hit(); | |||
| stat_.max_row_id = msg->max_row_id(); | |||
| stat_.min_row_id = msg->min_row_id(); | |||
| stat_.cache_service_state = msg->state(); | |||
| @@ -255,14 +308,15 @@ Status ListSessionsRequest::PostReply() { | |||
| auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data()); | |||
| auto session_vector = msg->sessions(); | |||
| for (auto i = 0; i < session_vector->size(); ++i) { | |||
| SessionCacheInfo current_info; | |||
| CacheServiceStat stats; | |||
| SessionCacheInfo current_info{}; | |||
| CacheServiceStat stats{}; | |||
| auto current_session_info = session_vector->Get(i); | |||
| current_info.session_id = current_session_info->session_id(); | |||
| current_info.connection_id = current_session_info->connection_id(); | |||
| stats.num_mem_cached = current_session_info->stats()->num_mem_cached(); | |||
| stats.num_disk_cached = current_session_info->stats()->num_disk_cached(); | |||
| stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz(); | |||
| stats.num_numa_hit = current_session_info->stats()->num_numa_hit(); | |||
| stats.min_row_id = current_session_info->stats()->min_row_id(); | |||
| stats.max_row_id = current_session_info->stats()->max_row_id(); | |||
| stats.cache_service_state = current_session_info->stats()->state(); | |||
| @@ -41,6 +41,7 @@ struct CacheServiceStat { | |||
| int64_t num_mem_cached; | |||
| int64_t num_disk_cached; | |||
| int64_t avg_cache_sz; | |||
| int64_t num_numa_hit; | |||
| row_id_type min_row_id; | |||
| row_id_type max_row_id; | |||
| int8_t cache_service_state; | |||
| @@ -75,6 +76,8 @@ class BaseRequest { | |||
| kHeartBeat = 14, | |||
| kToggleWriteMode = 15, | |||
| kListSessions = 16, | |||
| kConnectReset = 17, | |||
| kInternalFetchRow = 18, | |||
| // Add new request before it. | |||
| kRequestUnknown = 32767 | |||
| }; | |||
| @@ -84,10 +87,15 @@ class BaseRequest { | |||
| friend class CacheClientGreeter; | |||
| friend class CacheClientRequestTag; | |||
| friend class CacheClient; | |||
| friend class CacheService; | |||
| /// \brief Base class of a cache server request | |||
| /// \param type Type of the request | |||
| explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast<int16_t>(type_)); } | |||
| explicit BaseRequest(RequestType type) : type_(type) { | |||
| rq_.set_type(static_cast<int16_t>(type_)); | |||
| rq_.set_client_id(-1); | |||
| rq_.set_flag(0); | |||
| } | |||
| virtual ~BaseRequest() = default; | |||
| /// \brief A print method for debugging | |||
| @@ -138,15 +146,7 @@ class CacheRowRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| friend class CacheClient; | |||
| explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass) | |||
| : BaseRequest(RequestType::kCacheRow), | |||
| support_local_bypass_(local_bypass), | |||
| addr_(-1), | |||
| sz_(0), | |||
| row_id_from_server_(-1) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(cookie); | |||
| } | |||
| explicit CacheRowRequest(const CacheClient *cc); | |||
| ~CacheRowRequest() override = default; | |||
| /// \brief Serialize a TensorRow for streaming to the cache server | |||
| @@ -193,7 +193,7 @@ class BatchFetchRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| friend class CacheService; | |||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass); | |||
| BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id); | |||
| ~BatchFetchRequest() override = default; | |||
| Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); | |||
| @@ -212,21 +212,18 @@ class CreateCacheRequest : public BaseRequest { | |||
| /// \param connection_id | |||
| /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited | |||
| /// \param flag Attributes of the cache. | |||
| explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||
| explicit CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||
| CreateCacheFlag flag = CreateCacheFlag::kNone); | |||
| ~CreateCacheRequest() override = default; | |||
| void ParseResult(connection_id_type *id, std::string *out) { | |||
| auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | |||
| *id = p->connection_id(); | |||
| *out = p->cookie()->str(); | |||
| } | |||
| /// Overload the base class Prepare | |||
| /// Overload the base class Prepare/PostReply | |||
| Status Prepare() override; | |||
| Status PostReply() override; | |||
| private: | |||
| uint64_t cache_mem_sz_; | |||
| CreateCacheFlag flag_; | |||
| CacheClient *cc_; | |||
| }; | |||
| /// \brief Request to get all the keys not present at the server. | |||
| @@ -396,6 +393,23 @@ class ToggleWriteModeRequest : public BaseRequest { | |||
| } | |||
| ~ToggleWriteModeRequest() override = default; | |||
| }; | |||
| class ConnectResetRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit ConnectResetRequest(connection_id_type connection_id, int32_t client_id) | |||
| : BaseRequest(RequestType::kConnectReset) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.set_client_id(client_id); | |||
| } | |||
| ~ConnectResetRequest() override = default; | |||
| /// Override the base class function | |||
| Status Prepare() override { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq_.client_id() != -1, "Invalid client id"); | |||
| return Status::OK(); | |||
| } | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ | |||
| @@ -17,6 +17,7 @@ | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <limits> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| @@ -43,36 +44,57 @@ Status CacheServer::DoServiceStart() { | |||
| MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; | |||
| } | |||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||
| // There will be num_workers_ threads working on the grpc queue and | |||
| // the same number of threads working on the CacheServerRequest queue. | |||
| RETURN_IF_NOT_OK(hw_info_->GetNumaNodeInfo()); | |||
| auto num_numa_nodes = GetNumaNodeCount(); | |||
| // If we link with numa library. Set default memory policy. | |||
| // If we don't pin thread to cpu, then use up all memory controllers to maximize | |||
| // memory bandwidth. | |||
| RETURN_IF_NOT_OK( | |||
| CacheServerHW::SetDefaultMemoryPolicy(numa_affinity_ ? CachePoolPolicy::kLocal : CachePoolPolicy::kInterleave)); | |||
| auto my_node = hw_info_->GetMyNode(); | |||
| MS_LOG(DEBUG) << "Cache server is running on numa node " << my_node; | |||
| // Bump up num_workers_ to at least the number of numa nodes | |||
| num_workers_ = std::max(num_numa_nodes, num_workers_); | |||
| // But also it shouldn't be too many more than the hardware concurrency | |||
| auto num_cpus = hw_info_->GetCpuCount(); | |||
| num_workers_ = std::min(2 * num_cpus, num_workers_); | |||
| // Round up num_workers to a multiple of numa nodes. | |||
| auto remainder = num_workers_ % num_numa_nodes; | |||
| if (remainder > 0) num_workers_ += (num_numa_nodes - remainder); | |||
| MS_LOG(INFO) << "Re-adjusting the number of workers to " << num_workers_; | |||
| // There will be some threads working on the grpc queue and | |||
| // some number of threads working on the CacheServerRequest queue. | |||
| // Like a connector object we will set up the same number of queues but | |||
| // we do not need to preserve any order. We will set the capacity of | |||
| // each queue to be 128 since we are just pushing memory pointers which | |||
| // each queue to be 64 since we are just pushing memory pointers which | |||
| // is only 8 byte each. | |||
| const int32_t que_capacity = 128; | |||
| const int32_t kQueCapacity = 64; | |||
| // This is the request queue from the client | |||
| cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>(); | |||
| cache_q_->Init(num_workers_, que_capacity); | |||
| cache_q_->Init(num_workers_, kQueCapacity); | |||
| // We will match the number of grpc workers with the number of server workers. | |||
| // But technically they don't have to be the same. | |||
| num_grpc_workers_ = num_workers_; | |||
| MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_; | |||
| // For the grpc completion queue to work, we need to allocate some | |||
| // tags which in our case are instances of CacheServerQuest. | |||
| // They got recycled and we will allocate them in advance and push | |||
| // them into some free list. We need more (two or three times) the | |||
| // size of the cache_q. While each worker is working on a CacheSerRequest, | |||
| // we need some extra running injecting in the the qrpc completion queue. | |||
| const int32_t multiplier = 3; | |||
| const int32_t free_list_capacity = multiplier * (que_capacity + 1); | |||
| const int32_t kMultiplier = 2; | |||
| int ratio = num_workers_ / num_grpc_workers_; | |||
| if (num_workers_ % num_grpc_workers_) ++ratio; | |||
| const int32_t free_list_capacity = kMultiplier * (kQueCapacity + 1) * ratio; | |||
| free_list_ = std::make_shared<QueueList<CacheServerRequest *>>(); | |||
| free_list_->Init(num_workers_, free_list_capacity); | |||
| // We need to have a reference to the services memory pool in case | |||
| // the Services goes out of scope earlier than us since it is a singleton | |||
| mp_ = Services::GetInstance().GetServiceMemPool(); | |||
| Allocator<CacheServerRequest> alloc(mp_); | |||
| tag_.reserve(num_workers_); | |||
| // Now we populate all free list. | |||
| for (auto m = 0; m < num_workers_; ++m) { | |||
| // Ideally we allocate all the free list in one malloc. But it turns out it exceeds the | |||
| // Arena size. So we will we will allocate one segment at a time. | |||
| auto my_tag = std::make_unique<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>(alloc); | |||
| free_list_->Init(num_grpc_workers_, free_list_capacity); | |||
| tag_.reserve(num_grpc_workers_); | |||
| // Now we populate all free list. Round robin the free list among the numa nodes. | |||
| for (auto m = 0; m < num_grpc_workers_; ++m) { | |||
| NumaAllocator<CacheServerRequest> alloc(m % num_numa_nodes, CachePoolPolicy::kPreferred); | |||
| // Ideally we allocate all the free list in one malloc. But we will allocate one segment | |||
| // at a time so that we can change the numa policy easily per grpc worker. | |||
| auto my_tag = std::make_unique<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>(alloc); | |||
| // Allocate the tag and assign it the current queue | |||
| RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m)); | |||
| for (int i = 0; i < free_list_capacity; ++i) { | |||
| @@ -82,11 +104,6 @@ Status CacheServer::DoServiceStart() { | |||
| } | |||
| RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); | |||
| RETURN_IF_NOT_OK(free_list_->Register(&vg_)); | |||
| // Spawn a few threads to serve the real request. | |||
| auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); | |||
| for (auto i = 0; i < num_workers_; ++i) { | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i))); | |||
| } | |||
| // Start the comm layer | |||
| try { | |||
| comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | |||
| @@ -94,10 +111,29 @@ Status CacheServer::DoServiceStart() { | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| // Spawn a few threads to serve the real request. | |||
| auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); | |||
| for (auto i = 0; i < num_workers_; ++i) { | |||
| Task *pTask; | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i), &pTask)); | |||
| // Save a copy of the pointer to the underlying Task object. We may dynamically change their affinity if needed. | |||
| numa_tasks_.emplace(i, pTask); | |||
| // Spread out all the threads to all the numa nodes if needed | |||
| if (IsNumaAffinityOn()) { | |||
| auto numa_id = i % num_numa_nodes; | |||
| RETURN_IF_NOT_OK(SetAffinity(*pTask, numa_id)); | |||
| } | |||
| } | |||
| // Finally loop forever to handle the request. | |||
| auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1); | |||
| for (auto i = 0; i < num_workers_; ++i) { | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i))); | |||
| for (auto i = 0; i < num_grpc_workers_; ++i) { | |||
| Task *pTask; | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i), &pTask)); | |||
| // All these grpc workers will be allocated to the same node which is where we allocate all those free tag | |||
| // memory. | |||
| if (IsNumaAffinityOn()) { | |||
| RETURN_IF_NOT_OK(SetAffinity(*pTask, i % num_numa_nodes)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -108,8 +144,6 @@ Status CacheServer::DoServiceStop() { | |||
| // First stop all the threads. | |||
| RETURN_IF_NOT_OK(vg_.ServiceStop()); | |||
| // Clean up all the caches if any. | |||
| // Dump a message how much memory we have consumed in total. | |||
| MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes."; | |||
| UniqueLock lck(&rwLock_); | |||
| auto it = all_caches_.begin(); | |||
| while (it != all_caches_.end()) { | |||
| @@ -134,13 +168,14 @@ CacheService *CacheServer::GetService(connection_id_type id) const { | |||
| Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); | |||
| std::string cookie; | |||
| int32_t client_id; | |||
| auto session_id = rq->connection_info().session_id(); | |||
| auto crc = rq->connection_info().crc(); | |||
| // Before allowing the creation, make sure the session had already been created by the user | |||
| // Our intention is to add this cache to the active sessions list so leave the list locked during | |||
| // this entire function. | |||
| UniqueLock lock(&sessions_lock_); | |||
| UniqueLock sess_lck(&sessions_lock_); | |||
| auto session_it = active_sessions_.find(session_id); | |||
| if (session_it == active_sessions_.end()) { | |||
| RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!"); | |||
| @@ -163,6 +198,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| } | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| flatbuffers::Offset<flatbuffers::String> off_cookie; | |||
| flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list; | |||
| // Before creating the cache, first check if this is a request for a shared usage of an existing cache | |||
| // If two CreateService come in with identical connection_id, we need to serialize the create. | |||
| // The first create will be successful and be given a special cookie. | |||
| @@ -171,32 +207,74 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| if (global_shutdown_) { | |||
| return Status::OK(); | |||
| } | |||
| // We would like to protect ourselves from over allocating too much. We will go over existing cache | |||
| // and calculate how much we have consumed so far. | |||
| auto end = all_caches_.end(); | |||
| auto it = all_caches_.find(connection_id); | |||
| auto it = all_caches_.begin(); | |||
| bool duplicate = false; | |||
| auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; | |||
| int64_t max_avail = avail_mem; | |||
| while (it != end) { | |||
| if (it->first == connection_id) { | |||
| duplicate = true; | |||
| break; | |||
| } else { | |||
| auto &cs = it->second; | |||
| CacheService::ServiceStat stat; | |||
| RETURN_IF_NOT_OK(cs->GetStat(&stat)); | |||
| int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz; | |||
| max_avail -= mem_consumed; | |||
| if (max_avail <= 0) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } | |||
| } | |||
| ++it; | |||
| } | |||
| if (it == end) { | |||
| // If we have some cache using some memory already, make a reasonable decision if we should return | |||
| // out of memory. | |||
| if (max_avail < avail_mem) { | |||
| int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit. | |||
| if (req_mem > max_avail) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } else if (req_mem == 0) { | |||
| // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than | |||
| // 85% of our limit, fail this request. | |||
| if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||
| } | |||
| } | |||
| } | |||
| std::unique_ptr<CacheService> cs; | |||
| try { | |||
| cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id); | |||
| RETURN_IF_NOT_OK(cs->ServiceStart()); | |||
| cookie = cs->cookie(); | |||
| client_id = cs->num_clients_.fetch_add(1); | |||
| all_caches_.emplace(connection_id, std::move(cs)); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| // Add the cache into the active session tracking. | |||
| // We have already validated that the session exists and that this is a new cache created. | |||
| session_it->second.insert(connection_id); | |||
| } else { | |||
| duplicate = true; | |||
| client_id = it->second->num_clients_.fetch_add(1); | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | |||
| } | |||
| // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling | |||
| // the following function | |||
| lck.Unlock(); | |||
| sess_lck.Unlock(); | |||
| auto numa_id = client_id % GetNumaNodeCount(); | |||
| std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id); | |||
| // Send back the data | |||
| off_cookie = fbb.CreateString(cookie); | |||
| off_cpu_list = fbb.CreateVector(cpu_list); | |||
| CreateCacheReplyMsgBuilder bld(fbb); | |||
| bld.add_connection_id(connection_id); | |||
| bld.add_cookie(off_cookie); | |||
| bld.add_client_id(client_id); | |||
| // The last thing we send back is a set of cpu id that we suggest the client should bind itself to | |||
| bld.add_cpu_id(off_cpu_list); | |||
| auto off = bld.Finish(); | |||
| fbb.Finish(off); | |||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| @@ -220,26 +298,8 @@ Status CacheServer::DestroyCache(CacheRequest *rq) { | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | |||
| } | |||
| } | |||
| // Now that this cache is removed, we need to also remove it's connection id from active session tracking | |||
| auto session_id = GetSessionID(id); | |||
| UniqueLock sess_lck(&sessions_lock_); | |||
| auto it = active_sessions_.find(session_id); | |||
| if (it == active_sessions_.end()) { | |||
| // The session was not found in the active sessions | |||
| RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!"); | |||
| } | |||
| auto connection_it = it->second.find(id); | |||
| if (connection_it == it->second.end()) { | |||
| RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!"); | |||
| } | |||
| // remove that connection id from the set | |||
| it->second.erase(connection_it); | |||
| MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id; | |||
| // We aren't touching the session list even though we may be dropping the last remaining cache of a session. | |||
| // Leave that to be done by the drop session command. | |||
| return Status::OK(); | |||
| } | |||
| @@ -266,6 +326,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| buffers.push_back(rq->buf_data(i).data()); | |||
| } | |||
| row_id_type id = -1; | |||
| // We will allocate the memory the same numa node this thread is bound to. | |||
| RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id)); | |||
| reply->set_result(std::to_string(id)); | |||
| } else { | |||
| @@ -301,6 +362,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| if (!cs->HasBuildPhase() || cookie == cs->cookie()) { | |||
| row_id_type id = -1; | |||
| ReadableSlice src(p, sz); | |||
| // We will allocate the memory the same numa node this thread is bound to. | |||
| rc = cs->FastCacheRow(src, &id); | |||
| reply->set_result(std::to_string(id)); | |||
| } else { | |||
| @@ -330,9 +392,19 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| for (auto i = 0; i < sz; ++i) { | |||
| row_id.push_back(p->row_id()->Get(i)); | |||
| } | |||
| int64_t mem_sz = 0; | |||
| std::vector<key_size_pair> v; | |||
| RETURN_IF_NOT_OK(cs->PreBatchFetch(row_id, &v, &mem_sz)); | |||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||
| RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb)); | |||
| auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||
| int64_t mem_sz = sizeof(int64_t) * (sz + 1); | |||
| for (auto i = 0; i < sz; ++i) { | |||
| auto row_sz = locator->rows()->Get(i)->size(); | |||
| // row_sz is the size of the cached data. Later we will spawn multiple threads | |||
| // each of which will copy the data into either shared memory or protobuf concurrently but | |||
| // to different region. | |||
| // To avoid false sharing, we will bump up row_sz to be a multiple of 4k, i.e. 4096 bytes | |||
| row_sz = round_up_4K(row_sz); | |||
| mem_sz += row_sz; | |||
| } | |||
| auto client_flag = rq->flag(); | |||
| bool local_client = BitTest(client_flag, kLocalClientSupport); | |||
| // For large amount data to be sent back, we will use shared memory provided it is a local | |||
| @@ -346,7 +418,11 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| void *q = nullptr; | |||
| RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | |||
| WritableSlice dest(q, mem_sz); | |||
| RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); | |||
| Status rc = cs->BatchFetch(fbb, &dest); | |||
| if (rc.IsError()) { | |||
| shared_pool->Deallocate(q); | |||
| return rc; | |||
| } | |||
| // We can't return the absolute address which makes no sense to the client. | |||
| // Instead we return the difference. | |||
| auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base); | |||
| @@ -363,7 +439,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| WritableSlice dest(mem.data(), mem_sz); | |||
| RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); | |||
| RETURN_IF_NOT_OK(cs->BatchFetch(fbb, &dest)); | |||
| reply->set_result(std::move(mem)); | |||
| } | |||
| } | |||
| @@ -386,6 +462,7 @@ Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) { | |||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | |||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | |||
| bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); | |||
| bld.add_num_numa_hit(svc_stat.stat_.num_numa_hit); | |||
| bld.add_max_row_id(svc_stat.stat_.max_key); | |||
| bld.add_min_row_id(svc_stat.stat_.min_key); | |||
| bld.add_state(svc_stat.state_); | |||
| @@ -506,30 +583,27 @@ Status CacheServer::ToggleWriteMode(CacheRequest *rq) { | |||
| } | |||
| Status CacheServer::ListSessions(CacheReply *reply) { | |||
| SharedLock lck(&sessions_lock_); | |||
| SharedLock sess_lck(&sessions_lock_); | |||
| SharedLock lck(&rwLock_); | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector; | |||
| for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) { | |||
| session_id_type current_session_id = it->first; | |||
| // Loop over each cache inside this session | |||
| if (!it->second.empty()) { | |||
| for (auto current_conn_id : it->second) { | |||
| CacheService *cs = GetService(current_conn_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions."; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| CacheService::ServiceStat svc_stat; | |||
| RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); | |||
| auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, | |||
| svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key, | |||
| svc_stat.stat_.max_key, svc_stat.state_); | |||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); | |||
| session_msgs_vector.push_back(current_session_info); | |||
| } | |||
| for (auto const ¤t_session_id : active_sessions_) { | |||
| bool found = false; | |||
| for (auto const &it : all_caches_) { | |||
| auto current_conn_id = it.first; | |||
| if (GetSessionID(current_conn_id) == current_session_id) { | |||
| found = true; | |||
| auto &cs = it.second; | |||
| CacheService::ServiceStat svc_stat; | |||
| RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); | |||
| auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, | |||
| svc_stat.stat_.average_cache_sz, svc_stat.stat_.num_numa_hit, | |||
| svc_stat.stat_.min_key, svc_stat.stat_.max_key, svc_stat.state_); | |||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); | |||
| session_msgs_vector.push_back(current_session_info); | |||
| } | |||
| } else { | |||
| } | |||
| if (!found) { | |||
| // If there is no cache created yet, assign a connection id of 0 along with empty stats | |||
| auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0); | |||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats); | |||
| @@ -542,18 +616,35 @@ Status CacheServer::ListSessions(CacheReply *reply) { | |||
| auto offset = s_builder.Finish(); | |||
| fbb.Finish(offset); | |||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::ConnectReset(CacheRequest *rq) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| auto client_id = rq->client_id(); | |||
| MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects"; | |||
| cs->num_clients_--; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| /// \brief This is the main loop the cache server thread(s) are running. | |||
| /// Each thread will pop a request and send the result back to the client using grpc | |||
| /// \return | |||
| Status CacheServer::ServerRequest(int32_t worker_id) { | |||
| Status CacheServer::ServerRequest(worker_id_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| MS_LOG(DEBUG) << "Worker id " << worker_id << " is running on node " << hw_info_->GetMyNode(); | |||
| auto &my_que = cache_q_->operator[](worker_id); | |||
| // Loop forever until we are interrupted or shutdown. | |||
| while (!global_shutdown_) { | |||
| bool internal_request = false; | |||
| CacheServerRequest *cache_req = nullptr; | |||
| RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | |||
| auto &rq = cache_req->rq_; | |||
| @@ -571,8 +662,17 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBatchFetchRows: { | |||
| cache_req->rc_ = BatchFetchRows(&rq, &reply); | |||
| case BaseRequest::RequestType::kInternalFetchRow: { | |||
| internal_request = true; | |||
| auto connection_id = rq.connection_id(); | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data())); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCreateCache: { | |||
| @@ -636,6 +736,10 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||
| cache_req->rc_ = ListSessions(&reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kConnectReset: { | |||
| cache_req->rc_ = ConnectReset(&rq); | |||
| break; | |||
| } | |||
| default: | |||
| std::string errMsg("Unknown request type : "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| @@ -647,7 +751,13 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||
| // We will re-tag the request back to the grpc queue. Once it comes back from the client, | |||
| // the CacheServerRequest, i.e. the pointer cache_req, will be free | |||
| if (!global_shutdown_) { | |||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||
| if (!internal_request) { | |||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||
| } else { | |||
| // This is an internal request and is not tied to rpc. But need to post because there | |||
| // is a thread waiting on the completion of this request. | |||
| cache_req->wp_.Set(); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| @@ -667,12 +777,20 @@ CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int | |||
| int32_t shared_meory_sz_in_gb, float memory_cap_ratio) | |||
| : top_(spill_path), | |||
| num_workers_(num_workers), | |||
| num_grpc_workers_(num_workers_), | |||
| port_(port), | |||
| shared_memory_sz_in_gb_(shared_meory_sz_in_gb), | |||
| global_shutdown_(false), | |||
| memory_cap_ratio_(memory_cap_ratio), | |||
| cur_mem_usage_(0) { | |||
| memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_; | |||
| numa_affinity_(true) { | |||
| hw_info_ = std::make_shared<CacheServerHW>(); | |||
| // If we are not linked with numa library (i.e. NUMA_ENABLED is false), turn off cpu | |||
| // affinity which can make performance worse. | |||
| if (!CacheServerHW::numa_enabled()) { | |||
| numa_affinity_ = false; | |||
| MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build " | |||
| "that is compiled with numa support for more optimal performance"; | |||
| } | |||
| } | |||
| Status CacheServer::Run(int msg_qid) { | |||
| @@ -719,51 +837,52 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { | |||
| Status CacheServer::DestroySession(CacheRequest *rq) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); | |||
| auto drop_session_id = rq->connection_info().session_id(); | |||
| UniqueLock lck(&sessions_lock_); | |||
| // First validate that this session exists | |||
| auto it = active_sessions_.find(drop_session_id); | |||
| if (it == active_sessions_.end()) { | |||
| RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!"); | |||
| } | |||
| // Grab the locks in the correct order to avoid deadlock. | |||
| UniqueLock sess_lck(&sessions_lock_); | |||
| UniqueLock lck(&rwLock_); | |||
| // Iterate over the set of connection id's for this session that we're dropping and erase each one. | |||
| { | |||
| UniqueLock rwlck(&rwLock_); | |||
| for (auto drop_connection_id : it->second) { | |||
| auto cache_drop_it = all_caches_.find(drop_connection_id); | |||
| if (cache_drop_it == all_caches_.end()) { | |||
| RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry."); | |||
| } | |||
| all_caches_.erase(cache_drop_it); | |||
| MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id; | |||
| // **Do not bother to remove the cache connection id from the active session because we will soon remove the | |||
| // entire session. | |||
| bool found = false; | |||
| for (auto it = all_caches_.begin(); it != all_caches_.end();) { | |||
| auto connection_id = it->first; | |||
| auto session_id = GetSessionID(connection_id); | |||
| // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. | |||
| // So we will just manually do it. | |||
| if (session_id == drop_session_id) { | |||
| found = true; | |||
| it = all_caches_.erase(it); | |||
| MS_LOG(INFO) << "Destroy cache with id " << connection_id; | |||
| } else { | |||
| ++it; | |||
| } | |||
| } | |||
| // Finally remove the session itself | |||
| active_sessions_.erase(it); | |||
| MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; | |||
| return Status::OK(); | |||
| auto n = active_sessions_.erase(drop_session_id); | |||
| if (n > 0) { | |||
| MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; | |||
| return Status::OK(); | |||
| } else { | |||
| if (found) { | |||
| std::string errMsg = | |||
| "A destroy cache request has been completed but it had a stale session id " + std::to_string(drop_session_id); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } else { | |||
| std::string errMsg = "Session id " + std::to_string(drop_session_id) + " not found."; | |||
| return Status(StatusCode::kFileNotExist, errMsg); | |||
| } | |||
| } | |||
| } | |||
| session_id_type CacheServer::GenerateSessionID() { | |||
| UniqueLock lock(&sessions_lock_); | |||
| UniqueLock sess_lck(&sessions_lock_); | |||
| auto mt = GetRandomDevice(); | |||
| std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max()); | |||
| session_id_type session_id; | |||
| bool duplicate = false; | |||
| do { | |||
| session_id = distribution(mt); | |||
| auto it = active_sessions_.find(session_id); | |||
| duplicate = (it != active_sessions_.end()); | |||
| auto r = active_sessions_.insert(session_id); | |||
| duplicate = !r.second; | |||
| } while (duplicate); | |||
| // Add this session to our tracking of active sessions with initialized empty set of connections. | |||
| active_sessions_[session_id] = std::set<connection_id_type>(); | |||
| return session_id; | |||
| } | |||
| @@ -789,7 +908,7 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::RpcRequest(int32_t worker_id) { | |||
| Status CacheServer::RpcRequest(worker_id_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); | |||
| return Status::OK(); | |||
| @@ -820,12 +939,22 @@ Status CacheServer::GlobalShutdown() { | |||
| return Status::OK(); | |||
| } | |||
| int64_t CacheServer::GetTotalSystemMemory() { | |||
| auto pages = sysconf(_SC_PHYS_PAGES); | |||
| auto page_size = sysconf(_SC_PAGE_SIZE); | |||
| auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size); | |||
| MS_LOG(INFO) << "Total physical RAM in bytes: " << total; | |||
| return total; | |||
| worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { | |||
| auto num_numa_nodes = GetNumaNodeCount(); | |||
| MS_ASSERT(numa_id < num_numa_nodes); | |||
| auto num_workers_per_node = GetNumWorkers() / num_numa_nodes; | |||
| std::mt19937 gen = GetRandomDevice(); | |||
| std::uniform_int_distribution<worker_id_t> dist(0, num_workers_per_node - 1); | |||
| auto n = dist(gen); | |||
| worker_id_t worker_id = n * num_numa_nodes + numa_id; | |||
| MS_ASSERT(worker_id < GetNumWorkers()); | |||
| return worker_id; | |||
| } | |||
| worker_id_t CacheServer::GetRandomWorker() { | |||
| std::mt19937 gen = GetRandomDevice(); | |||
| std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1); | |||
| return dist(gen); | |||
| } | |||
| Status CacheServer::Builder::IpcResourceCleanup() { | |||
| @@ -842,6 +971,8 @@ Status CacheServer::Builder::IpcResourceCleanup() { | |||
| rc = mem.Attach(); | |||
| if (rc.IsError()) { | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_IF_NOT_OK(mem.Detach()); | |||
| } | |||
| int32_t num_attached; | |||
| RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached)); | |||
| @@ -892,5 +1023,16 @@ Status CacheServer::Builder::SanityCheck() { | |||
| RETURN_IF_NOT_OK(IpcResourceCleanup()); | |||
| return Status::OK(); | |||
| } | |||
| CacheServer::Builder::Builder() | |||
| : top_("/tmp"), | |||
| num_workers_(std::thread::hardware_concurrency() / 2), | |||
| port_(50052), | |||
| shared_memory_sz_in_gb_(4), | |||
| memory_cap_ratio_(0.8) { | |||
| if (num_workers_ == 0) { | |||
| num_workers_ = 1; | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -17,23 +17,31 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | |||
| #include <stdlib.h> | |||
| #include <string.h> | |||
| #include <unistd.h> | |||
| #include <algorithm> | |||
| #include <atomic> | |||
| #include <chrono> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <set> | |||
| #include <thread> | |||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/cache/cache_grpc_server.h" | |||
| #include "minddata/dataset/engine/cache/cache_pool.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/util/cache_pool.h" | |||
| #include "minddata/dataset/util/lock.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/semaphore.h" | |||
| #include "minddata/dataset/util/service.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| #include "minddata/dataset/util/system_pool.h" | |||
| @@ -47,9 +55,10 @@ class CacheServer : public Service { | |||
| public: | |||
| friend class Services; | |||
| using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | |||
| class Builder { | |||
| public: | |||
| Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {} | |||
| Builder(); | |||
| ~Builder() = default; | |||
| @@ -161,26 +170,40 @@ class CacheServer : public Service { | |||
| /// \return Status object | |||
| static Status ReturnRequestTag(CacheServerRequest *p); | |||
| /// \brief This returns the size (in bytes) of the physical RAM on the machine. | |||
| /// \return the size (in bytes) of the physical RAM on the machine. | |||
| static int64_t GetTotalSystemMemory(); | |||
| /// Return an instance of the numa control | |||
| std::shared_ptr<CacheServerHW> GetHWControl() { return hw_info_; } | |||
| /// \brief Internally this is how much we will try to use without exceeding the limit | |||
| /// \return Internal cap maximum | |||
| int64_t GetAvailableSystemMemory() { return memory_cap_; } | |||
| /// \brief Set CPU affinity | |||
| Status SetAffinity(const Task &tk, numa_id_t numa_node) { return hw_info_->SetAffinity(tk, numa_node); } | |||
| /// \brief Find out the current memory usage | |||
| int64_t GetMemoryUsage() { return cur_mem_usage_; } | |||
| /// \brief return number of workers | |||
| auto GetNumWorkers() const { return num_workers_; } | |||
| /// \brief This updates our current memory usage. | |||
| enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 }; | |||
| void UpdateMemoryUsage(int64_t sz, MemUsageOp op) { | |||
| if (op == MemUsageOp::kAllocate) { | |||
| cur_mem_usage_ += sz; | |||
| } else { | |||
| cur_mem_usage_ -= sz; | |||
| } | |||
| } | |||
| /// \brief return number of grpc workers | |||
| auto GetNumGrpcWorkers() const { return num_grpc_workers_; } | |||
| /// \brief return number of numa nodes | |||
| auto GetNumaNodeCount() const { return hw_info_->GetNumaNodeCount(); } | |||
| /// \brief Assign a worker by a numa id | |||
| /// \return worker id | |||
| worker_id_t GetWorkerByNumaId(numa_id_t node_id); | |||
| /// \brief Randomly pick a worker | |||
| /// \return worker id | |||
| worker_id_t GetRandomWorker(); | |||
| /// \brief Check if we bind threads to numa cores | |||
| bool IsNumaAffinityOn() const { return numa_affinity_; } | |||
| /// \brief Internal function to do row batch fetch | |||
| /// \param rq Request | |||
| /// \param reply Reply | |||
| /// \return Status object | |||
| Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Return the memory cap ratio | |||
| float GetMemoryCapRatio() const { return memory_cap_ratio_; } | |||
| private: | |||
| static std::once_flag init_instance_flag_; | |||
| @@ -189,20 +212,21 @@ class CacheServer : public Service { | |||
| mutable RWLock sessions_lock_; | |||
| std::string top_; | |||
| cache_index all_caches_; | |||
| std::map<session_id_type, std::set<connection_id_type>> active_sessions_; | |||
| std::set<session_id_type> active_sessions_; | |||
| std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_; | |||
| std::shared_ptr<QueueList<CacheServerRequest *>> free_list_; | |||
| std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_; | |||
| std::vector<std::unique_ptr<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>> tag_; | |||
| std::shared_ptr<CacheServerGreeterImpl> comm_layer_; | |||
| std::shared_ptr<MemoryPool> mp_; | |||
| TaskGroup vg_; | |||
| int32_t num_workers_; | |||
| int32_t num_grpc_workers_; | |||
| int32_t port_; | |||
| int32_t shared_memory_sz_in_gb_; | |||
| std::atomic<bool> global_shutdown_; | |||
| float memory_cap_ratio_; | |||
| int64_t memory_cap_; | |||
| std::atomic<int64_t> cur_mem_usage_; | |||
| std::shared_ptr<CacheServerHW> hw_info_; | |||
| std::map<worker_id_t, Task *> numa_tasks_; | |||
| bool numa_affinity_; | |||
| /// \brief Constructor | |||
| /// \param spill_path Top directory for spilling buffers to. | |||
| @@ -226,11 +250,11 @@ class CacheServer : public Service { | |||
| Status DestroyCache(CacheRequest *rq); | |||
| /// \brief Entry point for all internal server threads. | |||
| Status ServerRequest(int32_t worker_id); | |||
| Status ServerRequest(worker_id_t worker_id); | |||
| /// \brief Entry point for all grpc threads. | |||
| /// \return | |||
| Status RpcRequest(int32_t worker_id); | |||
| Status RpcRequest(worker_id_t worker_id); | |||
| Status DestroySession(CacheRequest *rq); | |||
| @@ -266,12 +290,6 @@ class CacheServer : public Service { | |||
| Status FastCacheRow(CacheRequest *rq, CacheReply *reply); | |||
| Status CacheRow(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Internal function to do row batch fetch | |||
| /// \param rq Request | |||
| /// \param reply Reply | |||
| /// \return Status object | |||
| Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Internal function to get statistics | |||
| /// \param rq | |||
| /// \param reply | |||
| @@ -309,6 +327,9 @@ class CacheServer : public Service { | |||
| /// \param reply | |||
| /// \return Status object | |||
| Status ListSessions(CacheReply *reply); | |||
| /// \brief Connect request by a pipeline | |||
| Status ConnectReset(CacheRequest *rq); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -13,51 +13,45 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <random> | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/slice.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) | |||
| : root_(root), | |||
| cache_mem_sz_(mem_sz), | |||
| cache_mem_sz_(mem_sz * 1048576L), // mem_sz is in MB unit | |||
| cp_(nullptr), | |||
| next_id_(0), | |||
| generate_id_(generate_id), | |||
| st_(generate_id ? State::kBuildPhase : State::kNone), | |||
| cur_mem_usage_(0), | |||
| cur_disk_usage_(0) {} | |||
| num_clients_(0), | |||
| st_(generate_id ? CacheServiceState::kBuildPhase : CacheServiceState::kNone) {} | |||
| CacheService::~CacheService() { (void)ServiceStop(); } | |||
| bool CacheService::UseArena() { | |||
| // If fixed size, use Arena instead of the pool from global context. | |||
| return (cache_mem_sz_ > 0); | |||
| } | |||
| Status CacheService::DoServiceStart() { | |||
| std::shared_ptr<MemoryPool> mp_; | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| if (UseArena()) { | |||
| auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L; | |||
| float memory_cap_ratio = cs.GetMemoryCapRatio(); | |||
| if (cache_mem_sz_ > 0) { | |||
| auto avail_mem = CacheServerHW::GetTotalSystemMemory(); | |||
| if (cache_mem_sz_ > avail_mem) { | |||
| // Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway. | |||
| MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem | |||
| << " MB"; | |||
| MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " while available system memory " << avail_mem; | |||
| } | |||
| // Create a fixed size arena based on the parameter. | |||
| std::shared_ptr<Arena> arena; | |||
| RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); | |||
| mp_ = std::move(arena); | |||
| // update the global usage only. | |||
| cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate); | |||
| } else { | |||
| // Unlimited size. Simply use a system pool. Another choice is CircularPool. | |||
| mp_ = std::make_shared<SystemPool>(); | |||
| memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem; | |||
| } | |||
| numa_pool_ = std::make_shared<NumaMemoryPool>(cs.GetHWControl(), memory_cap_ratio); | |||
| // It is possible we aren't able to allocate the pool for many reasons. | |||
| std::vector<numa_id_t> avail_nodes = numa_pool_->GetAvailableNodes(); | |||
| if (avail_nodes.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Unable to bring up numa memory pool"); | |||
| } | |||
| // Put together a CachePool for backing up the Tensor | |||
| cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), UseArena(), root_); | |||
| // Put together a CachePool for backing up the Tensor. | |||
| cp_ = std::make_shared<CachePool>(numa_pool_, root_); | |||
| RETURN_IF_NOT_OK(cp_->ServiceStart()); | |||
| // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. | |||
| cookie_ = cp_->MyName(); | |||
| @@ -68,26 +62,18 @@ Status CacheService::DoServiceStop() { | |||
| if (cp_ != nullptr) { | |||
| RETURN_IF_NOT_OK(cp_->ServiceStop()); | |||
| } | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| if (UseArena()) { | |||
| cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree); | |||
| } else { | |||
| MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage() | |||
| << " bytes."; | |||
| cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | |||
| if (st_ == State::kFetchPhase) { | |||
| if (st_ == CacheServiceState::kFetchPhase) { | |||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| if (st_ == State::kNoLocking) { | |||
| if (st_ == CacheServiceState::kNoLocking) { | |||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | |||
| // return out of memory from now on. | |||
| return Status(StatusCode::kOutOfMemory); | |||
| @@ -128,26 +114,13 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); | |||
| total_sz += msg->data_sz()->Get(i); | |||
| } | |||
| // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. | |||
| // Otherwise, we check how much (globally) how much we use and may simply spill to disk | |||
| // directly. | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); | |||
| Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly); | |||
| // Now we cache the buffer. | |||
| Status rc = cp_->Insert(*row_id_generated, all_data); | |||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | |||
| } else { | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| // All good, then update the memory usage local and global (if not using arena) | |||
| if (write_to_disk_directly) { | |||
| cur_disk_usage_ += total_sz; | |||
| } else { | |||
| cur_mem_usage_ += total_sz; | |||
| if (!UseArena()) { | |||
| cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| @@ -157,12 +130,12 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | |||
| if (st_ == State::kFetchPhase) { | |||
| if (st_ == CacheServiceState::kFetchPhase) { | |||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| if (st_ == State::kNoLocking) { | |||
| if (st_ == CacheServiceState::kNoLocking) { | |||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | |||
| // return out of memory from now on. | |||
| return Status(StatusCode::kOutOfMemory); | |||
| @@ -183,27 +156,13 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ | |||
| } | |||
| *row_id_generated = msg->row_id(); | |||
| } | |||
| // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. | |||
| // Otherwise, we check how much (globally) how much we use and may simply spill to disk | |||
| // directly. | |||
| auto total_sz = src.GetSize(); | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); | |||
| Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly); | |||
| // Now we cache the buffer. | |||
| Status rc = cp_->Insert(*row_id_generated, {src}); | |||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | |||
| } else { | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| // All good, then update the memory usage local and global (if not using arena) | |||
| if (write_to_disk_directly) { | |||
| cur_disk_usage_ += total_sz; | |||
| } else { | |||
| cur_mem_usage_ += total_sz; | |||
| if (!UseArena()) { | |||
| cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| @@ -247,52 +206,116 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *out, | |||
| int64_t *mem_sz) { | |||
| Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | |||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| RETURN_UNEXPECTED_IF_NULL(mem_sz); | |||
| const auto num_elements = v.size(); | |||
| *mem_sz = (num_elements + 1) * sizeof(int64_t); | |||
| (*out).reserve(num_elements); | |||
| std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v; | |||
| datalocator_v.reserve(v.size()); | |||
| for (auto row_id : v) { | |||
| auto sz = cp_->GetSize(row_id); | |||
| if (sz > 0) { | |||
| (*out).emplace_back(row_id, sz); | |||
| (*mem_sz) += sz; | |||
| } else { | |||
| // key not found | |||
| (*out).emplace_back(-1, 0); | |||
| } | |||
| flatbuffers::Offset<DataLocatorMsg> offset; | |||
| RETURN_IF_NOT_OK(cp_->GetDataLocator(row_id, fbb, &offset)); | |||
| datalocator_v.push_back(offset); | |||
| } | |||
| auto offset_v = fbb->CreateVector(datalocator_v); | |||
| BatchDataLocatorMsgBuilder bld(*fbb); | |||
| bld.add_connection_id(connection_id); | |||
| bld.add_rows(offset_v); | |||
| auto offset_final = bld.Finish(); | |||
| fbb->Finish(offset_final); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &info, | |||
| WritableSlice *out) const { | |||
| Status CacheService::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) const { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kBuildPhase) { | |||
| if (st_ == CacheServiceState::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| const auto num_elements = v.size(); | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| int32_t numQ = cs.GetNumGrpcWorkers(); | |||
| auto rng = GetRandomDevice(); | |||
| std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1); | |||
| int32_t qID = distribution(rng); | |||
| std::vector<CacheServerRequest *> cache_rq_list; | |||
| auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||
| const auto num_elements = p->rows()->size(); | |||
| auto connection_id = p->connection_id(); | |||
| cache_rq_list.reserve(num_elements); | |||
| int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | |||
| auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | |||
| offset_array[0] = data_offset; | |||
| for (auto i = 0; i < num_elements; ++i) { | |||
| auto sz = info.at(i).second; | |||
| offset_array[i + 1] = offset_array[i] + sz; | |||
| auto data_locator = p->rows()->Get(i); | |||
| auto node_id = data_locator->node_id(); | |||
| size_t sz = data_locator->size(); | |||
| void *source_addr = reinterpret_cast<void *>(data_locator->addr()); | |||
| auto key = data_locator->key(); | |||
| // Please read the comment in CacheServer::BatchFetchRows where we allocate | |||
| // the buffer big enough so each thread (which we are going to dispatch) will | |||
| // not run into false sharing problem. We are going to round up sz to 4k. | |||
| auto sz_4k = round_up_4K(sz); | |||
| offset_array[i + 1] = offset_array[i] + sz_4k; | |||
| if (sz > 0) { | |||
| WritableSlice row_data(*out, offset_array[i], sz); | |||
| auto key = info.at(i).first; | |||
| size_t bytesRead = 0; | |||
| RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); | |||
| if (bytesRead != sz) { | |||
| MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "." | |||
| << " Internal key: " << key << "\n"; | |||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||
| } | |||
| // Get a request and send to the proper worker (at some numa node) to do the fetch. | |||
| worker_id_t worker_id = cs.IsNumaAffinityOn() ? cs.GetWorkerByNumaId(node_id) : cs.GetRandomWorker(); | |||
| CacheServerRequest *cache_rq; | |||
| RETURN_IF_NOT_OK(cs.GetFreeRequestTag(qID++ % numQ, &cache_rq)); | |||
| cache_rq_list.push_back(cache_rq); | |||
| // Set up all the necessarily field. | |||
| cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; | |||
| cache_rq->st_ = CacheServerRequest::STATE::PROCESS; | |||
| cache_rq->rq_.set_connection_id(connection_id); | |||
| cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_)); | |||
| auto dest_addr = row_data.GetMutablePointer(); | |||
| flatbuffers::FlatBufferBuilder fb2; | |||
| FetchRowMsgBuilder bld(fb2); | |||
| bld.add_key(key); | |||
| bld.add_size(sz); | |||
| bld.add_source_addr(reinterpret_cast<int64_t>(source_addr)); | |||
| bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr)); | |||
| auto offset = bld.Finish(); | |||
| fb2.Finish(offset); | |||
| cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); | |||
| RETURN_IF_NOT_OK(cs.PushRequest(worker_id, cache_rq)); | |||
| } | |||
| } | |||
| // Now wait for all of them to come back. Let go of the shared lock. We shouldn't be holding | |||
| // any lock while we can wait for a long time. | |||
| rw.Unlock(); | |||
| Status rc; | |||
| for (CacheServerRequest *rq : cache_rq_list) { | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { | |||
| rc = rq->rc_; | |||
| } | |||
| RETURN_IF_NOT_OK(cs.ReturnRequestTag(rq)); | |||
| } | |||
| return rc; | |||
| } | |||
| Status CacheService::InternalFetchRow(const FetchRowMsg *p) { | |||
| RETURN_UNEXPECTED_IF_NULL(p); | |||
| SharedLock rw(&rw_lock_); | |||
| size_t bytesRead = 0; | |||
| int64_t key = p->key(); | |||
| size_t sz = p->size(); | |||
| void *source_addr = reinterpret_cast<void *>(p->source_addr()); | |||
| void *dest_addr = reinterpret_cast<void *>(p->dest_addr()); | |||
| WritableSlice dest(dest_addr, sz); | |||
| if (source_addr != nullptr) { | |||
| // We are not checking if the row is still present but simply use the information passed in. | |||
| // This saves another tree lookup and is faster. | |||
| ReadableSlice src(source_addr, sz); | |||
| RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); | |||
| } else { | |||
| RETURN_IF_NOT_OK(cp_->Read(key, &dest, &bytesRead)); | |||
| if (bytesRead != sz) { | |||
| std::string errMsg = "Unexpected length. Read " + std::to_string(bytesRead) + ". Expected " + std::to_string(sz) + | |||
| "." + " Internal key: " + std::to_string(key); | |||
| MS_LOG(ERROR) << errMsg; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| @@ -312,7 +335,7 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) { | |||
| Status CacheService::FetchSchema(std::string *out) const { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kBuildPhase) { | |||
| if (st_ == CacheServiceState::kBuildPhase) { | |||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| @@ -333,7 +356,7 @@ Status CacheService::BuildPhaseDone() { | |||
| if (HasBuildPhase()) { | |||
| // Exclusive lock to switch phase | |||
| UniqueLock rw(&rw_lock_); | |||
| st_ = State::kFetchPhase; | |||
| st_ = CacheServiceState::kFetchPhase; | |||
| cp_->SetLocking(false); | |||
| return Status::OK(); | |||
| } else { | |||
| @@ -348,12 +371,12 @@ Status CacheService::ToggleWriteMode(bool on_off) { | |||
| } else { | |||
| // If we stop accepting write request, we turn off locking for the | |||
| // underlying B+ tree. All future write request we will return kOutOfMemory. | |||
| if (st_ == State::kNone && !on_off) { | |||
| st_ = State::kNoLocking; | |||
| if (st_ == CacheServiceState::kNone && !on_off) { | |||
| st_ = CacheServiceState::kNoLocking; | |||
| cp_->SetLocking(on_off); | |||
| MS_LOG(WARNING) << "Locking mode is switched off."; | |||
| } else if (st_ == State::kNoLocking && on_off) { | |||
| st_ = State::kNone; | |||
| } else if (st_ == CacheServiceState::kNoLocking && on_off) { | |||
| st_ = CacheServiceState::kNone; | |||
| cp_->SetLocking(on_off); | |||
| } | |||
| } | |||
| @@ -29,36 +29,28 @@ | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/cache_pool.h" | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/util/btree.h" | |||
| #include "minddata/dataset/util/cache_pool.h" | |||
| #include "minddata/dataset/util/service.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| #include "minddata/dataset/util/system_pool.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// Some typedef used for BatchFetch | |||
| using key_size_pair = std::pair<CachePool::key_type, size_t>; | |||
| /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is | |||
| /// created to support spilling | |||
| class CacheService : public Service { | |||
| public: | |||
| friend class CacheServer; | |||
| enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; | |||
| /// \brief Constructor | |||
| /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited | |||
| /// \param root Spill path. Empty string means no spilling | |||
| /// \param generate_id If the cache service should generate row id for buffer that is cached. | |||
| /// For non-mappable dataset, this should be set to true. | |||
| CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); | |||
| ~CacheService(); | |||
| /// \brief For fixed size memory, we will create an Arena. | |||
| /// \return false if unlimited memory. | |||
| bool UseArena(); | |||
| ~CacheService() override; | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() override; | |||
| @@ -77,18 +69,18 @@ class CacheService : public Service { | |||
| Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated); | |||
| /// \brief This function is used in preparation for batch fetching. | |||
| /// It calculates how much memory we should allocate and which row id are present. | |||
| /// \param[in/out] Pointer to vector of <CachePool::key_type, size_t> | |||
| /// \param[in/out] mem_sz how much memory is required to batch fetch | |||
| /// It calculates how much memory we should allocate and which row id are present, etc. | |||
| /// All needed results are stored in the flat buffer. | |||
| /// \return Status object | |||
| Status PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *, int64_t *mem_sz); | |||
| Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | |||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &); | |||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | |||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | |||
| /// \param[in] v A vector of row id. | |||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | |||
| /// \return Status object | |||
| Status BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &, WritableSlice *out) const; | |||
| Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, WritableSlice *out) const; | |||
| /// \brief Getter function | |||
| /// \return Spilling path | |||
| @@ -96,7 +88,7 @@ class CacheService : public Service { | |||
| /// \brief A structure returned from the cache server for statistics request. | |||
| class ServiceStat { | |||
| public: | |||
| using state_type = std::underlying_type<State>::type; | |||
| using state_type = std::underlying_type<CacheServiceState>::type; | |||
| ServiceStat() : state_(0) {} | |||
| ~ServiceStat() = default; | |||
| CachePool::CacheStat stat_{}; | |||
| @@ -134,10 +126,6 @@ class CacheService : public Service { | |||
| /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. | |||
| /// \return Status object | |||
| Status BuildPhaseDone(); | |||
| /// \brief Find out the current memory usage | |||
| int64_t GetMemoryUsage() { return cur_mem_usage_; } | |||
| /// \brief Find out the current disk usage | |||
| int64_t GetDiskUsage() { return cur_disk_usage_; } | |||
| /// \brief For kToggleWriteMode request | |||
| Status ToggleWriteMode(bool on_off); | |||
| @@ -149,14 +137,10 @@ class CacheService : public Service { | |||
| std::atomic<row_id_type> next_id_; | |||
| bool generate_id_; | |||
| std::string cookie_; | |||
| State st_; | |||
| std::atomic<int32_t> num_clients_; | |||
| CacheServiceState st_; | |||
| std::string schema_; | |||
| // If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it. | |||
| // Otherwise we track how much is in memory and how much is on disk (if root_ is not empty). | |||
| // We use them to control when we should stop caching in memory in the case when there is no | |||
| // Arena. | |||
| std::atomic<int64_t> cur_mem_usage_; | |||
| std::atomic<int64_t> cur_disk_usage_; | |||
| std::shared_ptr<NumaMemoryPool> numa_pool_; | |||
| // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make | |||
| // this request after we hit memory full or disk full. So the result is unlikely to change. | |||
| std::mutex get_key_miss_mux_; | |||
| @@ -164,6 +148,8 @@ class CacheService : public Service { | |||
| /// \brief Private function to generate a row id | |||
| /// \return Row id assigned. | |||
| row_id_type GetNextRowId() { return next_id_.fetch_add(1); } | |||
| Status InternalFetchRow(const FetchRowMsg *p); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -65,6 +65,7 @@ table ServiceStatMsg { | |||
| num_mem_cached:int64; | |||
| num_disk_cached:int64; | |||
| avg_cache_sz:int64; | |||
| num_numa_hit:int64; | |||
| min_row_id:int64; | |||
| max_row_id:int64; | |||
| state:int8; | |||
| @@ -89,8 +90,10 @@ table CreateCacheRequestMsg { | |||
| /// Return result of CreateCacheRequest | |||
| table CreateCacheReplyMsg { | |||
| connection_id:int64; | |||
| client_id:int32; | |||
| connection_id:uint64; | |||
| cookie:string; | |||
| cpu_id:[int32]; | |||
| } | |||
| table ListSessionMsg { | |||
| @@ -102,3 +105,22 @@ table ListSessionMsg { | |||
| table ListSessionsMsg { | |||
| sessions:[ListSessionMsg]; | |||
| } | |||
| table DataLocatorMsg { | |||
| key:int64; | |||
| node_id:int32; | |||
| addr:int64; | |||
| size:int64; | |||
| } | |||
| table BatchDataLocatorMsg { | |||
| connection_id:uint64; | |||
| rows:[DataLocatorMsg]; | |||
| } | |||
| table FetchRowMsg { | |||
| key:int64; | |||
| source_addr:int64; | |||
| dest_addr:int64; | |||
| size:int64; | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| if (ENABLE_CACHE) | |||
| ms_protobuf_generate(CACHE_PERF_PROTO_SRCS CACHE_PERF_PROTO_HDRS cache_perf.proto) | |||
| add_executable(cache_perf cache_perf.cc cache_msg.cc cache_perf_run.cc ${CACHE_PERF_PROTO_SRCS}) | |||
| target_link_libraries(cache_perf | |||
| _c_dataengine | |||
| _c_mindrecord | |||
| mindspore::protobuf | |||
| mindspore_gvar | |||
| ${PYTHON_LIBRARIES} | |||
| pthread) | |||
| if (USE_GLOG) | |||
| target_link_libraries(cache_perf mindspore::glog) | |||
| endif () | |||
| add_executable(cache_pipeline cache_pipeline.cc cache_msg.cc cache_pipeline_run.cc ${CACHE_PERF_PROTO_SRCS}) | |||
| target_link_libraries(cache_pipeline | |||
| _c_dataengine | |||
| _c_mindrecord | |||
| mindspore::protobuf | |||
| mindspore_gvar | |||
| ${PYTHON_LIBRARIES} | |||
| pthread) | |||
| if (USE_GLOG) | |||
| target_link_libraries(cache_pipeline mindspore::glog) | |||
| endif () | |||
| endif () | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/perf/cache_msg.h" | |||
| #include <sys/types.h> | |||
| #include <sys/ipc.h> | |||
| #include <sys/msg.h> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status CachePerfMsg::Send(int32_t qID) { | |||
| auto err = msgsnd(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), IPC_NOWAIT); | |||
| if (err == -1) { | |||
| std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CachePerfMsg::Receive(int32_t qID) { | |||
| // This is a blocking call. Either there is some message or we the queue is removed when | |||
| // the destructor is called. | |||
| auto err = msgrcv(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR); | |||
| if (err == -1) { | |||
| if (errno == EIDRM) { | |||
| return Status(StatusCode::kInterrupted); | |||
| } else { | |||
| std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ | |||
| #include <cstdint> | |||
| #include <limits> | |||
| #include <string> | |||
| #include "proto/cache_perf.pb.h" | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // All our messages are very small. So we will use the stack version without the need | |||
| // to allocate memory. | |||
| struct CacheSmallMsg { | |||
| int64_t mtype; | |||
| union { | |||
| char mtext[1]; | |||
| struct { | |||
| int32_t type; // the first 4 bytes is the RequestType | |||
| int32_t proto_sz; | |||
| char proto_buffer[kSharedMessageSize]; | |||
| } msg; | |||
| } body; | |||
| }; | |||
| /// A message queue structure between the parent and the child process | |||
| class CachePerfMsg { | |||
| public: | |||
| enum MessageType : int16_t { | |||
| kInterrupt = 0, | |||
| kEpochResult = 1, | |||
| kEpochStart = 2, | |||
| kEpochEnd = 3, | |||
| kError = 4, | |||
| // Add new message before it. | |||
| kUnknownMessage = 32767 | |||
| }; | |||
| CachePerfMsg() : small_msg_{1} { | |||
| small_msg_.body.msg.type = kUnknownMessage; | |||
| small_msg_.body.msg.proto_sz = 0; | |||
| small_msg_.body.msg.proto_buffer[0] = 0; | |||
| } | |||
| ~CachePerfMsg() = default; | |||
| char *GetMutableBuffer() { return small_msg_.body.msg.proto_buffer; } | |||
| Status Send(int32_t qID); | |||
| void SetType(MessageType requestType) { small_msg_.body.msg.type = requestType; } | |||
| void SetProtoBufSz(size_t sz) { small_msg_.body.msg.proto_sz = sz; } | |||
| MessageType GetType() const { return static_cast<MessageType>(small_msg_.body.msg.type); } | |||
| size_t GetProtoBufSz() const { return small_msg_.body.msg.proto_sz; } | |||
| Status Receive(int32_t qID); | |||
| private: | |||
| CacheSmallMsg small_msg_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef USE_GLOG | |||
| #include <glog/logging.h> | |||
| #endif | |||
| #include <iostream> | |||
| #include "minddata/dataset/engine/cache/perf/cache_perf_run.h" | |||
| namespace ds = mindspore::dataset; | |||
| int main(int argc, char **argv) { | |||
| #ifdef USE_GLOG | |||
| FLAGS_log_dir = "/tmp"; | |||
| google::InitGoogleLogging(argv[0]); | |||
| #endif | |||
| ds::CachePerfRun cachePerfRun; | |||
| if (cachePerfRun.ProcessArgs(argc, argv) == 0) { | |||
| std::cout << cachePerfRun << std::endl; | |||
| ds::Status rc = cachePerfRun.Run(); | |||
| if (rc.IsError()) { | |||
| std::cerr << rc.ToString() << std::endl; | |||
| } | |||
| return static_cast<int>(rc.get_code()); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto3"; | |||
| package mindspore.dataset; | |||
| option cc_enable_arenas = true; | |||
| message PipelineWorkerEpochSummary { | |||
| int32 pipeline = 1; | |||
| int32 worker = 2; | |||
| int64 min = 3; | |||
| int64 max = 4; | |||
| int64 avg = 5; | |||
| int64 med = 6; | |||
| int64 cnt = 7; | |||
| int64 elapse = 8; | |||
| } | |||
| message EpochDone { | |||
| int32 pipeline = 1; | |||
| } | |||
| message ErrorMsg { | |||
| int32 rc = 1; | |||
| string msg = 2; | |||
| } | |||
| @@ -0,0 +1,575 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/perf/cache_perf_run.h" | |||
| #include <string.h> | |||
| #include <sys/types.h> | |||
| #include <sys/stat.h> | |||
| #include <sys/wait.h> | |||
| #include <sys/ipc.h> | |||
| #include <sys/msg.h> | |||
| #include <unistd.h> | |||
| #include <algorithm> | |||
| #include <chrono> | |||
| #include <iomanip> | |||
| #include <sstream> | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| #include "minddata/dataset/util/sig_handler.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const char CachePerfRun::kCachePipelineBinary[] = "cache_pipeline"; | |||
| void CachePerfRun::PrintHelp() { | |||
| std::cout << "Options:\n" | |||
| " -h,--help: Show this usage message\n" | |||
| " -s,--num_rows: Set the sample size, i.e., the number of " | |||
| "rows\n" | |||
| " -r,--row_size: Set the average row size\n" | |||
| " -n,--pipeline: Set the number of parallel pieplines. Default = " | |||
| << kDftNumOfPipelines | |||
| << "\n" | |||
| " -e,--epoch: Set the number of epochs. Default = " | |||
| << kDftNumberOfEpochs | |||
| << "\n" | |||
| " --shuffle: Set shuffle=True. Default = " | |||
| << std::boolalpha << kDftShuffle | |||
| << "\n" | |||
| " -p,--prefetch_size: Set the prefetch size for cache. Default = " | |||
| << kDftPrefetchSize << "\n" | |||
| << " -a,--cache_size: Set cache size. Default = " << kDftCacheSize | |||
| << " (Mb)\n" | |||
| " --spill: Set spill to disk to True. Default = " | |||
| << std::boolalpha << kDftSpill << "\n" | |||
| << " -w,--workers: Set the number of parallel workers. Default = " << cfg_.num_parallel_workers() | |||
| << "\n" | |||
| " --connection: Set number of TCP/IP connections per pipeline. Default = " | |||
| << kDftNumConnections << "\n" | |||
| << " --port: TCP/IP port of the cache server. Default = " << kCfgDefaultCachePort << "\n" | |||
| << " --hostname: Hostname of the cache server. Default = " << kCfgDefaultCacheHost << "\n"; | |||
| } | |||
| int32_t CachePerfRun::ProcessArgs(int argc, char **argv) { | |||
| if (argc == 1) { | |||
| PrintHelp(); | |||
| return -1; | |||
| } | |||
| const int32_t port_opt = 1000; // there is no short option for port | |||
| const int32_t hostname_opt = 1001; // there is no short option for hostname | |||
| const int32_t connect_opt = 1002; // there is no short option for connect | |||
| int shuffle = 0; | |||
| int spill = 0; | |||
| const char *const short_opts = ":n:e:p:a:s:r:w:"; | |||
| const option long_opts[] = {{"pipeline", required_argument, nullptr, 'n'}, | |||
| {"epoch", required_argument, nullptr, 'e'}, | |||
| {"prefetch_size", required_argument, nullptr, 'p'}, | |||
| {"shuffle", no_argument, &shuffle, 1}, | |||
| {"cache_size", required_argument, nullptr, 'a'}, | |||
| {"num_rows", required_argument, nullptr, 's'}, | |||
| {"row_size", required_argument, nullptr, 'r'}, | |||
| {"workers", required_argument, nullptr, 'w'}, | |||
| {"port", required_argument, nullptr, port_opt}, | |||
| {"hostname", required_argument, nullptr, hostname_opt}, | |||
| {"spill", no_argument, &spill, 1}, | |||
| {"connection", required_argument, nullptr, connect_opt}, | |||
| {"help", no_argument, nullptr, 'h'}, | |||
| {nullptr, no_argument, nullptr, 0}}; | |||
| std::map<int32_t, int32_t> seen_opts; | |||
| int32_t rc = 0; | |||
| try { | |||
| while (rc == 0) { | |||
| int32_t option_indxex; | |||
| const auto opt = getopt_long(argc, argv, short_opts, long_opts, &option_indxex); | |||
| if (-1 == opt) { | |||
| if (optind < argc) { | |||
| rc = -1; | |||
| std::cerr << "Unknown arguments: "; | |||
| while (optind < argc) { | |||
| std::cerr << argv[optind++] << " "; | |||
| } | |||
| std::cerr << std::endl; | |||
| } | |||
| break; | |||
| } | |||
| if (opt > 0) { | |||
| seen_opts[opt]++; | |||
| if (seen_opts[opt] > 1) { | |||
| std::string long_name = long_opts[option_indxex].name; | |||
| std::cerr << "The " << long_name << " argument was given more than once." << std::endl; | |||
| rc = -1; | |||
| continue; | |||
| } | |||
| } | |||
| switch (opt) { | |||
| case 0: { | |||
| if (long_opts[option_indxex].flag == &shuffle) { | |||
| shuffle_ = true; | |||
| } else if (long_opts[option_indxex].flag == &spill) { | |||
| cache_builder_.SetSpill(true); | |||
| } | |||
| break; | |||
| } | |||
| case 'n': { | |||
| num_pipelines_ = std::stoi(optarg); | |||
| break; | |||
| } | |||
| case 'e': { | |||
| num_epoches_ = std::stoi(optarg); | |||
| break; | |||
| } | |||
| case 'p': { | |||
| int32_t prefetch_sz = std::stoi(optarg); | |||
| cache_builder_.SetPrefetchSize(prefetch_sz); | |||
| break; | |||
| } | |||
| case 'a': { | |||
| int32_t cache_sz = std::stoi(optarg); | |||
| cache_builder_.SetCacheMemSz(cache_sz); | |||
| break; | |||
| } | |||
| case 's': { | |||
| num_rows_ = std::stoi(optarg); | |||
| break; | |||
| } | |||
| case 'r': { | |||
| row_size_ = std::stoi(optarg); | |||
| break; | |||
| } | |||
| case 'w': { | |||
| cfg_.set_num_parallel_workers(std::stoi(optarg)); | |||
| break; | |||
| } | |||
| case connect_opt: { | |||
| int32_t connection_sz = std::stoi(optarg); | |||
| cache_builder_.SetNumConnections(connection_sz); | |||
| break; | |||
| } | |||
| case port_opt: { | |||
| int32_t port = std::stoi(optarg); | |||
| cache_builder_.SetPort(port); | |||
| break; | |||
| } | |||
| case hostname_opt: { | |||
| std::string hostname = optarg; | |||
| cache_builder_.SetHostname(hostname); | |||
| break; | |||
| } | |||
| case 'h': // -h or --help | |||
| PrintHelp(); | |||
| rc = -1; | |||
| break; | |||
| case ':': | |||
| std::cerr << "Missing argument for option " << char(optopt) << std::endl; | |||
| rc = -1; | |||
| break; | |||
| case '?': // Unrecognized option | |||
| default: | |||
| std::cerr << "Unknown option " << char(optopt) << std::endl; | |||
| PrintHelp(); | |||
| rc = -1; | |||
| break; | |||
| } | |||
| } | |||
| } catch (const std::exception &e) { | |||
| PrintHelp(); | |||
| rc = -1; | |||
| } | |||
| if (rc < 0) { | |||
| return rc; | |||
| } | |||
| // We have all the defaults except sample size and average row size which the user must specify. | |||
| auto it = seen_opts.find('s'); | |||
| if (it == seen_opts.end()) { | |||
| std::cerr << "Missing sample size." << std::endl; | |||
| return -1; | |||
| } | |||
| it = seen_opts.find('r'); | |||
| if (it == seen_opts.end()) { | |||
| std::cerr << "Missing average row size." << std::endl; | |||
| return -1; | |||
| } | |||
| if (num_rows_ <= 0) { | |||
| std::cerr << "Sample size must be positive." << std::endl; | |||
| return -1; | |||
| } | |||
| if (row_size_ <= 0) { | |||
| std::cerr << "Average row size must be positive." << std::endl; | |||
| return -1; | |||
| } | |||
| if (num_pipelines_ <= 0) { | |||
| std::cerr << "Number of pipelines must be positive." << std::endl; | |||
| return -1; | |||
| } | |||
| if (num_epoches_ <= 0) { | |||
| std::cerr << "Number of epoches must be positive." << std::endl; | |||
| return -1; | |||
| } | |||
| if (num_rows_ < num_pipelines_) { | |||
| std::cerr << "Sample size is smaller than the number of pipelines." << std::endl; | |||
| return -1; | |||
| } | |||
| pid_lists_.reserve(num_pipelines_); | |||
| return 0; | |||
| } | |||
| Status CachePerfRun::GetSession() { | |||
| CacheClientGreeter comm(cache_builder_.GetHostname(), cache_builder_.GetPort(), 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| auto rq = std::make_shared<GenerateSessionIdRequest>(); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| session_ = rq->GetSessionId(); | |||
| std::cout << "Session: " << session_ << std::endl; | |||
| cache_builder_.SetSessionId(session_); | |||
| return Status::OK(); | |||
| } | |||
| CachePerfRun::CachePerfRun() | |||
| : my_pipeline_(-1), | |||
| num_pipelines_(kDftNumOfPipelines), | |||
| num_epoches_(kDftNumberOfEpochs), | |||
| num_rows_(0), | |||
| row_size_(0), | |||
| shuffle_(kDftShuffle), | |||
| session_(0), | |||
| crc_(0), | |||
| epoch_sync_cnt_(0) { | |||
| cache_builder_.SetSpill(kDftSpill).SetCacheMemSz(kDftCacheSize); | |||
| } | |||
| CachePerfRun::~CachePerfRun() { | |||
| if (session_ != 0) { | |||
| Status rc; | |||
| CacheClientGreeter comm(cache_builder_.GetHostname(), cache_builder_.GetPort(), 1); | |||
| rc = comm.ServiceStart(); | |||
| if (rc.IsOk()) { | |||
| CacheClientInfo cinfo; | |||
| cinfo.set_session_id(session_); | |||
| auto rq = std::make_shared<DropSessionRequest>(cinfo); | |||
| rc = comm.HandleRequest(rq); | |||
| if (rc.IsOk()) { | |||
| rc = rq->Wait(); | |||
| if (rc.IsOk()) { | |||
| std::cout << "Drop session " << session_ << " successful" << std::endl; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // Send an interrupt message to each child. | |||
| for (auto msg_qid : msg_send_lists_) { | |||
| CachePerfMsg msg; | |||
| msg.SetType(CachePerfMsg::MessageType::kInterrupt); | |||
| (void)msg.Send(msg_qid); | |||
| } | |||
| // Wait for each child to return | |||
| for (auto pid : pid_lists_) { | |||
| int status; | |||
| if (waitpid(pid, &status, 0) == -1) { | |||
| std::string errMsg = "waitpid fails. errno = " + std::to_string(errno); | |||
| std::cerr << errMsg << std::endl; | |||
| } else { | |||
| MS_LOG(INFO) << "Child pid " << pid << " returns." << std::endl; | |||
| } | |||
| } | |||
| // Remove all the message queues | |||
| for (auto msg_qid : msg_send_lists_) { | |||
| // Remove the message que and never mind about the return code. | |||
| (void)msgctl(msg_qid, IPC_RMID, nullptr); | |||
| } | |||
| for (auto msg_qid : msg_recv_lists_) { | |||
| // Remove the message que and never mind about the return code. | |||
| (void)msgctl(msg_qid, IPC_RMID, nullptr); | |||
| } | |||
| } | |||
| void CachePerfRun::PrintEpochSummary() const { | |||
| std::cout << std::setw(12) << "Pipeline #" << std::setw(10) << "worker id" << std::setw(11) << "min (μs)" | |||
| << std::setw(11) << "max (μs)" << std::setw(11) << "avg (μs)" << std::setw(14) << "median (μs)" | |||
| << std::setw(14) << "buffer count" << std::setw(18) << "Elapsed time (s)" << std::endl; | |||
| for (auto &it : epoch_results_) { | |||
| auto epoch_worker_summary = it.second; | |||
| std::cout << std::setw(12) << epoch_worker_summary.pipeline() + 1 << std::setw(10) << epoch_worker_summary.worker() | |||
| << std::setw(10) << epoch_worker_summary.min() << std::setw(10) << epoch_worker_summary.max() | |||
| << std::setw(10) << epoch_worker_summary.avg() << std::setw(13) << epoch_worker_summary.med() | |||
| << std::setw(14) << epoch_worker_summary.cnt() << std::setw(18) << epoch_worker_summary.elapse() | |||
| << std::endl; | |||
| } | |||
| } | |||
| Status CachePerfRun::ListenToPipeline(int32_t workerId) { | |||
| TaskManager::FindMe()->Post(); | |||
| int32_t qID = msg_recv_lists_[workerId]; | |||
| do { | |||
| RETURN_IF_INTERRUPTED(); | |||
| CachePerfMsg msg; | |||
| RETURN_IF_NOT_OK(msg.Receive(qID)); | |||
| // Decode the messages. | |||
| auto type = msg.GetType(); | |||
| char *p = msg.GetMutableBuffer(); | |||
| switch (type) { | |||
| case CachePerfMsg::MessageType::kEpochResult: { | |||
| PipelineWorkerEpochSummary epoch_worker_summary; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(epoch_worker_summary.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); | |||
| { | |||
| auto pipeline = epoch_worker_summary.pipeline(); | |||
| auto worker = epoch_worker_summary.worker(); | |||
| std::unique_lock<std::mutex> lock(mux_); | |||
| // sort by pipeline/worker | |||
| auto r = | |||
| epoch_results_.emplace(std::pair<int32_t, int32_t>(pipeline, worker), std::move(epoch_worker_summary)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(r.second, "Insert failed"); | |||
| } | |||
| break; | |||
| } | |||
| case CachePerfMsg::MessageType::kEpochEnd: { | |||
| EpochDone proto; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(proto.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); | |||
| auto n = epoch_sync_cnt_.fetch_add(1); | |||
| if (n + 1 == num_pipelines_) { | |||
| pipeline_wp_.Set(); | |||
| } | |||
| break; | |||
| } | |||
| case CachePerfMsg::MessageType::kInterrupt: { | |||
| TaskManager::WakeUpWatchDog(); | |||
| return Status::OK(); | |||
| } | |||
| case CachePerfMsg::kError: { | |||
| ErrorMsg proto; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(proto.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); | |||
| return Status(static_cast<const StatusCode>(proto.rc()), proto.msg()); | |||
| } | |||
| default: | |||
| std::string errMsg = "Unknown request type: " + std::to_string(type); | |||
| MS_LOG(ERROR) << errMsg; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| break; | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| Status CachePerfRun::Run() { | |||
| // Now we bring up TaskManager. | |||
| RETURN_IF_NOT_OK(Services::CreateInstance()); | |||
| // Handle Control-C | |||
| RegisterHandlers(); | |||
| // Get a session from the server. | |||
| RETURN_IF_NOT_OK(GetSession()); | |||
| // Generate a random crc. | |||
| auto mt = GetRandomDevice(); | |||
| std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<int32_t>::max()); | |||
| crc_ = distribution(mt); | |||
| std::cout << "CRC: " << crc_ << std::endl; | |||
| // Create all the resources required by the pipelines before we fork. | |||
| for (auto i = 0; i < num_pipelines_; ++i) { | |||
| // We will use shared message queues for communication between parent (this process) | |||
| // and each pipelines. | |||
| auto access_mode = S_IRUSR | S_IWUSR; | |||
| int32_t msg_send_qid = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); | |||
| if (msg_send_qid == -1) { | |||
| std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| msg_send_lists_.push_back(msg_send_qid); | |||
| int32_t msg_recv_qid = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); | |||
| if (msg_recv_qid == -1) { | |||
| std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| msg_recv_lists_.push_back(msg_recv_qid); | |||
| } | |||
| // Now we create the children knowing all two sets of message queues are constructed. | |||
| for (auto i = 0; i < num_pipelines_; ++i) { | |||
| auto pid = fork(); | |||
| if (pid == 0) { | |||
| // Child. We will call another binary but with different (hidden) parameters. | |||
| // The parent process is waiting on a wait post. Any error we hit here we must interrupt the | |||
| // parent process | |||
| auto interrupt_parent = [this, i]() { | |||
| CachePerfMsg msg; | |||
| msg.SetType(CachePerfMsg::MessageType::kInterrupt); | |||
| msg.Send(msg_recv_lists_[i]); | |||
| }; | |||
| const std::string self_proc = "/proc/self/exe"; | |||
| std::string canonical_path; | |||
| canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use. | |||
| // Some lower level OS library calls are needed here to determine the binary path. | |||
| if (realpath(self_proc.data(), canonical_path.data()) == nullptr) { | |||
| std::cerr << "Failed to identify cache_perf binary path: " + std::to_string(errno) << ": " << strerror(errno) | |||
| << std::endl; | |||
| interrupt_parent(); | |||
| // Call _exit instead of exit because we will hang in TaskManager destructor for a forked child process. | |||
| _exit(-1); | |||
| } | |||
| canonical_path.resize(strlen(canonical_path.data())); | |||
| int last_seperator = canonical_path.find_last_of('/'); | |||
| if (last_seperator == std::string::npos) { | |||
| std::cerr << "Canonical path can't locate / " << canonical_path << std::endl; | |||
| interrupt_parent(); | |||
| // Call _exit instead of exit because we will hang in TaskManager destructor for a forked child process. | |||
| _exit(-1); | |||
| } | |||
| // truncate the binary name so we are left with the absolute path of cache_admin binary | |||
| canonical_path.resize(last_seperator + 1); | |||
| std::string cache_pipeline_binary = canonical_path + std::string(kCachePipelineBinary); | |||
| std::string pipeline_cfg = std::to_string(i) + "," + std::to_string(session_) + "," + std::to_string(crc_) + "," + | |||
| std::to_string(msg_send_lists_[i]) + "," + std::to_string(msg_recv_lists_[i]) + "," + | |||
| std::to_string(num_pipelines_) + "," + std::to_string(num_epoches_) + "," + | |||
| std::to_string(num_rows_) + "," + std::to_string(row_size_) + "," + | |||
| std::to_string(cfg_.num_parallel_workers()) + "," + | |||
| (shuffle_ ? std::string("true").data() : std::string("false").data()); | |||
| std::string client_cfg = cache_builder_.GetHostname() + "," + std::to_string(cache_builder_.GetPort()) + "," + | |||
| std::to_string(cache_builder_.GetPrefetchSize()) + "," + | |||
| std::to_string(cache_builder_.GetCacheMemSz()) + "," + | |||
| std::to_string(cache_builder_.GetNumConnections()) + "," + | |||
| (cache_builder_.isSpill() ? std::string("true").data() : std::string("false").data()); | |||
| char *argv[4]; | |||
| argv[0] = const_cast<char *>(kCachePipelineBinary); | |||
| argv[1] = pipeline_cfg.data(); | |||
| argv[2] = client_cfg.data(); | |||
| argv[3] = nullptr; | |||
| // Invoke the binary. | |||
| execv(cache_pipeline_binary.data(), argv); | |||
| std::cerr << "Unable to exec. Errno = " + std::to_string(errno) << ": " << strerror(errno) << std::endl; | |||
| interrupt_parent(); | |||
| // Call _exit instead of exit because we will hang TaskManager destructor for a forked child process. | |||
| _exit(-1); | |||
| } else if (pid > 0) { | |||
| std::cout << "Pipeline number " << i + 1 << " has been created with process id: " << pid << std::endl; | |||
| pid_lists_.push_back(pid); | |||
| } else { | |||
| std::string errMsg = "Failed to fork process for cache pipeline: " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| } | |||
| // Spawn a few threads to monitor the communications from the pipeline. | |||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||
| auto f = std::bind(&CachePerfRun::ListenToPipeline, this, std::placeholders::_1); | |||
| for (auto i = 0; i < num_pipelines_; ++i) { | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(f, i))); | |||
| } | |||
| // Wait until all pipelines finish the first epoch. | |||
| RETURN_IF_NOT_OK(pipeline_wp_.Wait()); | |||
| std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer() | |||
| << std::endl; | |||
| PrintEpochSummary(); | |||
| // Get some stat but we need to connect. The server will thinks it is the (n+1) pipeline | |||
| RETURN_IF_NOT_OK(cache_builder_.Build(&cc_)); | |||
| Status rc = cc_->CreateCache(crc_, false); | |||
| // Duplicate key is fine. | |||
| if (rc.IsError() && rc.get_code() != StatusCode::kDuplicateKey) { | |||
| return rc; | |||
| } | |||
| CacheServiceStat stat{}; | |||
| RETURN_IF_NOT_OK(cc_->GetStat(&stat)); | |||
| std::cout << "Get statistics for this session:\n"; | |||
| std::cout << std::setw(12) << "Mem cached" << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" | |||
| << std::setw(10) << "Numa hit" << std::endl; | |||
| std::string stat_mem_cached; | |||
| std::string stat_disk_cached; | |||
| std::string stat_avg_cached; | |||
| std::string stat_numa_hit; | |||
| stat_mem_cached = (stat.num_mem_cached == 0) ? "n/a" : std::to_string(stat.num_mem_cached); | |||
| stat_disk_cached = (stat.num_disk_cached == 0) ? "n/a" : std::to_string(stat.num_disk_cached); | |||
| stat_avg_cached = (stat.avg_cache_sz == 0) ? "n/a" : std::to_string(stat.avg_cache_sz); | |||
| stat_numa_hit = (stat.num_numa_hit == 0) ? "n/a" : std::to_string(stat.num_numa_hit); | |||
| std::cout << std::setw(12) << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached | |||
| << std::setw(10) << stat_numa_hit << std::endl; | |||
| // Toggle write mode off since the rest are just read only. | |||
| // Simplest way is call this special internal function. | |||
| cc_->ServerRunningOutOfResources(); | |||
| // The rest of the epochs are just fetching. | |||
| auto epoch_num = 2; | |||
| while (epoch_num <= num_epoches_) { | |||
| epoch_sync_cnt_ = 0; | |||
| pipeline_wp_.Clear(); | |||
| epoch_results_.clear(); | |||
| // Signal each pipeline to start | |||
| for (auto msg_qid : msg_send_lists_) { | |||
| CachePerfMsg msg; | |||
| msg.SetType(CachePerfMsg::MessageType::kEpochStart); | |||
| (void)msg.Send(msg_qid); | |||
| } | |||
| // Wait for the child to finish | |||
| RETURN_IF_NOT_OK(pipeline_wp_.Wait()); | |||
| std::cout << "Epoch " << epoch_num | |||
| << " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl; | |||
| PrintEpochSummary(); | |||
| ++epoch_num; | |||
| } | |||
| // Destroy the cache. We no longer need it around. | |||
| RETURN_IF_NOT_OK(cc_->DestroyCache()); | |||
| // Unreserve the session | |||
| CacheClientInfo cinfo; | |||
| cinfo.set_session_id(session_); | |||
| auto rq = std::make_shared<DropSessionRequest>(cinfo); | |||
| RETURN_IF_NOT_OK(cc_->PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| std::cout << "Drop session " << session_ << " successful" << std::endl; | |||
| session_ = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,100 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ | |||
| #include <getopt.h> | |||
| #include <atomic> | |||
| #include <cstdint> | |||
| #include <limits> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <random> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/cache/perf/cache_msg.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| constexpr int32_t kDftNumOfPipelines = 8; | |||
| constexpr int32_t kDftNumberOfEpochs = 10; | |||
| constexpr int32_t kDftCacheSize = 0; | |||
| constexpr bool kDftShuffle = false; | |||
| constexpr bool kDftSpill = false; | |||
| class CachePerfRun { | |||
| public: | |||
| static const char kCachePipelineBinary[]; | |||
| CachePerfRun(); | |||
| ~CachePerfRun(); | |||
| void PrintHelp(); | |||
| int32_t ProcessArgs(int argc, char **argv); | |||
| void Print(std::ostream &out) const { | |||
| out << "Number of pipelines: " << num_pipelines_ << "\n" | |||
| << "Number of epochs: " << num_epoches_ << "\n" | |||
| << "Sample size: " << num_rows_ << "\n" | |||
| << "Average row size: " << row_size_ << "\n" | |||
| << "Shuffle: " << std::boolalpha << shuffle_; | |||
| } | |||
| friend std::ostream &operator<<(std::ostream &out, const CachePerfRun &cp) { | |||
| cp.Print(out); | |||
| return out; | |||
| } | |||
| Status Run(); | |||
| private: | |||
| std::mutex mux_; | |||
| int32_t my_pipeline_; | |||
| int32_t num_pipelines_; | |||
| int32_t num_epoches_; | |||
| int64_t num_rows_; | |||
| int32_t row_size_; | |||
| bool shuffle_; | |||
| CacheClient::Builder cache_builder_; | |||
| session_id_type session_; | |||
| int32_t crc_; | |||
| std::vector<int32_t> pid_lists_; | |||
| std::vector<int32_t> msg_send_lists_; | |||
| std::vector<int32_t> msg_recv_lists_; | |||
| TaskGroup vg_; | |||
| std::atomic<int32_t> epoch_sync_cnt_; | |||
| WaitPost pipeline_wp_; | |||
| std::map<std::pair<int32_t, int32_t>, PipelineWorkerEpochSummary> epoch_results_; | |||
| ConfigManager cfg_; | |||
| std::shared_ptr<CacheClient> cc_; | |||
| Status GetSession(); | |||
| Status ListenToPipeline(int32_t workerId); | |||
| void PrintEpochSummary() const; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef USE_GLOG | |||
| #include <glog/logging.h> | |||
| #endif | |||
| #include <string.h> | |||
| #include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h" | |||
| namespace ds = mindspore::dataset; | |||
| int main(int argc, char **argv) { | |||
| #ifdef USE_GLOG | |||
| FLAGS_log_dir = "/tmp"; | |||
| FLAGS_minloglevel = google::WARNING; | |||
| google::InitGoogleLogging(argv[0]); | |||
| #endif | |||
| ds::CachePipelineRun cachePipelineRun; | |||
| if (cachePipelineRun.ProcessArgs(argc, argv) == 0) { | |||
| ds::Status rc = cachePipelineRun.Run(); | |||
| // If we hit any error, send the rc back to the parent. | |||
| if (rc.IsError()) { | |||
| ds::ErrorMsg proto; | |||
| proto.set_rc(static_cast<int32_t>(rc.get_code())); | |||
| proto.set_msg(rc.ToString()); | |||
| ds::CachePerfMsg msg; | |||
| (void)cachePipelineRun.SendMessage(&msg, ds::CachePerfMsg::MessageType::kError, &proto); | |||
| } | |||
| return static_cast<int>(rc.get_code()); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -0,0 +1,471 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h" | |||
| #include <string.h> | |||
| #include <sys/types.h> | |||
| #include <algorithm> | |||
| #include <chrono> | |||
| #include <iomanip> | |||
| #include <sstream> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| void CachePipelineRun::PrintHelp() { std::cout << "Please run the executable cache_perf instead." << std::endl; } | |||
| int32_t CachePipelineRun::ProcessArgs(int argc, char **argv) { | |||
| if (argc != 3) { | |||
| PrintHelp(); | |||
| return -1; | |||
| } | |||
| try { | |||
| std::stringstream cfg_ss(argv[1]); | |||
| std::string s; | |||
| int32_t numArgs = 0; | |||
| while (std::getline(cfg_ss, s, ',')) { | |||
| if (numArgs == 0) { | |||
| my_pipeline_ = std::stoi(s); | |||
| } else if (numArgs == 1) { | |||
| session_ = std::stoul(s); | |||
| cache_builder_.SetSessionId(session_); | |||
| } else if (numArgs == 2) { | |||
| crc_ = std::stoi(s); | |||
| } else if (numArgs == 3) { | |||
| recv_id_ = std::stoi(s); | |||
| } else if (numArgs == 4) { | |||
| send_id_ = std::stoi(s); | |||
| } else if (numArgs == 5) { | |||
| num_pipelines_ = std::stoi(s); | |||
| } else if (numArgs == 6) { | |||
| num_epoches_ = std::stoi(s); | |||
| } else if (numArgs == 7) { | |||
| num_rows_ = std::stol(s); | |||
| } else if (numArgs == 8) { | |||
| row_size_ = std::stoi(s); | |||
| } else if (numArgs == 9) { | |||
| cfg_.set_num_parallel_workers(std::stol(s)); | |||
| } else if (numArgs == 10) { | |||
| shuffle_ = strcmp(s.data(), "true") == 0; | |||
| } | |||
| ++numArgs; | |||
| } | |||
| if (numArgs != 11) { | |||
| std::cerr << "Incomplete arguments. Expect 11. But get " << numArgs << std::endl; | |||
| return -1; | |||
| } | |||
| std::stringstream client_ss(argv[2]); | |||
| numArgs = 0; | |||
| while (std::getline(client_ss, s, ',')) { | |||
| if (numArgs == 0) { | |||
| cache_builder_.SetHostname(s); | |||
| } else if (numArgs == 1) { | |||
| cache_builder_.SetPort(std::stoi(s)); | |||
| } else if (numArgs == 2) { | |||
| cache_builder_.SetPrefetchSize(std::stoi(s)); | |||
| } else if (numArgs == 3) { | |||
| cache_builder_.SetCacheMemSz(std::stoi(s)); | |||
| } else if (numArgs == 4) { | |||
| cache_builder_.SetNumConnections(std::stoi(s)); | |||
| } else if (numArgs == 5) { | |||
| cache_builder_.SetSpill(strcmp(s.data(), "true") == 0); | |||
| } | |||
| ++numArgs; | |||
| } | |||
| if (numArgs != 6) { | |||
| std::cerr << "Incomplete arguments. Expect 6. But get " << numArgs << std::endl; | |||
| return -1; | |||
| } | |||
| } catch (const std::exception &e) { | |||
| std::cerr << "Parse error: " << e.what() << std::endl; | |||
| return -1; | |||
| } | |||
| return 0; | |||
| } | |||
| CachePipelineRun::CachePipelineRun() | |||
| : my_pipeline_(-1), | |||
| num_pipelines_(kDftNumOfPipelines), | |||
| num_epoches_(kDftNumberOfEpochs), | |||
| num_rows_(0), | |||
| row_size_(0), | |||
| shuffle_(kDftShuffle), | |||
| session_(0), | |||
| crc_(0), | |||
| send_id_(-1), | |||
| recv_id_(-1), | |||
| start_row_(-1), | |||
| end_row_(-1) { | |||
| cache_builder_.SetSpill(kDftSpill).SetCacheMemSz(kDftCacheSize); | |||
| } | |||
| CachePipelineRun::~CachePipelineRun() { | |||
| CachePerfMsg msg; | |||
| (void)SendMessage<ErrorMsg>(&msg, CachePerfMsg::MessageType::kInterrupt, nullptr); | |||
| } | |||
| Status CachePipelineRun::ListenToParent() { | |||
| TaskManager::FindMe()->Post(); | |||
| do { | |||
| RETURN_IF_INTERRUPTED(); | |||
| CachePerfMsg msg; | |||
| RETURN_IF_NOT_OK(msg.Receive(recv_id_)); | |||
| // Decode the messages. | |||
| auto type = msg.GetType(); | |||
| switch (type) { | |||
| case CachePerfMsg::MessageType::kInterrupt: { | |||
| TaskManager::WakeUpWatchDog(); | |||
| return Status::OK(); | |||
| } | |||
| case CachePerfMsg::MessageType::kEpochStart: { | |||
| pipeline_wp_.Set(); | |||
| break; | |||
| } | |||
| default: | |||
| std::string errMsg = "Unknown request type: " + std::to_string(type); | |||
| MS_LOG(ERROR) << errMsg; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| break; | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| Status CachePipelineRun::Run() { | |||
| RETURN_IF_NOT_OK(cache_builder_.Build(&cc_)); | |||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||
| auto num_workers = cfg_.num_parallel_workers(); | |||
| io_block_queues_.Init(num_workers, cfg_.op_connector_size()); | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(&vg_)); | |||
| Status rc = cc_->CreateCache(crc_, false); | |||
| // Duplicate key is fine. | |||
| if (rc.IsError() && rc.get_code() != StatusCode::kDuplicateKey) { | |||
| return rc; | |||
| } | |||
| // Log a warning level message so we can see it in the log file when it starts. | |||
| MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " successfully creating cache service." << std::endl; | |||
| // Spawn a thread to listen to the parent process | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(&CachePipelineRun::ListenToParent, this))); | |||
| RETURN_IF_NOT_OK(RunFirstEpoch()); | |||
| // The rest of the epochs are just fetching. | |||
| auto remaining_epochs = num_epoches_ - 1; | |||
| while (remaining_epochs > 0) { | |||
| // Wait for parent process signal to start | |||
| pipeline_wp_.Wait(); | |||
| pipeline_wp_.Clear(); | |||
| RETURN_IF_NOT_OK(RunReadEpoch()); | |||
| --remaining_epochs; | |||
| } | |||
| // The listener thread is blocked on a shared message queue. It will be waken up by | |||
| // the parent process which will send an interrupt message to it, and this program | |||
| // will exit. | |||
| RETURN_IF_NOT_OK(vg_.join_all()); | |||
| return Status::OK(); | |||
| } | |||
| Status CachePipelineRun::RunFirstEpoch() { | |||
| auto sz = num_rows_ / num_pipelines_; | |||
| start_row_ = my_pipeline_ * sz; | |||
| end_row_ = (my_pipeline_ + 1) * sz - 1; | |||
| if (my_pipeline_ + 1 == num_pipelines_) { | |||
| end_row_ = num_rows_ - 1; | |||
| } | |||
| std::cout << "Pipeline number " << my_pipeline_ + 1 << " row id range: [" << start_row_ << "," << end_row_ << "]" | |||
| << std::endl; | |||
| // Spawn the worker threads. | |||
| auto f = std::bind(&CachePipelineRun::WriterWorkerEntry, this, std::placeholders::_1); | |||
| std::vector<Task *> worker_threads; | |||
| auto num_workers = cfg_.num_parallel_workers(); | |||
| worker_threads.reserve(num_workers); | |||
| for (int32_t i = 0; i < num_workers; ++i) { | |||
| Task *pTask; | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Parallel worker", std::bind(f, i), &pTask)); | |||
| worker_threads.push_back(pTask); | |||
| } | |||
| std::vector<row_id_type> keys; | |||
| auto rows_per_buffer = cfg_.rows_per_buffer(); | |||
| keys.reserve(rows_per_buffer); | |||
| int32_t worker_id = 0; | |||
| for (auto i = start_row_; i <= end_row_; ++i) { | |||
| keys.push_back(i); | |||
| if (keys.size() == rows_per_buffer) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| } | |||
| } | |||
| if (!keys.empty()) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| } | |||
| // Shutdown threads and wait for them to come back. | |||
| for (int32_t i = 0; i < num_workers; i++) { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| for (auto *pTask : worker_threads) { | |||
| RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking)); | |||
| } | |||
| // Send a message saying epoch one done for this pipeline. | |||
| EpochDone proto; | |||
| proto.set_pipeline(my_pipeline_); | |||
| CachePerfMsg msg; | |||
| RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochEnd, &proto)); | |||
| return Status::OK(); | |||
| } | |||
| Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) { | |||
| Status rc; | |||
| TaskManager::FindMe()->Post(); | |||
| int64_t min_val = std::numeric_limits<int64_t>::max(); | |||
| int64_t max_val = 0; | |||
| int64_t total_val = 0; | |||
| int64_t cnt = 0; | |||
| std::vector<int64_t> duration; | |||
| duration.reserve(num_rows_ / num_pipelines_ / cfg_.num_parallel_workers()); | |||
| bool resource_err = false; | |||
| auto col_desc = std::make_unique<ColDescriptor>("int64", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); | |||
| auto num_elements = row_size_ / sizeof(int64_t); | |||
| TensorShape shape(std::vector<dsize_t>(1, num_elements)); | |||
| std::unique_ptr<IOBlock> blk; | |||
| auto epoch_start = std::chrono::steady_clock::now(); | |||
| do { | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); | |||
| std::vector<int64_t> keys; | |||
| RETURN_IF_NOT_OK(blk->GetKeys(&keys)); | |||
| if (keys.empty()) { | |||
| // empty key is a quit signal for workers | |||
| break; | |||
| } | |||
| // Once we hit resource error, we drain the io block. No point to send anything to the server. | |||
| if (!resource_err) { | |||
| auto buffer = std::make_unique<DataBuffer>(cnt++, DataBuffer::kDeBFlagNone); | |||
| auto tensor_table = std::make_unique<TensorQTable>(); | |||
| for (auto id : keys) { | |||
| TensorRow row; | |||
| std::shared_ptr<Tensor> element; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc->type(), &element)); | |||
| row.setId(id); | |||
| // CreateEmpty allocates the memory but in virutal address. Let's commit the memory | |||
| // so we can get an accurate timing. | |||
| auto it = element->begin<int64_t>(); | |||
| for (auto i = 0; i < num_elements; ++i, ++it) { | |||
| *it = i; | |||
| } | |||
| row.push_back(std::move(element)); | |||
| tensor_table->push_back(std::move(row)); | |||
| } | |||
| buffer->set_tensor_table(std::move(tensor_table)); | |||
| // Measure the time to call WriteBuffer | |||
| auto start_tick = std::chrono::steady_clock::now(); | |||
| rc = cc_->WriteBuffer(std::move(buffer)); | |||
| auto end_tick = std::chrono::steady_clock::now(); | |||
| if (rc.IsError()) { | |||
| if (rc.IsOutofMemory() || rc.IsNoSpace()) { | |||
| MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " worker id " << worker_id << ": " | |||
| << rc.ToString(); | |||
| resource_err = true; | |||
| cc_->ServerRunningOutOfResources(); | |||
| continue; | |||
| } else { | |||
| return rc; | |||
| } | |||
| } else { | |||
| int64_t ms = std::chrono::duration_cast<std::chrono::microseconds>(end_tick - start_tick).count(); | |||
| min_val = std::min(min_val, ms); | |||
| max_val = std::max(max_val, ms); | |||
| duration.push_back(ms); | |||
| total_val += ms; | |||
| } | |||
| } | |||
| } while (true); | |||
| auto epoch_end = std::chrono::steady_clock::now(); | |||
| int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(epoch_end - epoch_start).count(); | |||
| PipelineWorkerEpochSummary proto; | |||
| proto.set_pipeline(my_pipeline_); | |||
| proto.set_worker(worker_id); | |||
| proto.set_min(min_val); | |||
| proto.set_max(max_val); | |||
| proto.set_elapse(elapse_time); | |||
| auto sz = duration.size(); | |||
| proto.set_cnt(sz); | |||
| if (sz > 0) { | |||
| // median | |||
| auto n = sz / 2; | |||
| std::nth_element(duration.begin(), duration.begin() + n, duration.end()); | |||
| auto median = duration[n]; | |||
| proto.set_med(median); | |||
| // average | |||
| int64_t avg = total_val / sz; | |||
| proto.set_avg(avg); | |||
| } | |||
| CachePerfMsg msg; | |||
| RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochResult, &proto)); | |||
| return Status::OK(); | |||
| } | |||
| Status CachePipelineRun::RunReadEpoch() { | |||
| std::vector<row_id_type> keys; | |||
| auto rows_per_buffer = cc_->GetPrefetchSize(); // We will use prefetch size to read. | |||
| auto num_workers = cfg_.num_parallel_workers(); | |||
| keys.reserve(rows_per_buffer); | |||
| // Spawn workers | |||
| auto f = std::bind(&CachePipelineRun::ReaderWorkerEntry, this, std::placeholders::_1); | |||
| std::vector<Task *> worker_threads; | |||
| worker_threads.reserve(num_workers); | |||
| for (int32_t i = 0; i < num_workers; ++i) { | |||
| Task *pTask; | |||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Parallel worker", std::bind(f, i), &pTask)); | |||
| worker_threads.push_back(pTask); | |||
| } | |||
| std::vector<row_id_type> all_keys; | |||
| all_keys.reserve(end_row_ - start_row_ + 1); | |||
| for (auto i = start_row_; i <= end_row_; ++i) { | |||
| all_keys.push_back((i)); | |||
| } | |||
| // If we need to shuffle the keys | |||
| if (shuffle_) { | |||
| std::shuffle(all_keys.begin(), all_keys.end(), GetRandomDevice()); | |||
| } | |||
| int32_t worker_id = 0; | |||
| for (auto id : all_keys) { | |||
| keys.push_back(id); | |||
| if (keys.size() == rows_per_buffer) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| } | |||
| } | |||
| if (!keys.empty()) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| } | |||
| // Shutdown threads and wait for them to come back. | |||
| for (int32_t i = 0; i < num_workers; i++) { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| for (auto *pTask : worker_threads) { | |||
| RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking)); | |||
| } | |||
| // Send a message saying epoch one done for this pipeline. | |||
| EpochDone proto; | |||
| proto.set_pipeline(my_pipeline_); | |||
| CachePerfMsg msg; | |||
| RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochEnd, &proto)); | |||
| return Status::OK(); | |||
| } | |||
| Status CachePipelineRun::ReaderWorkerEntry(int32_t worker_id) { | |||
| Status rc; | |||
| TaskManager::FindMe()->Post(); | |||
| int64_t min_val = std::numeric_limits<int64_t>::max(); | |||
| int64_t max_val = 0; | |||
| int64_t total_val = 0; | |||
| int64_t cnt = 0; | |||
| std::vector<int64_t> duration; | |||
| duration.reserve(num_rows_ / num_pipelines_ / cfg_.num_parallel_workers()); | |||
| std::unique_ptr<IOBlock> blk; | |||
| auto epoch_start = std::chrono::steady_clock::now(); | |||
| do { | |||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); | |||
| std::vector<int64_t> keys; | |||
| RETURN_IF_NOT_OK(blk->GetKeys(&keys)); | |||
| if (keys.empty()) { | |||
| // empty key is a quit signal for workers | |||
| break; | |||
| } | |||
| std::vector<row_id_type> prefetch_keys; | |||
| prefetch_keys.reserve(keys.size()); | |||
| // Filter out all those keys that unlikely we will find at the server | |||
| for (auto row_id : keys) { | |||
| if (!cc_->KeyIsCacheMiss(row_id)) { | |||
| prefetch_keys.push_back(row_id); | |||
| } | |||
| } | |||
| // Early exit if nothing to fetch | |||
| if (prefetch_keys.empty()) { | |||
| continue; | |||
| } | |||
| // Get the rows from the server | |||
| TensorTable ttbl; | |||
| // Measure how long it takes for the row to come back. | |||
| auto start_tick = std::chrono::steady_clock::now(); | |||
| RETURN_IF_NOT_OK(cc_->GetRows(prefetch_keys, &ttbl)); | |||
| auto end_tick = std::chrono::steady_clock::now(); | |||
| int64_t ms = std::chrono::duration_cast<std::chrono::microseconds>(end_tick - start_tick).count(); | |||
| min_val = std::min(min_val, ms); | |||
| max_val = std::max(max_val, ms); | |||
| duration.push_back(ms); | |||
| total_val += ms; | |||
| ++cnt; | |||
| } while (true); | |||
| auto epoch_end = std::chrono::steady_clock::now(); | |||
| int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(epoch_end - epoch_start).count(); | |||
| PipelineWorkerEpochSummary proto; | |||
| proto.set_pipeline(my_pipeline_); | |||
| proto.set_worker(worker_id); | |||
| proto.set_min(min_val); | |||
| proto.set_max(max_val); | |||
| proto.set_elapse(elapse_time); | |||
| auto sz = duration.size(); | |||
| proto.set_cnt(sz); | |||
| if (sz > 0) { | |||
| // median | |||
| auto n = sz / 2; | |||
| std::nth_element(duration.begin(), duration.begin() + n, duration.end()); | |||
| auto median = duration[n]; | |||
| proto.set_med(median); | |||
| // average | |||
| int64_t avg = total_val / sz; | |||
| proto.set_avg(avg); | |||
| } | |||
| CachePerfMsg msg; | |||
| RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochResult, &proto)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,117 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ | |||
| #include <getopt.h> | |||
| #include <atomic> | |||
| #include <cstdint> | |||
| #include <limits> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <random> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/cache/perf/cache_msg.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| constexpr int32_t kDftNumOfPipelines = 8; | |||
| constexpr int32_t kDftNumberOfEpochs = 10; | |||
| constexpr int32_t kDftCacheSize = 0; | |||
| constexpr bool kDftShuffle = false; | |||
| constexpr bool kDftSpill = false; | |||
| class CachePipelineRun { | |||
| public: | |||
| CachePipelineRun(); | |||
| ~CachePipelineRun(); | |||
| static void PrintHelp(); | |||
| int32_t ProcessArgs(int argc, char **argv); | |||
| void Print(std::ostream &out) const { | |||
| out << "Number of pipelines: " << num_pipelines_ << "\n" | |||
| << "Number of epochs: " << num_epoches_ << "\n" | |||
| << "Sample size: " << num_rows_ << "\n" | |||
| << "Average row size: " << row_size_ << "\n" | |||
| << "Shuffle: " << std::boolalpha << shuffle_; | |||
| } | |||
| friend std::ostream &operator<<(std::ostream &out, const CachePipelineRun &cp) { | |||
| cp.Print(out); | |||
| return out; | |||
| } | |||
| Status Run(); | |||
| template <typename T> | |||
| Status SendMessage(CachePerfMsg *msg, CachePerfMsg::MessageType type, T *proto) { | |||
| RETURN_UNEXPECTED_IF_NULL(msg); | |||
| msg->SetType(type); | |||
| if (proto != nullptr) { | |||
| auto size_needed = proto->ByteSizeLong(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| size_needed <= kSharedMessageSize, | |||
| "Shared message set too small. Suggest to increase to " + std::to_string(size_needed)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(proto->SerializeToArray(msg->GetMutableBuffer(), kSharedMessageSize), | |||
| "Serialization fails"); | |||
| msg->SetProtoBufSz(size_needed); | |||
| } | |||
| RETURN_IF_NOT_OK(msg->Send(send_id_)); | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| int32_t my_pipeline_; | |||
| int32_t num_pipelines_; | |||
| int32_t num_epoches_; | |||
| int64_t num_rows_; | |||
| int32_t row_size_; | |||
| bool shuffle_; | |||
| CacheClient::Builder cache_builder_; | |||
| session_id_type session_; | |||
| int32_t crc_; | |||
| TaskGroup vg_; | |||
| WaitPost pipeline_wp_; | |||
| int32_t send_id_; | |||
| int32_t recv_id_; | |||
| row_id_type start_row_; | |||
| row_id_type end_row_; | |||
| ConfigManager cfg_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks | |||
| std::shared_ptr<CacheClient> cc_; | |||
| Status ListenToParent(); | |||
| Status RunFirstEpoch(); | |||
| Status RunReadEpoch(); | |||
| Status WriterWorkerEntry(int32_t worker_id); | |||
| Status ReaderWorkerEntry(int32_t worker_id); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/util/storage_container.h" | |||
| #include "minddata/dataset/engine/cache/storage_container.h" | |||
| #include <fcntl.h> | |||
| #include <sys/stat.h> | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/util/storage_manager.h" | |||
| #include "minddata/dataset/engine/cache/storage_manager.h" | |||
| #include <iomanip> | |||
| #include <sstream> | |||
| @@ -21,6 +21,7 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/storage_container.h" | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/auto_index.h" | |||
| #include "minddata/dataset/util/lock.h" | |||
| @@ -28,7 +29,6 @@ | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/service.h" | |||
| #include "minddata/dataset/util/slice.h" | |||
| #include "minddata/dataset/util/storage_container.h" | |||
| using ListOfContainers = std::vector<std::shared_ptr<mindspore::dataset::StorageContainer>>; | |||
| namespace mindspore { | |||
| @@ -271,29 +271,18 @@ Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector | |||
| } | |||
| // Get the rows from the server | |||
| TensorTable ttbl; | |||
| Status rc = cache_client_->GetRows(prefetch_keys, &ttbl); | |||
| if (rc.IsOk()) { | |||
| auto row_it = ttbl.begin(); | |||
| for (auto row_id : prefetch_keys) { | |||
| auto &row = *row_it; | |||
| if (row.empty()) { | |||
| cache_miss->push_back(row_id); | |||
| } | |||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| ++row_it; | |||
| } | |||
| } else { | |||
| // In case any thread is waiting for the rows to come back and blocked on a semaphore, | |||
| // we will put an empty row in the local cache. | |||
| for (auto row_id : prefetch_keys) { | |||
| TensorRow row; | |||
| row.setId(row_id); | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); | |||
| auto row_it = ttbl.begin(); | |||
| for (auto row_id : prefetch_keys) { | |||
| auto &row = *row_it; | |||
| if (row.empty()) { | |||
| cache_miss->push_back(row_id); | |||
| } | |||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| ++row_it; | |||
| } | |||
| return rc; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheBase::Prefetcher(int32_t worker_id) { | |||
| @@ -322,6 +311,16 @@ Status CacheBase::Prefetcher(int32_t worker_id) { | |||
| return rc; | |||
| } | |||
| } while (rc.IsNetWorkError()); | |||
| // In case any thread is waiting for the rows to come back and blocked on a semaphore, | |||
| // we will put an empty row in the local cache. | |||
| if (rc.IsError() && AllowCacheMiss()) { | |||
| for (auto row_id : prefetch_keys) { | |||
| TensorRow row; | |||
| row.setId(row_id); | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| cache_miss.push_back(row_id); | |||
| } | |||
| } | |||
| } else { | |||
| if (AllowCacheMiss()) { | |||
| // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from | |||
| @@ -24,7 +24,6 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/connector.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| @@ -309,8 +309,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha | |||
| if (st_.compare_exchange_strong(expected, State::kDirty)) { | |||
| // We will do a deep copy but write directly into CacheRequest protobuf or shared memory | |||
| Status rc; | |||
| cleaner_copy_ = | |||
| std::make_shared<CacheRowRequest>(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient()); | |||
| cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get()); | |||
| rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); | |||
| if (rc.IsOk()) { | |||
| // Send the request async. The cleaner will check the return code. | |||
| @@ -153,7 +153,7 @@ Status CacheOp::WaitForCachingAllRows() { | |||
| bool BuildPhaseDone = true; | |||
| do { | |||
| RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | |||
| BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase); | |||
| BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase); | |||
| if (!BuildPhaseDone) { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |||
| } | |||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CacheErrorPass::CacheErrorPass() : is_cached_(false) {} | |||
| CacheErrorPass::CacheErrorPass() : is_cached_(false), is_mappable_(false) {} | |||
| // Identifies the subtree below this node as being cached | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| @@ -75,5 +75,81 @@ Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modifi | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||
| is_mappable_ = true; | |||
| return Status::OK(); | |||
| } | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Turn off the flag that we're under a merge op | |||
| is_cached_ = false; | |||
| return Status::OK(); | |||
| } | |||
| // Currently, returns an error if RepeatOp exists under a cache | |||
| // Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| if (is_cached_ && is_mappable_) { | |||
| RETURN_STATUS_UNEXPECTED("Repeat is not supported as a descendant operator under a mappable cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -67,8 +67,81 @@ class CacheErrorPass : public NodePass { | |||
| Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||
| /// \brief Identifies the subtree above this node as not being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Identifies and block repeat under cache scenarios | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| private: | |||
| bool is_cached_; | |||
| bool is_mappable_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -3,7 +3,6 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE | |||
| add_library(utils OBJECT | |||
| arena.cc | |||
| buddy.cc | |||
| cache_pool.cc | |||
| circular_pool.cc | |||
| data_helper.cc | |||
| memory_pool.cc | |||
| @@ -16,8 +15,6 @@ add_library(utils OBJECT | |||
| lock.cc | |||
| semaphore.cc | |||
| status.cc | |||
| storage_container.cc | |||
| storage_manager.cc | |||
| slice.cc | |||
| path.cc | |||
| wait_post.cc | |||
| @@ -94,6 +94,11 @@ Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc, | |||
| CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive"); | |||
| try { | |||
| T *data = alloc.allocate(n); | |||
| // Some of our implementation of allocator (e.g. NumaAllocator) don't throw std::bad_alloc. | |||
| // So we have to catch for null ptr | |||
| if (data == nullptr) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| if (!std::is_arithmetic<T>::value) { | |||
| for (auto i = 0; i < n; i++) { | |||
| std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...); | |||
| @@ -78,6 +78,18 @@ class Path { | |||
| Path operator/(const char *); | |||
| bool operator==(const Path &rhs) const { return (path_ == rhs.path_); } | |||
| bool operator!=(const Path &rhs) const { return (path_ != rhs.path_); } | |||
| bool operator<(const Path &rhs) const { return (path_ < rhs.path_); } | |||
| bool operator>(const Path &rhs) const { return (path_ > rhs.path_); } | |||
| bool operator<=(const Path &rhs) const { return (path_ <= rhs.path_); } | |||
| bool operator>=(const Path &rhs) const { return (path_ >= rhs.path_); } | |||
| bool Exists(); | |||
| bool IsDirectory(); | |||
| @@ -37,6 +37,11 @@ void Task::operator()() { | |||
| ss << Services::GetUniqueID(); | |||
| #endif | |||
| MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started."; | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| native_handle_ = pthread_self(); | |||
| #endif | |||
| try { | |||
| // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set | |||
| // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can | |||
| @@ -96,7 +101,8 @@ Task::Task(const std::string &myName, const std::function<Status()> &f) | |||
| task_group_(nullptr), | |||
| is_master_(false), | |||
| running_(false), | |||
| caught_severe_exception_(false) { | |||
| caught_severe_exception_(false), | |||
| native_handle_(0) { | |||
| IntrpResource::ResetIntrpState(); | |||
| wp_.ResetIntrpState(); | |||
| wp_.Clear(); | |||
| @@ -164,5 +170,10 @@ Status Task::OverrideInterruptRc(const Status &rc) { | |||
| } | |||
| return rc; | |||
| } | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| pthread_t Task::GetNativeHandle() const { return native_handle_; } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,6 +16,9 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| #include <pthread.h> | |||
| #endif | |||
| #include <chrono> | |||
| #include <exception> | |||
| #include <functional> | |||
| @@ -84,7 +87,7 @@ class Task : public IntrpResource { | |||
| std::thread::id get_id() { return id_; } | |||
| std::string MyName() { return my_name_; } | |||
| std::string MyName() const { return my_name_; } | |||
| // An operator used by std::find | |||
| bool operator==(const Task &other) const { return (this == &other); } | |||
| @@ -97,6 +100,10 @@ class Task : public IntrpResource { | |||
| static Status OverrideInterruptRc(const Status &rc); | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| pthread_t GetNativeHandle() const; | |||
| #endif | |||
| private: | |||
| mutable std::mutex mux_; | |||
| std::string my_name_; | |||
| @@ -113,6 +120,12 @@ class Task : public IntrpResource { | |||
| volatile bool running_; | |||
| volatile bool caught_severe_exception_; | |||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||
| pthread_t native_handle_; | |||
| #else | |||
| uint64_t native_handle_; | |||
| #endif | |||
| void ShutdownGroup(); | |||
| TaskGroup *MyTaskGroup(); | |||
| void set_task_group(TaskGroup *vg); | |||
| @@ -24,7 +24,6 @@ | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "minddata/dataset/util/storage_container.h" // lint !e322 | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| @@ -31,7 +31,7 @@ HandleRcExit $? 1 1 | |||
| export RUN_CACHE_TEST=TRUE | |||
| # Each of these tests will create session, use it, then destroy it after the test | |||
| for i in $(seq 1 6) | |||
| for i in $(seq 1 5) | |||
| do | |||
| test_name="test_cache_map_basic${i}" | |||
| GetSession | |||
| @@ -121,6 +121,12 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_voc" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_python_sampler" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" | |||
| HandleRcExit $? 0 0 | |||
| # Run two parallel pipelines (sharing cache) | |||
| for i in $(seq 1 2) | |||
| do | |||
| @@ -309,6 +315,9 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_nested_repeat" | |||
| HandleRcExit $? 0 0 | |||
| for i in $(seq 1 3) | |||
| do | |||
| test_name="test_cache_nomap_multiple_cache${i}" | |||
| @@ -107,49 +107,10 @@ def test_cache_map_basic2(): | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic3(): | |||
| """ | |||
| Test a repeat under mappable cache | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Repeat | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache basic 3") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.repeat(4) | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||
| logger.info("ds1.dataset_size is ", ds1.get_dataset_size()) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| logger.info("get data from dataset") | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 8 | |||
| logger.info('test_cache_basic3 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic4(): | |||
| """ | |||
| Test different rows result in core dump | |||
| """ | |||
| logger.info("Test cache basic 4") | |||
| logger.info("Test cache basic 3") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| @@ -171,11 +132,11 @@ def test_cache_map_basic4(): | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 8 | |||
| logger.info('test_cache_basic4 Ended.\n') | |||
| logger.info('test_cache_basic3 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic5(): | |||
| def test_cache_map_basic4(): | |||
| """ | |||
| Test Map with non-deterministic TensorOps above cache | |||
| @@ -188,7 +149,7 @@ def test_cache_map_basic5(): | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache failure 5") | |||
| logger.info("Test cache basic 4") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| @@ -211,11 +172,11 @@ def test_cache_map_basic5(): | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 8 | |||
| logger.info('test_cache_failure5 Ended.\n') | |||
| logger.info('test_cache_basic4 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_basic6(): | |||
| def test_cache_map_basic5(): | |||
| """ | |||
| Test cache as root node | |||
| @@ -223,7 +184,7 @@ def test_cache_map_basic6(): | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache basic 6") | |||
| logger.info("Test cache basic 5") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| @@ -239,7 +200,7 @@ def test_cache_map_basic6(): | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 2 | |||
| logger.info('test_cache_basic6 Ended.\n') | |||
| logger.info('test_cache_basic5 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| @@ -502,6 +463,7 @@ def test_cache_map_failure7(): | |||
| Generator | |||
| """ | |||
| def generator_1d(): | |||
| for i in range(64): | |||
| yield (np.array(i),) | |||
| @@ -528,6 +490,44 @@ def test_cache_map_failure7(): | |||
| logger.info('test_cache_failure7 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_failure8(): | |||
| """ | |||
| Test a repeat under mappable cache (failure) | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Repeat | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache failure 8") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.repeat(4) | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure8 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_parameter_check(): | |||
| """ | |||
| @@ -1702,6 +1702,125 @@ def test_cache_map_voc2(): | |||
| logger.info("test_cache_map_voc2 Ended.\n") | |||
| class ReverseSampler(ds.Sampler): | |||
| def __iter__(self): | |||
| for i in range(self.dataset_size - 1, -1, -1): | |||
| yield i | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_python_sampler1(): | |||
| """ | |||
| Test using a python sampler, and cache after leaf | |||
| Repeat | |||
| | | |||
| Map(decode) | |||
| | | |||
| cache | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache map python sampler1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 8 | |||
| logger.info("test_cache_map_python_sampler1 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_python_sampler2(): | |||
| """ | |||
| Test using a python sampler, and cache after map | |||
| Repeat | |||
| | | |||
| cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache map python sampler2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler()) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 8 | |||
| logger.info("test_cache_map_python_sampler2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_map_nested_repeat(): | |||
| """ | |||
| Test cache on pipeline with nested repeat ops | |||
| Repeat | |||
| | | |||
| Map(decode) | |||
| | | |||
| Repeat | |||
| | | |||
| Cache | |||
| | | |||
| ImageFolder | |||
| """ | |||
| logger.info("Test cache map nested repeat") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.repeat(4) | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"]) | |||
| ds1 = ds1.repeat(2) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| logger.info("get data from dataset") | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 16 | |||
| logger.info('test_cache_map_nested_repeat Ended.\n') | |||
| if __name__ == '__main__': | |||
| test_cache_map_basic1() | |||
| test_cache_map_basic2() | |||
| @@ -1292,6 +1292,50 @@ def test_cache_nomap_epoch_ctrl3(): | |||
| logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_epoch_ctrl4(): | |||
| """ | |||
| Test using two-loops method with repeat under cache | |||
| cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| repeat | |||
| | | |||
| TFRecord | |||
| """ | |||
| logger.info("Test cache nomap epoch ctrl4") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| ds1 = ds1.repeat(2) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| num_epoch = 5 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||
| epoch_count = 0 | |||
| for _ in range(num_epoch): | |||
| row_count = 0 | |||
| for _ in iter1: | |||
| row_count += 1 | |||
| logger.info("Number of data in ds1: {} ".format(row_count)) | |||
| assert row_count == 6 | |||
| epoch_count += 1 | |||
| assert epoch_count == num_epoch | |||
| logger.info("test_cache_nomap_epoch_ctrl4 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_multiple_cache1(): | |||
| """ | |||
| @@ -1734,6 +1778,47 @@ def test_cache_nomap_textfile2(): | |||
| logger.info("test_cache_nomap_textfile2 Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_nested_repeat(): | |||
| """ | |||
| Test cache on pipeline with nested repeat ops | |||
| Repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Repeat | |||
| | | |||
| TFRecord | |||
| """ | |||
| logger.info("Test cache nomap nested repeat") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.repeat(4) | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||
| ds1 = ds1.repeat(2) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| logger.info("get data from dataset") | |||
| num_iter += 1 | |||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||
| assert num_iter == 24 | |||
| logger.info('test_cache_nomap_nested_repeat Ended.\n') | |||
| if __name__ == '__main__': | |||
| test_cache_nomap_basic1() | |||
| test_cache_nomap_basic2() | |||
| @@ -0,0 +1,10 @@ | |||
| ~/cache/cache_admin --start | |||
| session_id=$(~/cache/cache_admin -g | awk '{print $NF}') | |||
| export SESSION_ID=${session_id} | |||
| pytest dataset/test_cache_nomap.py::test_cache_nomap_server_stop & | |||
| pid=("$!") | |||
| sleep 2 | |||
| ~/cache/cache_admin --stop | |||
| sleep 1 | |||
| wait ${pid} | |||