| @@ -1,3 +1,4 @@ | |||||
| add_subdirectory(perf EXCLUDE_FROM_ALL) | |||||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | ||||
| set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | ||||
| ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | ||||
| @@ -5,6 +6,18 @@ ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engin | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| # Try to find numa header file and its library | |||||
| find_file(NUMA_HDR NAMES "numa.h") | |||||
| if (EXISTS ${NUMA_HDR}) | |||||
| ADD_DEFINITIONS(-DNUMA_ENABLED) | |||||
| MESSAGE("Numa package found") | |||||
| endif () | |||||
| if (${CMAKE_SYSTEM_NAME} MATCHES "Linux") | |||||
| ADD_DEFINITIONS(-DCACHE_LOCAL_CLIENT) | |||||
| endif () | |||||
| add_library(engine-cache-client OBJECT | add_library(engine-cache-client OBJECT | ||||
| cache_client.cc | cache_client.cc | ||||
| cache_fbb.cc | cache_fbb.cc | ||||
| @@ -20,8 +33,13 @@ if (ENABLE_CACHE) | |||||
| ${CACHE_GRPC_SRCS} | ${CACHE_GRPC_SRCS} | ||||
| cache_grpc_server.cc | cache_grpc_server.cc | ||||
| cache_arena.cc | cache_arena.cc | ||||
| cache_hw.cc | |||||
| cache_numa.cc | |||||
| cache_pool.cc | |||||
| cache_service.cc | cache_service.cc | ||||
| cache_server.cc) | |||||
| cache_server.cc | |||||
| storage_manager.cc | |||||
| storage_container.cc) | |||||
| add_executable(cache_server cache_main.cc) | add_executable(cache_server cache_main.cc) | ||||
| target_link_libraries(cache_server | target_link_libraries(cache_server | ||||
| @@ -39,6 +57,10 @@ if (ENABLE_CACHE) | |||||
| target_link_libraries(cache_server mindspore::glog) | target_link_libraries(cache_server mindspore::glog) | ||||
| endif () | endif () | ||||
| if (EXISTS ${NUMA_HDR}) | |||||
| target_link_libraries(cache_server numa) | |||||
| endif () | |||||
| add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) | add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) | ||||
| target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES}) | target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES}) | ||||
| @@ -49,7 +71,7 @@ if (ENABLE_CACHE) | |||||
| add_dependencies(engine-cache-server generated_engine_files) | add_dependencies(engine-cache-server generated_engine_files) | ||||
| else () | else () | ||||
| ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto) | |||||
| ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PROTO_HDRS cache_grpc.proto) | |||||
| target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS}) | target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS}) | ||||
| endif () | endif () | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <sys/stat.h> | #include <sys/stat.h> | ||||
| #include <sys/wait.h> | #include <sys/wait.h> | ||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include <algorithm> | |||||
| #include <cerrno> | #include <cerrno> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <iostream> | #include <iostream> | ||||
| @@ -31,7 +32,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| const int32_t CacheAdminArgHandler::kDefaultNumWorkers = std::thread::hardware_concurrency() > 2 | |||||
| ? std::thread::hardware_concurrency() / 2 | |||||
| : 1; | |||||
| const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; | const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; | ||||
| const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; | const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; | ||||
| @@ -304,8 +307,10 @@ Status CacheAdminArgHandler::Validate() { | |||||
| } | } | ||||
| // Additional checks here | // Additional checks here | ||||
| if (num_workers_ < 1 || num_workers_ > 100) | |||||
| return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100."); | |||||
| auto max_num_workers = std::max<int32_t>(std::thread::hardware_concurrency(), 100); | |||||
| if (num_workers_ < 1 || num_workers_ > max_num_workers) | |||||
| return Status(StatusCode::kSyntaxError, | |||||
| "Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + "."); | |||||
| if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); | if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); | ||||
| if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) | if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) | ||||
| return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1"); | return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1"); | ||||
| @@ -354,13 +359,15 @@ Status CacheAdminArgHandler::RunCommand() { | |||||
| std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo(); | std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo(); | ||||
| if (!session_info.empty()) { | if (!session_info.empty()) { | ||||
| std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached" | std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached" | ||||
| << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl; | |||||
| << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::setw(10) << "Numa hit" | |||||
| << std::endl; | |||||
| for (auto curr_session : session_info) { | for (auto curr_session : session_info) { | ||||
| std::string cache_id; | std::string cache_id; | ||||
| std::string stat_mem_cached; | std::string stat_mem_cached; | ||||
| std::string stat_disk_cached; | std::string stat_disk_cached; | ||||
| std::string stat_avg_cached; | std::string stat_avg_cached; | ||||
| int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); | |||||
| std::string stat_numa_hit; | |||||
| uint32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); | |||||
| cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc); | cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc); | ||||
| stat_mem_cached = | stat_mem_cached = | ||||
| (curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached); | (curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached); | ||||
| @@ -368,10 +375,12 @@ Status CacheAdminArgHandler::RunCommand() { | |||||
| (curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached); | (curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached); | ||||
| stat_avg_cached = | stat_avg_cached = | ||||
| (curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz); | (curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz); | ||||
| stat_numa_hit = | |||||
| (curr_session.stats.num_numa_hit == 0) ? "n/a" : std::to_string(curr_session.stats.num_numa_hit); | |||||
| std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12) | std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12) | ||||
| << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached | << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached | ||||
| << std::endl; | |||||
| << std::setw(10) << stat_numa_hit << std::endl; | |||||
| } | } | ||||
| } else { | } else { | ||||
| std::cout << "No active sessions." << std::endl; | std::cout << "No active sessions." << std::endl; | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <sstream> | #include <sstream> | ||||
| #include <thread> | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| @@ -29,7 +30,7 @@ namespace dataset { | |||||
| class CacheAdminArgHandler { | class CacheAdminArgHandler { | ||||
| public: | public: | ||||
| static constexpr int32_t kDefaultNumWorkers = 32; | |||||
| static const int32_t kDefaultNumWorkers; | |||||
| static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | ||||
| static constexpr int32_t kDefaultLogLevel = 1; | static constexpr int32_t kDefaultLogLevel = 1; | ||||
| static constexpr float kMemoryCapRatio = 0.8; | static constexpr float kMemoryCapRatio = 0.8; | ||||
| @@ -17,7 +17,6 @@ | |||||
| #include <iomanip> | #include <iomanip> | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||||
| #include "minddata/dataset/engine/cache/cache_fbb.h" | #include "minddata/dataset/engine/cache/cache_fbb.h" | ||||
| #include "minddata/dataset/util/bit.h" | #include "minddata/dataset/util/bit.h" | ||||
| @@ -59,6 +58,7 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool | |||||
| : server_connection_id_(0), | : server_connection_id_(0), | ||||
| cache_mem_sz_(cache_mem_sz), | cache_mem_sz_(cache_mem_sz), | ||||
| spill_(spill), | spill_(spill), | ||||
| client_id_(-1), | |||||
| local_bypass_(false), | local_bypass_(false), | ||||
| hostname_(std::move(hostname)), | hostname_(std::move(hostname)), | ||||
| port_(port), | port_(port), | ||||
| @@ -71,6 +71,22 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool | |||||
| CacheClient::~CacheClient() { | CacheClient::~CacheClient() { | ||||
| cache_miss_keys_wp_.Set(); | cache_miss_keys_wp_.Set(); | ||||
| if (client_id_ != -1) { | |||||
| try { | |||||
| // Send a message to the server, saying I am done. | |||||
| auto rq = std::make_shared<ConnectResetRequest>(server_connection_id_, client_id_); | |||||
| Status rc = PushRequest(rq); | |||||
| if (rc.IsOk()) { | |||||
| rc = rq->Wait(); | |||||
| if (rc.IsOk()) { | |||||
| MS_LOG(INFO) << "Disconnect from server successful"; | |||||
| } | |||||
| } | |||||
| } catch (const std::exception &e) { | |||||
| // Can't do anything in destructor. So just log the error. | |||||
| MS_LOG(ERROR) << e.what(); | |||||
| } | |||||
| } | |||||
| (void)comm_->ServiceStop(); | (void)comm_->ServiceStop(); | ||||
| } | } | ||||
| @@ -85,7 +101,7 @@ void CacheClient::Print(std::ostream &out) const { | |||||
| } | } | ||||
| Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { | Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { | ||||
| auto rq = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient()); | |||||
| auto rq = std::make_shared<CacheRowRequest>(this); | |||||
| RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row)); | RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row)); | ||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | RETURN_IF_NOT_OK(PushRequest(rq)); | ||||
| RETURN_IF_NOT_OK(rq->Wait()); | RETURN_IF_NOT_OK(rq->Wait()); | ||||
| @@ -104,7 +120,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||||
| for (auto i = 0; i < num_rows; ++i) { | for (auto i = 0; i < num_rows; ++i) { | ||||
| TensorRow row; | TensorRow row; | ||||
| RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); | ||||
| arr[i] = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient()); | |||||
| arr[i] = std::make_shared<CacheRowRequest>(this); | |||||
| RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row)); | RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row)); | ||||
| RETURN_IF_NOT_OK(PushRequest(arr[i])); | RETURN_IF_NOT_OK(PushRequest(arr[i])); | ||||
| } | } | ||||
| @@ -118,7 +134,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const { | |||||
| Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const { | Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const { | ||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| auto rq = std::make_shared<BatchFetchRequest>(server_connection_id_, row_id, SupportLocalClient()); | |||||
| auto rq = std::make_shared<BatchFetchRequest>(this, row_id); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | RETURN_IF_NOT_OK(PushRequest(rq)); | ||||
| RETURN_IF_NOT_OK(rq->Wait()); | RETURN_IF_NOT_OK(rq->Wait()); | ||||
| int64_t mem_addr; | int64_t mem_addr; | ||||
| @@ -167,7 +183,7 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||||
| lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. | lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. | ||||
| CacheServiceStat stat{}; | CacheServiceStat stat{}; | ||||
| RETURN_IF_NOT_OK(GetStat(&stat)); | RETURN_IF_NOT_OK(GetStat(&stat)); | ||||
| if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) { | |||||
| if (stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase)) { | |||||
| return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); | return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -183,18 +199,16 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||||
| // Start the comm layer to receive reply | // Start the comm layer to receive reply | ||||
| RETURN_IF_NOT_OK(comm_->ServiceStart()); | RETURN_IF_NOT_OK(comm_->ServiceStart()); | ||||
| // Initiate connection | // Initiate connection | ||||
| auto rq = std::make_shared<CreateCacheRequest>(cinfo_, cache_mem_sz_, createFlag); | |||||
| auto rq = std::make_shared<CreateCacheRequest>(this, cinfo_, cache_mem_sz_, createFlag); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | RETURN_IF_NOT_OK(PushRequest(rq)); | ||||
| Status rc = rq->Wait(); | Status rc = rq->Wait(); | ||||
| if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { | |||||
| std::string cookie; | |||||
| rq->ParseResult(&server_connection_id_, &cookie); | |||||
| if (rc.IsOk()) { | |||||
| // The 1st guy creating the cache will get a cookie back. | |||||
| // But this object may be shared among pipelines and we don't want | |||||
| // overwrite it. | |||||
| cookie_ = cookie; | |||||
| } | |||||
| bool success = (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey); | |||||
| // If we get kDuplicateKey, it just means we aren't the first one to create the cache, | |||||
| // and we will continue to parse the result. | |||||
| if (rc.get_code() == StatusCode::kDuplicateKey) { | |||||
| RETURN_IF_NOT_OK(rq->PostReply()); | |||||
| } | |||||
| if (success) { | |||||
| // Attach to shared memory for local client | // Attach to shared memory for local client | ||||
| RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); | RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); | ||||
| } | } | ||||
| @@ -47,6 +47,9 @@ namespace dataset { | |||||
| class CacheClient { | class CacheClient { | ||||
| public: | public: | ||||
| friend class CacheMergeOp; | friend class CacheMergeOp; | ||||
| friend class CreateCacheRequest; | |||||
| friend class CacheRowRequest; | |||||
| friend class BatchFetchRequest; | |||||
| /// \brief A builder to help creating a CacheClient object | /// \brief A builder to help creating a CacheClient object | ||||
| class Builder { | class Builder { | ||||
| @@ -115,7 +118,7 @@ class CacheClient { | |||||
| session_id_type GetSessionId() const { return session_id_; } | session_id_type GetSessionId() const { return session_id_; } | ||||
| uint64_t GetCacheMemSz() const { return cache_mem_sz_; } | uint64_t GetCacheMemSz() const { return cache_mem_sz_; } | ||||
| bool isSpill() const { return spill_; } | bool isSpill() const { return spill_; } | ||||
| const std::string &getHostname() const { return hostname_; } | |||||
| const std::string &GetHostname() const { return hostname_; } | |||||
| int32_t GetPort() const { return port_; } | int32_t GetPort() const { return port_; } | ||||
| int32_t GetNumConnections() const { return num_connections_; } | int32_t GetNumConnections() const { return num_connections_; } | ||||
| int32_t GetPrefetchSize() const { return prefetch_size_; } | int32_t GetPrefetchSize() const { return prefetch_size_; } | ||||
| @@ -256,8 +259,10 @@ class CacheClient { | |||||
| CacheClientInfo cinfo_; | CacheClientInfo cinfo_; | ||||
| // The server_connection_id_ is the actual id we use for operations after the cache is built | // The server_connection_id_ is the actual id we use for operations after the cache is built | ||||
| connection_id_type server_connection_id_; | connection_id_type server_connection_id_; | ||||
| // Some magic cookie returned from the cache server. | |||||
| // Some magic cookie/id returned from the cache server. | |||||
| std::string cookie_; | std::string cookie_; | ||||
| int32_t client_id_; | |||||
| std::vector<int32_t> cpu_list_; | |||||
| // Comm layer | // Comm layer | ||||
| bool local_bypass_; | bool local_bypass_; | ||||
| std::string hostname_; | std::string hostname_; | ||||
| @@ -20,11 +20,6 @@ | |||||
| /// both client and server side codes. Do not put code that is not common here. | /// both client and server side codes. Do not put code that is not common here. | ||||
| /// There are client and server specific header files. | /// There are client and server specific header files. | ||||
| // On platform like Windows, we may support only tcp/ip clients | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #define CACHE_LOCAL_CLIENT 1 | |||||
| #endif | |||||
| #ifdef ENABLE_CACHE | #ifdef ENABLE_CACHE | ||||
| #include <grpcpp/grpcpp.h> | #include <grpcpp/grpcpp.h> | ||||
| #endif | #endif | ||||
| @@ -50,6 +45,9 @@ constexpr static uint32_t kDataIsInSharedMemory = 2; | |||||
| /// \brief Size of each message used in message queue. | /// \brief Size of each message used in message queue. | ||||
| constexpr static int32_t kSharedMessageSize = 2048; | constexpr static int32_t kSharedMessageSize = 2048; | ||||
| /// \brief State of CacheService at the server. | |||||
| enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; | |||||
| /// \brief Convert a Status object into a protobuf | /// \brief Convert a Status object into a protobuf | ||||
| /// \param rc[in] Status object | /// \param rc[in] Status object | ||||
| /// \param reply[in/out] pointer to pre-allocated protobuf object | /// \param reply[in/out] pointer to pre-allocated protobuf object | ||||
| @@ -61,6 +59,22 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) { | |||||
| /// \param port | /// \param port | ||||
| /// \return unix socket url | /// \return unix socket url | ||||
| inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } | inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } | ||||
| /// \brief Round up to the next 4k | |||||
| inline int64_t round_up_4K(int64_t sz) { | |||||
| // Since 4096 is a power of 2, a simple way to round up is add 4095 and mask off all the | |||||
| // bits of 4095 | |||||
| return static_cast<uint64_t>(sz + 4095) & ~static_cast<uint64_t>(4095); | |||||
| } | |||||
| /// Memory policy | |||||
| enum CachePoolPolicy : int8_t { kOnNode, kPreferred, kLocal, kInterleave, kNone }; | |||||
| /// Misc typedef | |||||
| using worker_id_t = int32_t; | |||||
| using numa_id_t = int32_t; | |||||
| using cpu_id_t = int32_t; | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | ||||
| @@ -32,12 +32,13 @@ message CacheRequest { | |||||
| uint32 flag = 2; | uint32 flag = 2; | ||||
| oneof connect_info { | oneof connect_info { | ||||
| // The server_connection_id is the actual id we use for operations after the cache is built | // The server_connection_id is the actual id we use for operations after the cache is built | ||||
| int64 connection_id = 3; | |||||
| uint64 connection_id = 3; | |||||
| // But some request like CreateCache we have to use the session id and crc to connect to the server. | // But some request like CreateCache we have to use the session id and crc to connect to the server. | ||||
| CacheClientInfo connection_info = 4; | CacheClientInfo connection_info = 4; | ||||
| } | } | ||||
| int32 client_id = 5; | |||||
| // Everything else is just vector of buffers | // Everything else is just vector of buffers | ||||
| repeated bytes buf_data = 5; | |||||
| repeated bytes buf_data = 6; | |||||
| } | } | ||||
| message CacheReply { | message CacheReply { | ||||
| @@ -74,6 +74,9 @@ Status CacheServerGreeterImpl::Run() { | |||||
| #if CACHE_LOCAL_CLIENT | #if CACHE_LOCAL_CLIENT | ||||
| RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); | ||||
| MS_LOG(INFO) << "Creation of local socket and shared memory successful"; | MS_LOG(INFO) << "Creation of local socket and shared memory successful"; | ||||
| auto cs = CacheServer::GetInstance().GetHWControl(); | |||||
| // This shared memory is a hot memory and we will interleave among all the numa nodes. | |||||
| cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); | |||||
| #endif | #endif | ||||
| } else { | } else { | ||||
| std::string errMsg = "Fail to start server. "; | std::string errMsg = "Fail to start server. "; | ||||
| @@ -127,8 +130,13 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp | |||||
| st_ = STATE::PROCESS; | st_ = STATE::PROCESS; | ||||
| svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this); | svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this); | ||||
| } else if (st_ == STATE::PROCESS) { | } else if (st_ == STATE::PROCESS) { | ||||
| auto &cs = CacheServer::GetInstance(); | |||||
| // Get a new tag and handle the next request before we serve the current request. | // Get a new tag and handle the next request before we serve the current request. | ||||
| // The tag will be recycled when its state is changed to FINISH | |||||
| // The tag will be recycled when its state is changed to FINISH. | |||||
| // The number of free list queues is the same as the number of grpc threads. | |||||
| // Where we get the free list it doesn't matter (as long we return it back to the right queue). | |||||
| // We can round robin, use the qid or even use the worker id. We will use the free list queue | |||||
| // where the current request comes from. | |||||
| CacheServerRequest *next_rq; | CacheServerRequest *next_rq; | ||||
| RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq)); | RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq)); | ||||
| RETURN_IF_NOT_OK((*next_rq)(svc, cq)); | RETURN_IF_NOT_OK((*next_rq)(svc, cq)); | ||||
| @@ -138,8 +146,24 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp | |||||
| type_ = static_cast<RequestType>(rq_.type()); | type_ = static_cast<RequestType>(rq_.type()); | ||||
| // Now we pass the address of this instance to CacheServer's main loop. | // Now we pass the address of this instance to CacheServer's main loop. | ||||
| MS_LOG(DEBUG) << "Handle request " << *this; | MS_LOG(DEBUG) << "Handle request " << *this; | ||||
| auto &cs = CacheServer::GetInstance(); | |||||
| RETURN_IF_NOT_OK(cs.PushRequest(myQID, this)); | |||||
| // We will distribute the request evenly (or randomly) over all the numa nodes. | |||||
| // The exception is BatchFetch which we need to pre-process here. | |||||
| if (type_ == BaseRequest::RequestType::kBatchFetchRows) { | |||||
| rc_ = cs.BatchFetchRows(&rq_, &reply_); | |||||
| if (!rc_.IsInterrupted()) { | |||||
| Status2CacheReply(rc_, &reply_); | |||||
| st_ = CacheServerRequest::STATE::FINISH; | |||||
| responder_.Finish(reply_, grpc::Status::OK, this); | |||||
| } else { | |||||
| return rc_; | |||||
| } | |||||
| } else { | |||||
| // When the number of grpc workers is the same as the server workers, we will use this queue id | |||||
| // and push to the corresponding queue. | |||||
| bool random = cs.GetNumWorkers() != cs.GetNumGrpcWorkers(); | |||||
| worker_id_t worker_id = random ? cs.GetRandomWorker() : myQID; | |||||
| RETURN_IF_NOT_OK(cs.PushRequest(worker_id, this)); | |||||
| } | |||||
| } else if (st_ == STATE::FINISH) { | } else if (st_ == STATE::FINISH) { | ||||
| MS_LOG(DEBUG) << *this << " Finished."; | MS_LOG(DEBUG) << *this << " Finished."; | ||||
| // Return back to the free list. | // Return back to the free list. | ||||
| @@ -16,6 +16,7 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ | ||||
| #include <atomic> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -34,6 +35,7 @@ namespace dataset { | |||||
| class CacheServerRequest : public BaseRequest { | class CacheServerRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| friend class CacheService; | |||||
| enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | ||||
| explicit CacheServerRequest(int32_t queue_id) | explicit CacheServerRequest(int32_t queue_id) | ||||
| : BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown), | : BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown), | ||||
| @@ -0,0 +1,220 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||||
| #ifdef NUMA_ENABLED | |||||
| #include <numa.h> | |||||
| #endif | |||||
| #include <sched.h> | |||||
| #include <cstdlib> | |||||
| #include <cstring> | |||||
| #include <cctype> | |||||
| #include <fstream> | |||||
| #include <regex> | |||||
| #include <thread> | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| CacheServerHW::CacheServerHW() { | |||||
| num_cpus_ = std::thread::hardware_concurrency(); | |||||
| MS_LOG(DEBUG) << "Number of cpu(s) : " << num_cpus_; | |||||
| #ifdef NUMA_ENABLED | |||||
| if (numa_enabled()) { | |||||
| MS_LOG(WARNING) << "Numa support enabled"; | |||||
| for (auto i = 0; i <= numa_max_node(); ++i) { | |||||
| int64_t free_avail; | |||||
| int64_t mem_avail = numa_node_size(i, &free_avail); | |||||
| MS_LOG(INFO) << "Total physical/free RAM in bytes at node " << i << " : " << mem_avail << "/" << free_avail; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| int64_t CacheServerHW::GetTotalSystemMemory() { | |||||
| auto pages = sysconf(_SC_PHYS_PAGES); | |||||
| auto page_size = sysconf(_SC_PAGE_SIZE); | |||||
| auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size); | |||||
| MS_LOG(INFO) << "Total physical RAM in bytes: " << total; | |||||
| return total; | |||||
| } | |||||
| Status CacheServerHW::SetDefaultMemoryPolicy(CachePoolPolicy policy) { | |||||
| #ifdef NUMA_ENABLED | |||||
| if (numa_enabled()) { | |||||
| // Set our default memory policy. | |||||
| switch (policy) { | |||||
| case kLocal: | |||||
| numa_set_localalloc(); | |||||
| MS_LOG(DEBUG) << "Setting memory default policy to local node. Low level code may override the setting"; | |||||
| break; | |||||
| case kInterleave: | |||||
| numa_set_interleave_mask(numa_all_nodes_ptr); | |||||
| MS_LOG(DEBUG) << "Numa affinity is turned off. Use interleave memory policy as default."; | |||||
| break; | |||||
| case kOnNode: | |||||
| case kPreferred: | |||||
| RETURN_STATUS_UNEXPECTED("Unsupported memory policy"); | |||||
| break; | |||||
| case kNone: | |||||
| default: | |||||
| // No action taken. | |||||
| break; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServerHW::GetNumaNodeInfo() { | |||||
| std::set<Path> numa_nodes_; | |||||
| Path node(kSysNodePath); | |||||
| auto it = Path::DirIterator::OpenDirectory(&node); | |||||
| if (it == nullptr) { | |||||
| MS_LOG(WARNING) << "Unable to open directory " << kSysNodePath << ". Skip scanning hardware info"; | |||||
| return Status::OK(); | |||||
| } | |||||
| auto isdigit_string = [](const char *str) -> bool { | |||||
| bool r = true; | |||||
| for (auto i = 0; i < strlen(str); ++i) { | |||||
| if (!std::isdigit(str[i])) { | |||||
| r = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| return r; | |||||
| }; | |||||
| // Look for name starts with 'node' and followed by digits. | |||||
| const char kNodeName[] = "node"; | |||||
| while (it->hasNext()) { | |||||
| auto p = it->next(); | |||||
| const std::string entry = p.Basename(); | |||||
| const char *name = entry.data(); | |||||
| if (strncmp(name, kNodeName, 4) == 0 && isdigit_string(name + strlen(kNodeName))) { | |||||
| numa_nodes_.insert(p); | |||||
| } | |||||
| } | |||||
| // There should be at least one. But if not found in any case, just move on the | |||||
| // rest of the server start up. | |||||
| if (numa_nodes_.empty()) { | |||||
| MS_LOG(WARNING) << "No numa nodes ? Skip scanning hardware info"; | |||||
| return Status::OK(); | |||||
| } | |||||
| // For each numa node, get a list of CPU that is associated with it. | |||||
| const char kCpuList[] = "cpulist"; | |||||
| auto r = std::regex("[0-9]*-[0-9]*"); | |||||
| for (Path p : numa_nodes_) { | |||||
| auto node_dir = p.Basename().data(); | |||||
| numa_id_t numa_node = strtol(node_dir + strlen(kNodeName), nullptr, 10); | |||||
| Path f = p / kCpuList; | |||||
| std::ifstream fs(f.toString()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + f.toString()); | |||||
| std::string cpu_string; | |||||
| cpu_set_t cpuset; | |||||
| CPU_ZERO(&cpuset); | |||||
| int32_t cpu_cnt = 0; | |||||
| while (getline(fs, cpu_string)) { | |||||
| // Now we parse the content of cpu_string. | |||||
| std::sregex_iterator iter(cpu_string.begin(), cpu_string.end(), r); | |||||
| std::sregex_iterator end; | |||||
| while (iter != end) { | |||||
| auto match = iter->str(); | |||||
| auto pos = match.find_first_of('-'); | |||||
| std::string min = match.substr(0, pos); | |||||
| std::string max = match.substr(pos + 1); | |||||
| cpu_id_t cpu_min = strtol(min.data(), nullptr, 10); | |||||
| cpu_id_t cpu_max = strtol(max.data(), nullptr, 10); | |||||
| MS_LOG(DEBUG) << "Numa node " << numa_node << " CPU(s) : " << cpu_min << "-" << cpu_max; | |||||
| for (int i = cpu_min; i <= cpu_max; ++i) { | |||||
| CPU_SET(i, &cpuset); | |||||
| ++cpu_cnt; | |||||
| } | |||||
| ++iter; | |||||
| } | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!fs.bad(), "Fail to read file: " + f.toString()); | |||||
| fs.close(); | |||||
| // Remember which cpu is attached to this numa node. | |||||
| numa_cpuset_.emplace(numa_node, cpuset); | |||||
| numa_cpu_cnt_.emplace(numa_node, cpu_cnt); | |||||
| } | |||||
| MS_LOG(DEBUG) << "Number of numa nodes : " << numa_cpuset_.size(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServerHW::SetAffinity(const Task &tk, numa_id_t numa_node) { | |||||
| auto r = numa_cpuset_.find(numa_node); | |||||
| if (r != numa_cpuset_.end()) { | |||||
| auto err = pthread_setaffinity_np(tk.GetNativeHandle(), sizeof(r->second), &r->second); | |||||
| if (err) { | |||||
| std::string errMsg = "Unable to set affiity. Errno = " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Numa node " + std::to_string(numa_node) + " not found"); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<cpu_id_t> CacheServerHW::GetCpuList(numa_id_t numa_id) { | |||||
| std::vector<cpu_id_t> v; | |||||
| auto it = numa_cpuset_.find(numa_id); | |||||
| if (it != numa_cpuset_.end()) { | |||||
| auto &cpu_set = it->second; | |||||
| for (auto i = 0; i < num_cpus_; ++i) { | |||||
| if (CPU_ISSET(i, &cpu_set)) { | |||||
| v.push_back(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| return v; | |||||
| } | |||||
| numa_id_t CacheServerHW::GetMyNode() const { | |||||
| numa_id_t node_id = 0; | |||||
| auto cpu = sched_getcpu(); | |||||
| #ifdef NUMA_ENABLED | |||||
| node_id = numa_node_of_cpu(cpu); | |||||
| #else | |||||
| bool found = false; | |||||
| for (auto it : numa_cpuset_) { | |||||
| cpu_set_t &cpu_set = it.second; | |||||
| if (CPU_ISSET(cpu, &cpu_set)) { | |||||
| node_id = it.first; | |||||
| found = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "cpu id " << cpu << " found : " << std::boolalpha << found; | |||||
| #endif | |||||
| return node_id; | |||||
| } | |||||
| void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) { | |||||
| #ifdef NUMA_ENABLED | |||||
| if (numa_enabled()) { | |||||
| numa_interleave_memory(ptr, sz, numa_all_nodes_ptr); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| bool CacheServerHW::numa_enabled() { | |||||
| #ifdef NUMA_ENABLED | |||||
| return (numa_available() != -1); | |||||
| #else | |||||
| return false; | |||||
| #endif | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ | |||||
| #ifdef NUMA_ENABLED | |||||
| #include <numa.h> | |||||
| #endif | |||||
| #include <sched.h> | |||||
| #include <stdlib.h> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| #include "minddata/dataset/util/memory_pool.h" | |||||
| #include "minddata/dataset/util/path.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/util/task.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class CacheServerHW { | |||||
| public: | |||||
| CacheServerHW(); | |||||
| ~CacheServerHW() = default; | |||||
| /// \brief Get Numa node info without using numa library | |||||
| /// \return Status object | |||||
| Status GetNumaNodeInfo(); | |||||
| /// \brief Set thread affinity | |||||
| Status SetAffinity(const Task &tk, numa_id_t numa_node); | |||||
| /// \brief Get total number of cpu(s) | |||||
| int32_t GetCpuCount() const { return num_cpus_; } | |||||
| /// \brief Get total number of numa nodes | |||||
| int32_t GetNumaNodeCount() const { return numa_cpuset_.empty() ? 1 : numa_cpuset_.size(); } | |||||
| /// \brief Get a list of cpu for a given numa node. | |||||
| std::vector<cpu_id_t> GetCpuList(numa_id_t numa_id); | |||||
| static bool numa_enabled(); | |||||
| /// \brief Return the numa the current thread is running on. | |||||
| numa_id_t GetMyNode() const; | |||||
| /// \brief Interleave a given memory block. Used by shared memory only. | |||||
| static void InterleaveMemory(void *ptr, size_t sz); | |||||
| /// \brief Set default memory policy. | |||||
| static Status SetDefaultMemoryPolicy(CachePoolPolicy); | |||||
| /// \brief This returns the size (in bytes) of the physical RAM on the machine. | |||||
| /// \return the size (in bytes) of the physical RAM on the machine. | |||||
| static int64_t GetTotalSystemMemory(); | |||||
| private: | |||||
| constexpr static char kSysNodePath[] = "/sys/devices/system/node"; | |||||
| int32_t num_cpus_; | |||||
| std::map<numa_id_t, cpu_set_t> numa_cpuset_; | |||||
| std::map<numa_id_t, int32_t> numa_cpu_cnt_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ | |||||
| @@ -54,6 +54,8 @@ ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds:: | |||||
| #endif | #endif | ||||
| try { | try { | ||||
| rq->set_type(static_cast<int16_t>(type)); | rq->set_type(static_cast<int16_t>(type)); | ||||
| rq->set_client_id(-1); | |||||
| rq->set_flag(0); | |||||
| grpc::ChannelArguments args; | grpc::ChannelArguments args; | ||||
| grpc::ClientContext ctx; | grpc::ClientContext ctx; | ||||
| grpc::CompletionQueue cq; | grpc::CompletionQueue cq; | ||||
| @@ -0,0 +1,224 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include <iterator> | |||||
| #include <limits> | |||||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| NumaMemoryPool::NumaMemoryPool(std::shared_ptr<CacheServerHW> hw, float memory_cap_ratio) | |||||
| : hw_(std::move(hw)), memory_cap_ratio_(memory_cap_ratio) { | |||||
| int64_t total_avail = 0; | |||||
| // We will create a number of small Arenas to spread out the server threads so it | |||||
| // will be less contention. If we link with the numa library, i.e. if | |||||
| // NUMA_ENABLED is defined, we will make use of the low level numa library such that | |||||
| // each Arena solely comes from one particular socket. | |||||
| // The total number of Arenas will be controlled under the number of cpus. | |||||
| auto num_cpus = hw_->GetCpuCount(); | |||||
| memory_segments_.reserve(num_cpus); | |||||
| arena_list_.reserve(num_cpus); | |||||
| mux_ = std::make_unique<std::mutex[]>(num_cpus); | |||||
| auto num_memory_nodes = num_cpus; | |||||
| int64_t max_avail = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; | |||||
| int64_t arena_sz = max_avail / num_memory_nodes; | |||||
| // If arena_sz is too small, lower the number of Arenas. | |||||
| if (arena_sz < std::numeric_limits<int32_t>::max()) { | |||||
| arena_sz = round_up_4K(std::numeric_limits<int32_t>::max()); | |||||
| num_memory_nodes = max_avail / arena_sz; | |||||
| if (num_memory_nodes == 0) { | |||||
| num_memory_nodes = 1; | |||||
| arena_sz = max_avail; | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "Creating " << num_memory_nodes << " number of arena. Each one of size " << arena_sz; | |||||
| #ifdef NUMA_ENABLED | |||||
| if (numa_available() != -1) { | |||||
| auto num_numa_nodes = hw_->GetNumaNodeCount(); | |||||
| numa_id_t node_id = 0; | |||||
| for (auto i = 0; i < num_memory_nodes; ++i) { | |||||
| auto success = CreateMultipleArenas(arena_sz, node_id++ % num_numa_nodes, 1); | |||||
| total_avail += success * arena_sz; | |||||
| } | |||||
| } else { | |||||
| auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes); | |||||
| total_avail += success * arena_sz; | |||||
| } | |||||
| #else | |||||
| auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes); | |||||
| total_avail += success * arena_sz; | |||||
| #endif | |||||
| memory_cap_ = total_avail; | |||||
| MS_LOG(WARNING) << "Memory pool created. Total available memory " << memory_cap_ << " spread in " << nodes_.size() | |||||
| << " arenas"; | |||||
| int32_t slot = 0; | |||||
| // Set up a map for future easy access. | |||||
| for (auto node_id : nodes_) { | |||||
| numa_map_[node_id].push_back(slot); | |||||
| ++slot; | |||||
| } | |||||
| } | |||||
| int32_t NumaMemoryPool::CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count) { | |||||
| int32_t success = 0; | |||||
| for (auto i = 0; i < repeat_count; ++i) { | |||||
| #ifdef NUMA_ENABLED | |||||
| void *ptr = numa_alloc_onnode(segment_sz, node_id); | |||||
| #else | |||||
| void *ptr = malloc(segment_sz); | |||||
| #endif | |||||
| if (ptr != nullptr) { | |||||
| memory_segments_.emplace_back(ptr, segment_sz); | |||||
| arena_list_.push_back(std::make_unique<ArenaImpl>(ptr, segment_sz)); | |||||
| nodes_.push_back(node_id); | |||||
| ++success; | |||||
| } else { | |||||
| // Skip the rest. | |||||
| break; | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "Allocate " << success << " arenas from node " << node_id; | |||||
| return success; | |||||
| } | |||||
| NumaMemoryPool::~NumaMemoryPool() { | |||||
| if (!memory_segments_.empty()) { | |||||
| for (auto &s : memory_segments_) { | |||||
| #ifdef NUMA_ENABLED | |||||
| numa_free(s.first, s.second); | |||||
| #else | |||||
| free(s.first); | |||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||
| Status NumaMemoryPool::Allocate(size_t n, void **p) { | |||||
| RETURN_UNEXPECTED_IF_NULL(p); | |||||
| auto mt = GetRandomDevice(); | |||||
| Status rc; | |||||
| void *ptr = nullptr; | |||||
| auto num_segments = memory_segments_.size(); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_segments > 0, "No numa nodes available"); | |||||
| if (NumaAware()) { | |||||
| auto num_numa_nodes = hw_->GetNumaNodeCount(); | |||||
| // We will start from the numa node this worker id is running on and do a round robin search. | |||||
| numa_id_t start = hw_->GetMyNode(); | |||||
| numa_id_t node_id = start; | |||||
| do { | |||||
| auto it = numa_map_.find(node_id); | |||||
| if (it != numa_map_.end()) { | |||||
| auto &slots = it->second; | |||||
| auto num_slots = slots.size(); | |||||
| std::uniform_int_distribution<int32_t> distribution(0, num_slots - 1); | |||||
| auto start_slot = distribution(mt); | |||||
| int32_t inx = start_slot; | |||||
| do { | |||||
| int32_t k = slots.at(inx); | |||||
| std::unique_lock lock_x(mux_[k]); | |||||
| auto &impl = arena_list_.at(k); | |||||
| rc = impl->Allocate(n, &ptr); | |||||
| if (rc.IsOk()) { | |||||
| *p = ptr; | |||||
| break; | |||||
| } else if (rc.IsOutofMemory()) { | |||||
| inx = (inx + 1) % num_slots; | |||||
| } else { | |||||
| return rc; | |||||
| } | |||||
| } while (inx != start_slot); | |||||
| } | |||||
| // We have done searching for this numa node. If not found, move to the next node. | |||||
| if (ptr == nullptr) { | |||||
| node_id = (node_id + 1) % num_numa_nodes; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } while (node_id != start); | |||||
| } else { | |||||
| // If not numa aware, just randomly pick a slot. | |||||
| std::uniform_int_distribution<int32_t> distribution(0, num_segments - 1); | |||||
| auto start_slot = distribution(mt); | |||||
| int32_t slot = start_slot; | |||||
| do { | |||||
| std::unique_lock lock_x(mux_[slot]); | |||||
| auto &impl = arena_list_.at(slot); | |||||
| rc = impl->Allocate(n, &ptr); | |||||
| if (rc.IsOk()) { | |||||
| *p = ptr; | |||||
| break; | |||||
| } else if (rc.IsOutofMemory()) { | |||||
| // Make the next arena and continue. | |||||
| slot = (slot + 1) % num_segments; | |||||
| } else { | |||||
| return rc; | |||||
| } | |||||
| } while (slot != start_slot); | |||||
| } | |||||
| // Handle the case we have done one round robin search. | |||||
| if (ptr == nullptr) { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| } | |||||
| return rc; | |||||
| } | |||||
| void NumaMemoryPool::Deallocate(void *p) { | |||||
| // Find out which numa slot it comes from. | |||||
| auto slot = Locate(p); | |||||
| MS_ASSERT(slot != -1); | |||||
| std::unique_lock lock_x(mux_[slot]); | |||||
| auto &impl = arena_list_.at(slot); | |||||
| impl->Deallocate(p); | |||||
| } | |||||
| int NumaMemoryPool::PercentFree() const { | |||||
| int percent_free = 0; | |||||
| int num_arena = 0; | |||||
| for (auto const &p : arena_list_) { | |||||
| percent_free += p->PercentFree(); | |||||
| num_arena++; | |||||
| } | |||||
| if (num_arena) { | |||||
| return percent_free / num_arena; | |||||
| } else { | |||||
| return 100; | |||||
| } | |||||
| } | |||||
| int32_t NumaMemoryPool::Locate(void *p) const { | |||||
| int32_t slot = 0; | |||||
| char *mem = reinterpret_cast<char *>(p); | |||||
| for (slot = 0; slot < memory_segments_.size(); ++slot) { | |||||
| auto elem = memory_segments_.at(slot); | |||||
| char *q = reinterpret_cast<char *>(elem.first); | |||||
| if (mem >= q && mem < q + elem.second) { | |||||
| return slot; | |||||
| } | |||||
| } | |||||
| return -1; | |||||
| } | |||||
| std::vector<numa_id_t> NumaMemoryPool::GetAvailableNodes() const { | |||||
| std::vector<numa_id_t> v; | |||||
| std::transform(numa_map_.begin(), numa_map_.end(), std::back_inserter(v), | |||||
| [](const std::pair<numa_id_t, std::vector<int32_t>> &v) { return v.first; }); | |||||
| return v; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,195 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ | |||||
| #include <limits> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||||
| #include "minddata/dataset/util/arena.h" | |||||
| #include "minddata/dataset/util/memory_pool.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief An allocator but for a particular numa node. | |||||
| template <typename T> | |||||
| class NumaAllocator { | |||||
| public: | |||||
| explicit NumaAllocator(numa_id_t node_id, CachePoolPolicy policy) | |||||
| : policy_(policy), numa_enabled_(false), node_id_(node_id) { | |||||
| #ifdef NUMA_ENABLED | |||||
| numa_enabled_ = numa_available() != -1; | |||||
| #endif | |||||
| } | |||||
| ~NumaAllocator() = default; | |||||
| template <typename U> | |||||
| explicit NumaAllocator(NumaAllocator<U> const &rhs) | |||||
| : policy_(rhs.policy_), numa_enabled_(rhs.numa_enabled_), node_id_(rhs.node_id_) {} | |||||
| template <typename U> | |||||
| bool operator==(Allocator<U> const &rhs) const { | |||||
| return node_id_ == rhs.node_id_; | |||||
| } | |||||
| template <typename U> | |||||
| bool operator!=(Allocator<U> const &rhs) const { | |||||
| return node_id_ != rhs.node_id_; | |||||
| } | |||||
| template <typename U> | |||||
| friend class NumaAllocator; | |||||
| using value_type = T; | |||||
| using pointer = T *; | |||||
| using const_pointer = const T *; | |||||
| using reference = T &; | |||||
| using const_reference = const T &; | |||||
| using size_type = uint64_t; | |||||
| using difference_type = std::ptrdiff_t; | |||||
| template <typename U> | |||||
| struct rebind { | |||||
| using other = Allocator<U>; | |||||
| }; | |||||
| using propagate_on_container_copy_assignment = std::true_type; | |||||
| using propagate_on_container_move_assignment = std::true_type; | |||||
| using propagate_on_container_swap = std::true_type; | |||||
| /// Allocate memory on this node only. Return nullptr if no memory on this numa node. | |||||
| /// \note. This version will not throw if we can't allocate memory from this node. | |||||
| /// User must check if the pointer returned is null or not. | |||||
| pointer allocate(std::size_t n) noexcept { | |||||
| auto sz = n * sizeof(T); | |||||
| void *p = nullptr; | |||||
| #ifdef NUMA_ENABLED | |||||
| if (numa_enabled_) { | |||||
| switch (policy_) { | |||||
| case kPreferred: | |||||
| numa_set_preferred(node_id_); | |||||
| p = numa_alloc(sz); | |||||
| break; | |||||
| case kLocal: | |||||
| p = numa_alloc_local(sz); | |||||
| break; | |||||
| case kInterleave: | |||||
| p = numa_alloc_interleaved(sz); | |||||
| break; | |||||
| case kOnNode: | |||||
| p = numa_alloc_onnode(sz, node_id_); | |||||
| break; | |||||
| case kNone: | |||||
| default: | |||||
| p = numa_alloc(sz); | |||||
| break; | |||||
| } | |||||
| } else { | |||||
| p = malloc(sz); | |||||
| } | |||||
| #else | |||||
| p = malloc(sz); | |||||
| #endif | |||||
| return reinterpret_cast<pointer>(p); | |||||
| } | |||||
| /// Free a memory allocated on this node. | |||||
| void deallocate(pointer p, std::size_t n) noexcept { | |||||
| #ifdef NUMA_ENABLED | |||||
| if (numa_enabled_) { | |||||
| numa_free(p, n * sizeof(T)); | |||||
| } else { | |||||
| free(p); | |||||
| } | |||||
| #else | |||||
| free(p); | |||||
| #endif | |||||
| } | |||||
| /// \brief Allow one to change to another numa node | |||||
| void SetNodeId(numa_id_t node_id) { node_id_ = node_id; } | |||||
| /// \brif Getter for node_id; | |||||
| numa_id_t GetNodeId() const { return node_id_; } | |||||
| /// \brief Getter for policy | |||||
| CachePoolPolicy GetPolicy() const { return policy_; } | |||||
| private: | |||||
| CachePoolPolicy policy_; | |||||
| bool numa_enabled_; | |||||
| numa_id_t node_id_; | |||||
| }; | |||||
| /// \brief A NumaMemoryPool is like a CircularPool but all the arenas have already been allocated | |||||
| /// and each one comes from a numa socket. Memory is allocated using OnNode policy. That is, | |||||
| /// it is solely comes from one particular numa node, and is not interleaved. | |||||
| class NumaMemoryPool : public MemoryPool { | |||||
| public: | |||||
| explicit NumaMemoryPool(std::shared_ptr<CacheServerHW> hw, float memory_cap_ratio); | |||||
| ~NumaMemoryPool() override; | |||||
| // As a derived class, we override the following functions | |||||
| Status Allocate(size_t size, void **pVoid) override; | |||||
| void Deallocate(void *pVoid) override; | |||||
| Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override { RETURN_STATUS_UNEXPECTED("Not supported"); } | |||||
| uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); } | |||||
| int PercentFree() const override; | |||||
| /// \brief Return if the memory pool is numa aware | |||||
| bool NumaAware() const { return CacheServerHW::numa_enabled(); } | |||||
| /// \brief. This returns all the numa nodes that we are able to allocate memory from. | |||||
| std::vector<numa_id_t> GetAvailableNodes() const; | |||||
| /// \brief. Given a pointer (allocated from this pool), return the numa node where it is located. | |||||
| /// \note. -1 is returned if not found. | |||||
| numa_id_t FindNode(void *p) const { | |||||
| auto slot = Locate(p); | |||||
| if (slot != -1) { | |||||
| return nodes_.at(slot); | |||||
| } else { | |||||
| return -1; | |||||
| } | |||||
| } | |||||
| /// \brief Return maximum available memory | |||||
| int64_t GetAvailableMemory() const { return memory_cap_; } | |||||
| private: | |||||
| std::shared_ptr<CacheServerHW> hw_; | |||||
| float memory_cap_ratio_; | |||||
| int64_t memory_cap_; | |||||
| std::vector<std::pair<void *, int64_t>> memory_segments_; | |||||
| std::vector<std::unique_ptr<ArenaImpl>> arena_list_; | |||||
| std::unique_ptr<std::mutex[]> mux_; | |||||
| std::vector<numa_id_t> nodes_; | |||||
| std::map<numa_id_t, std::vector<int32_t>> numa_map_; | |||||
| /// \brief. Returns the slot that a given memory comes from. | |||||
| /// \return slot from numa_segments. -1 if not found. | |||||
| int32_t Locate(void *p) const; | |||||
| /// If numa library is not linked, or numa_availble() return -1, we will fall back to this method. | |||||
| int32_t CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count); | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ | |||||
| @@ -15,18 +15,14 @@ | |||||
| */ | */ | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "minddata/dataset/util/cache_pool.h" | |||||
| #include "minddata/dataset/engine/cache/cache_pool.h" | |||||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||||
| #include "minddata/dataset/util/services.h" | #include "minddata/dataset/util/services.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root) | |||||
| : alloc_(alloc), | |||||
| root_(root), | |||||
| subfolder_(Services::GetUniqueID()), | |||||
| sm_(nullptr), | |||||
| tree_(nullptr), | |||||
| custom_arena_(ourOwnArena) {} | |||||
| CachePool::CachePool(std::shared_ptr<NumaMemoryPool> mp, const std::string &root) | |||||
| : mp_(std::move(mp)), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} | |||||
| Status CachePool::DoServiceStart() { | Status CachePool::DoServiceStart() { | ||||
| tree_ = std::make_shared<data_index>(); | tree_ = std::make_shared<data_index>(); | ||||
| @@ -36,10 +32,11 @@ Status CachePool::DoServiceStart() { | |||||
| RETURN_IF_NOT_OK(spill.CreateDirectories()); | RETURN_IF_NOT_OK(spill.CreateDirectories()); | ||||
| sm_ = std::make_shared<StorageManager>(spill); | sm_ = std::make_shared<StorageManager>(spill); | ||||
| RETURN_IF_NOT_OK(sm_->ServiceStart()); | RETURN_IF_NOT_OK(sm_->ServiceStart()); | ||||
| MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); | |||||
| MS_LOG(INFO) << "CachePool will use disk folder: " << spill.toString(); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CachePool::DoServiceStop() { | Status CachePool::DoServiceStop() { | ||||
| Status rc; | Status rc; | ||||
| Status rc2; | Status rc2; | ||||
| @@ -50,14 +47,14 @@ Status CachePool::DoServiceStop() { | |||||
| } | } | ||||
| } | } | ||||
| sm_.reset(); | sm_.reset(); | ||||
| // If it is our own arena, skip freeing individual pieces. | |||||
| if (!custom_arena_) { | |||||
| for (auto &bl : *tree_) { | |||||
| if (bl.ptr != nullptr) { | |||||
| alloc_.deallocate(bl.ptr, bl.sz); | |||||
| } | |||||
| value_allocator alloc(mp_); | |||||
| for (auto &bl : *tree_) { | |||||
| if (bl.ptr != nullptr) { | |||||
| alloc.deallocate(bl.ptr, bl.sz); | |||||
| } | } | ||||
| } | } | ||||
| tree_.reset(); | tree_.reset(); | ||||
| if (!root_.toString().empty()) { | if (!root_.toString().empty()) { | ||||
| Path spill = GetSpillPath(); | Path spill = GetSpillPath(); | ||||
| @@ -75,8 +72,10 @@ Status CachePool::DoServiceStop() { | |||||
| } | } | ||||
| return rc2; | return rc2; | ||||
| } | } | ||||
| CachePool::~CachePool() noexcept { (void)ServiceStop(); } | CachePool::~CachePool() noexcept { (void)ServiceStop(); } | ||||
| Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly) { | |||||
| Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf) { | |||||
| DataLocator bl; | DataLocator bl; | ||||
| Status rc; | Status rc; | ||||
| size_t sz = 0; | size_t sz = 0; | ||||
| @@ -85,26 +84,35 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic | |||||
| sz += v.GetSize(); | sz += v.GetSize(); | ||||
| } | } | ||||
| bl.sz = sz; | bl.sz = sz; | ||||
| try { | |||||
| if (!writeToDiskDirectly) { | |||||
| bl.ptr = alloc_.allocate(sz); | |||||
| // We will do a piecewise copy. | |||||
| WritableSlice dest(bl.ptr, bl.sz); | |||||
| size_t pos = 0; | |||||
| for (auto &v : buf) { | |||||
| WritableSlice out(dest, pos); | |||||
| rc = WritableSlice::Copy(&out, v); | |||||
| if (rc.IsError()) { | |||||
| break; | |||||
| } | |||||
| pos += v.GetSize(); | |||||
| } | |||||
| rc = mp_->Allocate(sz, reinterpret_cast<void **>(&bl.ptr)); | |||||
| if (rc.IsOk()) { | |||||
| // Write down which numa node where we allocate from. It only make sense if the policy is kOnNode. | |||||
| if (CacheServerHW::numa_enabled()) { | |||||
| auto &cs = CacheServer::GetInstance(); | |||||
| auto node_id = cs.GetHWControl()->GetMyNode(); | |||||
| bl.node_id = mp_->FindNode(bl.ptr); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(bl.node_id != -1, "Allocator is not from numa memory pool"); | |||||
| bl.node_hit = (bl.node_id == node_id); | |||||
| } | |||||
| // We will do a piecewise copy. | |||||
| WritableSlice dest(bl.ptr, bl.sz); | |||||
| size_t pos = 0; | |||||
| for (auto &v : buf) { | |||||
| WritableSlice out(dest, pos); | |||||
| rc = WritableSlice::Copy(&out, v); | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| alloc_.deallocate(bl.ptr, sz); | |||||
| bl.ptr = nullptr; | |||||
| return rc; | |||||
| break; | |||||
| } | } | ||||
| } else if (sm_ != nullptr) { | |||||
| pos += v.GetSize(); | |||||
| } | |||||
| if (rc.IsError()) { | |||||
| mp_->Deallocate(bl.ptr); | |||||
| bl.ptr = nullptr; | |||||
| return rc; | |||||
| } | |||||
| } else if (rc.IsOutofMemory()) { | |||||
| // If no memory, write to disk. | |||||
| if (sm_ != nullptr) { | |||||
| MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes."; | MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes."; | ||||
| RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | ||||
| } else { | } else { | ||||
| @@ -112,12 +120,8 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic | |||||
| // instead. | // instead. | ||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | ||||
| } | } | ||||
| } catch (std::bad_alloc &e) { | |||||
| if (sm_ != nullptr) { | |||||
| RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | |||||
| } else { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| } | |||||
| } else { | |||||
| return rc; | |||||
| } | } | ||||
| // Insert into the B+ tree. We may still get out of memory error. So need to catch it. | // Insert into the B+ tree. We may still get out of memory error. So need to catch it. | ||||
| try { | try { | ||||
| @@ -127,10 +131,13 @@ Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlic | |||||
| } | } | ||||
| // Duplicate key is treated as error and we will also free the memory. | // Duplicate key is treated as error and we will also free the memory. | ||||
| if (rc.IsError() && bl.ptr != nullptr) { | if (rc.IsError() && bl.ptr != nullptr) { | ||||
| alloc_.deallocate(bl.ptr, sz); | |||||
| mp_->Deallocate(bl.ptr); | |||||
| bl.ptr = nullptr; | |||||
| return rc; | |||||
| } | } | ||||
| return rc; | return rc; | ||||
| } | } | ||||
| Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { | Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { | ||||
| RETURN_UNEXPECTED_IF_NULL(dest); | RETURN_UNEXPECTED_IF_NULL(dest); | ||||
| auto r = tree_->Search(key); | auto r = tree_->Search(key); | ||||
| @@ -156,13 +163,14 @@ Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *byt | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } | |||||
| Path CachePool::GetSpillPath() const { | Path CachePool::GetSpillPath() const { | ||||
| auto spill = Path(root_) / subfolder_; | auto spill = Path(root_) / subfolder_; | ||||
| return spill; | return spill; | ||||
| } | } | ||||
| CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | ||||
| CacheStat cs{-1, -1, 0, 0, 0}; | |||||
| CacheStat cs{-1, -1, 0, 0, 0, 0}; | |||||
| int64_t total_sz = 0; | int64_t total_sz = 0; | ||||
| if (tree_->begin() != tree_->end()) { | if (tree_->begin() != tree_->end()) { | ||||
| cs.min_key = tree_->begin().key(); | cs.min_key = tree_->begin().key(); | ||||
| @@ -174,6 +182,9 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||||
| } else { | } else { | ||||
| ++cs.num_disk_cached; | ++cs.num_disk_cached; | ||||
| } | } | ||||
| if (it.value().node_hit) { | |||||
| ++cs.num_numa_hit; | |||||
| } | |||||
| auto cur_key = it.key(); | auto cur_key = it.key(); | ||||
| if (GetMissingKeys) { | if (GetMissingKeys) { | ||||
| for (auto i = cs.max_key + 1; i < cur_key; ++i) { | for (auto i = cs.max_key + 1; i < cur_key; ++i) { | ||||
| @@ -192,49 +203,26 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||||
| } | } | ||||
| return cs; | return cs; | ||||
| } | } | ||||
| Status CachePool::Spill(CachePool::DataLocator *dl) { | |||||
| if (sm_ == nullptr) { | |||||
| RETURN_STATUS_UNEXPECTED("No disk storage to spill"); | |||||
| } | |||||
| RETURN_UNEXPECTED_IF_NULL(dl); | |||||
| RETURN_UNEXPECTED_IF_NULL(dl->ptr); | |||||
| if (dl->storage_key == 0) { | |||||
| ReadableSlice data(dl->ptr, dl->sz); | |||||
| RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); | |||||
| } | |||||
| alloc_.deallocate(dl->ptr, dl->sz); | |||||
| dl->ptr = nullptr; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CachePool::Locate(CachePool::DataLocator *dl) { | |||||
| RETURN_UNEXPECTED_IF_NULL(dl); | |||||
| if (dl->ptr == nullptr) { | |||||
| if (sm_ == nullptr) { | |||||
| RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); | |||||
| } | |||||
| try { | |||||
| dl->ptr = alloc_.allocate(dl->sz); | |||||
| WritableSlice dest(dl->ptr, dl->sz); | |||||
| Status rc = Read(dl->storage_key, &dest); | |||||
| if (rc.IsError()) { | |||||
| alloc_.deallocate(dl->ptr, dl->sz); | |||||
| dl->ptr = nullptr; | |||||
| return rc; | |||||
| } | |||||
| } catch (const std::bad_alloc &e) { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| size_t CachePool::GetSize(CachePool::key_type key) const { | |||||
| Status CachePool::GetDataLocator(key_type key, const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, | |||||
| flatbuffers::Offset<DataLocatorMsg> *out) const { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| auto r = tree_->Search(key); | auto r = tree_->Search(key); | ||||
| if (r.second) { | if (r.second) { | ||||
| auto &it = r.first; | auto &it = r.first; | ||||
| return it->sz; | |||||
| DataLocatorMsgBuilder bld(*fbb); | |||||
| bld.add_key(key); | |||||
| bld.add_size(it->sz); | |||||
| bld.add_node_id(it->node_id); | |||||
| bld.add_addr(reinterpret_cast<int64_t>(it->ptr)); | |||||
| auto offset = bld.Finish(); | |||||
| *out = offset; | |||||
| } else { | } else { | ||||
| return 0; | |||||
| // Key not in the cache. | |||||
| auto offset = CreateDataLocatorMsg(*fbb, key, 0, 0, 0); | |||||
| *out = offset; | |||||
| } | } | ||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,11 +19,14 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||||
| #include "minddata/dataset/engine/cache/storage_manager.h" | |||||
| #include "minddata/dataset/util/allocator.h" | #include "minddata/dataset/util/allocator.h" | ||||
| #include "minddata/dataset/util/service.h" | #include "minddata/dataset/util/service.h" | ||||
| #include "minddata/dataset/util/slice.h" | #include "minddata/dataset/util/slice.h" | ||||
| #include "minddata/dataset/util/storage_manager.h" | |||||
| #include "minddata/dataset/util/auto_index.h" | #include "minddata/dataset/util/auto_index.h" | ||||
| #include "minddata/dataset/util/btree.h" | #include "minddata/dataset/util/btree.h" | ||||
| @@ -45,13 +48,15 @@ class CachePool : public Service { | |||||
| // An internal class to locate the whereabouts of a backed up buffer which can be either in | // An internal class to locate the whereabouts of a backed up buffer which can be either in | ||||
| class DataLocator { | class DataLocator { | ||||
| public: | public: | ||||
| DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} | |||||
| DataLocator() : ptr(nullptr), sz(0), node_id(0), node_hit(false), storage_key(0) {} | |||||
| ~DataLocator() = default; | ~DataLocator() = default; | ||||
| DataLocator(const DataLocator &other) = default; | DataLocator(const DataLocator &other) = default; | ||||
| DataLocator &operator=(const DataLocator &other) = default; | DataLocator &operator=(const DataLocator &other) = default; | ||||
| DataLocator(DataLocator &&other) noexcept { | DataLocator(DataLocator &&other) noexcept { | ||||
| ptr = other.ptr; | ptr = other.ptr; | ||||
| sz = other.sz; | sz = other.sz; | ||||
| node_id = other.node_id; | |||||
| node_hit = other.node_hit; | |||||
| storage_key = other.storage_key; | storage_key = other.storage_key; | ||||
| other.ptr = nullptr; | other.ptr = nullptr; | ||||
| other.sz = 0; | other.sz = 0; | ||||
| @@ -61,6 +66,8 @@ class CachePool : public Service { | |||||
| if (&other != this) { | if (&other != this) { | ||||
| ptr = other.ptr; | ptr = other.ptr; | ||||
| sz = other.sz; | sz = other.sz; | ||||
| node_id = other.node_id; | |||||
| node_hit = other.node_hit; | |||||
| storage_key = other.storage_key; | storage_key = other.storage_key; | ||||
| other.ptr = nullptr; | other.ptr = nullptr; | ||||
| other.sz = 0; | other.sz = 0; | ||||
| @@ -70,6 +77,8 @@ class CachePool : public Service { | |||||
| } | } | ||||
| pointer ptr; | pointer ptr; | ||||
| size_t sz; | size_t sz; | ||||
| numa_id_t node_id; // where the numa node the memory is allocated to | |||||
| bool node_hit; // we can allocate to the preferred node | |||||
| StorageManager::key_type storage_key; | StorageManager::key_type storage_key; | ||||
| }; | }; | ||||
| @@ -85,19 +94,20 @@ class CachePool : public Service { | |||||
| int64_t num_mem_cached; | int64_t num_mem_cached; | ||||
| int64_t num_disk_cached; | int64_t num_disk_cached; | ||||
| int64_t average_cache_sz; | int64_t average_cache_sz; | ||||
| int64_t num_numa_hit; | |||||
| std::vector<key_type> gap; | std::vector<key_type> gap; | ||||
| }; | }; | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param alloc Allocator to allocate memory from | /// \param alloc Allocator to allocate memory from | ||||
| /// \param root Optional disk folder to spill | /// \param root Optional disk folder to spill | ||||
| explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = ""); | |||||
| explicit CachePool(std::shared_ptr<NumaMemoryPool> mp, const std::string &root = ""); | |||||
| CachePool(const CachePool &) = delete; | CachePool(const CachePool &) = delete; | ||||
| CachePool(CachePool &&) = delete; | CachePool(CachePool &&) = delete; | ||||
| CachePool &operator=(const CachePool &) = delete; | CachePool &operator=(const CachePool &) = delete; | ||||
| CachePool &operator=(CachePool &&) = delete; | CachePool &operator=(CachePool &&) = delete; | ||||
| ~CachePool() noexcept; | |||||
| ~CachePool() noexcept override; | |||||
| Status DoServiceStart() override; | Status DoServiceStart() override; | ||||
| Status DoServiceStop() override; | Status DoServiceStop() override; | ||||
| @@ -110,7 +120,8 @@ class CachePool : public Service { | |||||
| /// \param[in] buf A sequence of ReadableSlice objects. | /// \param[in] buf A sequence of ReadableSlice objects. | ||||
| /// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory | /// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory | ||||
| /// \return Error code | /// \return Error code | ||||
| Status Insert(key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly); | |||||
| Status Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf); | |||||
| /// \brief Restore a cached buffer (from memory or disk) | /// \brief Restore a cached buffer (from memory or disk) | ||||
| /// \param[in] key A previous key returned from Insert | /// \param[in] key A previous key returned from Insert | ||||
| /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice | /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice | ||||
| @@ -118,18 +129,14 @@ class CachePool : public Service { | |||||
| /// \return Error code | /// \return Error code | ||||
| Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; | Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; | ||||
| Status Spill(DataLocator *dl); | |||||
| Status Locate(DataLocator *dl); | |||||
| size_t GetSize(key_type key) const; | |||||
| /// \brief Serialize a DataLocator | |||||
| Status GetDataLocator(key_type, const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, | |||||
| flatbuffers::Offset<DataLocatorMsg> *) const; | |||||
| /// \brief Get statistics. | /// \brief Get statistics. | ||||
| /// \return CacheStat object | /// \return CacheStat object | ||||
| CacheStat GetStat(bool GetMissingKeys = false) const; | CacheStat GetStat(bool GetMissingKeys = false) const; | ||||
| const value_allocator &get_allocator() const; | |||||
| std::string MyName() const { return subfolder_; } | std::string MyName() const { return subfolder_; } | ||||
| /// \brief Toggle locking | /// \brief Toggle locking | ||||
| @@ -137,12 +144,11 @@ class CachePool : public Service { | |||||
| void SetLocking(bool on_off) { tree_->SetLocking(on_off); } | void SetLocking(bool on_off) { tree_->SetLocking(on_off); } | ||||
| private: | private: | ||||
| value_allocator alloc_; | |||||
| std::shared_ptr<NumaMemoryPool> mp_; | |||||
| Path root_; | Path root_; | ||||
| const std::string subfolder_; | const std::string subfolder_; | ||||
| std::shared_ptr<StorageManager> sm_; | std::shared_ptr<StorageManager> sm_; | ||||
| std::shared_ptr<data_index> tree_; | std::shared_ptr<data_index> tree_; | ||||
| bool custom_arena_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,6 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| #include <sched.h> | |||||
| #include <sys/types.h> | |||||
| #include <unistd.h> | |||||
| #endif | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <thread> | #include <thread> | ||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| @@ -106,6 +111,7 @@ Status CacheRowRequest::PostReply() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheRowRequest::Prepare() { | Status CacheRowRequest::Prepare() { | ||||
| if (BitTest(rq_.flag(), kDataIsInSharedMemory)) { | if (BitTest(rq_.flag(), kDataIsInSharedMemory)) { | ||||
| // First one is cookie, followed by address and then size. | // First one is cookie, followed by address and then size. | ||||
| @@ -118,10 +124,21 @@ Status CacheRowRequest::Prepare() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, | |||||
| bool local_bypass) | |||||
| : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| CacheRowRequest::CacheRowRequest(const CacheClient *cc) | |||||
| : BaseRequest(RequestType::kCacheRow), | |||||
| support_local_bypass_(cc->local_bypass_), | |||||
| addr_(-1), | |||||
| sz_(0), | |||||
| row_id_from_server_(-1) { | |||||
| rq_.set_connection_id(cc->server_connection_id_); | |||||
| rq_.set_client_id(cc->client_id_); | |||||
| rq_.add_buf_data(cc->cookie_); | |||||
| } | |||||
| BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id) | |||||
| : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) { | |||||
| rq_.set_connection_id(cc->server_connection_id_); | |||||
| rq_.set_client_id(cc->client_id_); | |||||
| rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0); | rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0); | ||||
| // Convert the row id into a flatbuffer | // Convert the row id into a flatbuffer | ||||
| flatbuffers::FlatBufferBuilder fbb; | flatbuffers::FlatBufferBuilder fbb; | ||||
| @@ -186,9 +203,9 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, in | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||||
| CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||||
| CreateCacheRequest::CreateCacheFlag flag) | CreateCacheRequest::CreateCacheFlag flag) | ||||
| : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) { | |||||
| : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) { | |||||
| // Type has been set already in the base constructor. So we need to fill in the connection info. | // Type has been set already in the base constructor. So we need to fill in the connection info. | ||||
| // On successful return, we will get the connection id | // On successful return, we will get the connection id | ||||
| rq_.mutable_connection_info()->operator=(cinfo); | rq_.mutable_connection_info()->operator=(cinfo); | ||||
| @@ -209,6 +226,41 @@ Status CreateCacheRequest::Prepare() { | |||||
| } | } | ||||
| } | } | ||||
| Status CreateCacheRequest::PostReply() { | |||||
| auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | |||||
| cc_->server_connection_id_ = p->connection_id(); | |||||
| cc_->cookie_ = p->cookie()->str(); | |||||
| cc_->client_id_ = p->client_id(); | |||||
| // Next is a set of cpu id that we should re-adjust ourselves for better affinity. | |||||
| auto sz = p->cpu_id()->size(); | |||||
| cc_->cpu_list_.reserve(sz); | |||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| std::string c_list; | |||||
| cpu_set_t cpu_set; | |||||
| CPU_ZERO(&cpu_set); | |||||
| #endif | |||||
| for (auto i = 0; i < sz; ++i) { | |||||
| auto cpu_id = p->cpu_id()->Get(i); | |||||
| cc_->cpu_list_.push_back(cpu_id); | |||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| c_list += std::to_string(cpu_id) + " "; | |||||
| CPU_SET(cpu_id, &cpu_set); | |||||
| #endif | |||||
| } | |||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| if (sz > 0) { | |||||
| auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set); | |||||
| if (err == -1) { | |||||
| RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno)); | |||||
| } | |||||
| MS_LOG(WARNING) << "Changing cpu affinity to the following list of cpu id: " + c_list; | |||||
| } | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) { | Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) { | ||||
| try { | try { | ||||
| flatbuffers::FlatBufferBuilder fbb; | flatbuffers::FlatBufferBuilder fbb; | ||||
| @@ -245,6 +297,7 @@ Status GetStatRequest::PostReply() { | |||||
| stat_.num_disk_cached = msg->num_disk_cached(); | stat_.num_disk_cached = msg->num_disk_cached(); | ||||
| stat_.num_mem_cached = msg->num_mem_cached(); | stat_.num_mem_cached = msg->num_mem_cached(); | ||||
| stat_.avg_cache_sz = msg->avg_cache_sz(); | stat_.avg_cache_sz = msg->avg_cache_sz(); | ||||
| stat_.num_numa_hit = msg->num_numa_hit(); | |||||
| stat_.max_row_id = msg->max_row_id(); | stat_.max_row_id = msg->max_row_id(); | ||||
| stat_.min_row_id = msg->min_row_id(); | stat_.min_row_id = msg->min_row_id(); | ||||
| stat_.cache_service_state = msg->state(); | stat_.cache_service_state = msg->state(); | ||||
| @@ -255,14 +308,15 @@ Status ListSessionsRequest::PostReply() { | |||||
| auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data()); | auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data()); | ||||
| auto session_vector = msg->sessions(); | auto session_vector = msg->sessions(); | ||||
| for (auto i = 0; i < session_vector->size(); ++i) { | for (auto i = 0; i < session_vector->size(); ++i) { | ||||
| SessionCacheInfo current_info; | |||||
| CacheServiceStat stats; | |||||
| SessionCacheInfo current_info{}; | |||||
| CacheServiceStat stats{}; | |||||
| auto current_session_info = session_vector->Get(i); | auto current_session_info = session_vector->Get(i); | ||||
| current_info.session_id = current_session_info->session_id(); | current_info.session_id = current_session_info->session_id(); | ||||
| current_info.connection_id = current_session_info->connection_id(); | current_info.connection_id = current_session_info->connection_id(); | ||||
| stats.num_mem_cached = current_session_info->stats()->num_mem_cached(); | stats.num_mem_cached = current_session_info->stats()->num_mem_cached(); | ||||
| stats.num_disk_cached = current_session_info->stats()->num_disk_cached(); | stats.num_disk_cached = current_session_info->stats()->num_disk_cached(); | ||||
| stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz(); | stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz(); | ||||
| stats.num_numa_hit = current_session_info->stats()->num_numa_hit(); | |||||
| stats.min_row_id = current_session_info->stats()->min_row_id(); | stats.min_row_id = current_session_info->stats()->min_row_id(); | ||||
| stats.max_row_id = current_session_info->stats()->max_row_id(); | stats.max_row_id = current_session_info->stats()->max_row_id(); | ||||
| stats.cache_service_state = current_session_info->stats()->state(); | stats.cache_service_state = current_session_info->stats()->state(); | ||||
| @@ -41,6 +41,7 @@ struct CacheServiceStat { | |||||
| int64_t num_mem_cached; | int64_t num_mem_cached; | ||||
| int64_t num_disk_cached; | int64_t num_disk_cached; | ||||
| int64_t avg_cache_sz; | int64_t avg_cache_sz; | ||||
| int64_t num_numa_hit; | |||||
| row_id_type min_row_id; | row_id_type min_row_id; | ||||
| row_id_type max_row_id; | row_id_type max_row_id; | ||||
| int8_t cache_service_state; | int8_t cache_service_state; | ||||
| @@ -75,6 +76,8 @@ class BaseRequest { | |||||
| kHeartBeat = 14, | kHeartBeat = 14, | ||||
| kToggleWriteMode = 15, | kToggleWriteMode = 15, | ||||
| kListSessions = 16, | kListSessions = 16, | ||||
| kConnectReset = 17, | |||||
| kInternalFetchRow = 18, | |||||
| // Add new request before it. | // Add new request before it. | ||||
| kRequestUnknown = 32767 | kRequestUnknown = 32767 | ||||
| }; | }; | ||||
| @@ -84,10 +87,15 @@ class BaseRequest { | |||||
| friend class CacheClientGreeter; | friend class CacheClientGreeter; | ||||
| friend class CacheClientRequestTag; | friend class CacheClientRequestTag; | ||||
| friend class CacheClient; | friend class CacheClient; | ||||
| friend class CacheService; | |||||
| /// \brief Base class of a cache server request | /// \brief Base class of a cache server request | ||||
| /// \param type Type of the request | /// \param type Type of the request | ||||
| explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast<int16_t>(type_)); } | |||||
| explicit BaseRequest(RequestType type) : type_(type) { | |||||
| rq_.set_type(static_cast<int16_t>(type_)); | |||||
| rq_.set_client_id(-1); | |||||
| rq_.set_flag(0); | |||||
| } | |||||
| virtual ~BaseRequest() = default; | virtual ~BaseRequest() = default; | ||||
| /// \brief A print method for debugging | /// \brief A print method for debugging | ||||
| @@ -138,15 +146,7 @@ class CacheRowRequest : public BaseRequest { | |||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| friend class CacheClient; | friend class CacheClient; | ||||
| explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass) | |||||
| : BaseRequest(RequestType::kCacheRow), | |||||
| support_local_bypass_(local_bypass), | |||||
| addr_(-1), | |||||
| sz_(0), | |||||
| row_id_from_server_(-1) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| rq_.add_buf_data(cookie); | |||||
| } | |||||
| explicit CacheRowRequest(const CacheClient *cc); | |||||
| ~CacheRowRequest() override = default; | ~CacheRowRequest() override = default; | ||||
| /// \brief Serialize a TensorRow for streaming to the cache server | /// \brief Serialize a TensorRow for streaming to the cache server | ||||
| @@ -193,7 +193,7 @@ class BatchFetchRequest : public BaseRequest { | |||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| friend class CacheService; | friend class CacheService; | ||||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass); | |||||
| BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id); | |||||
| ~BatchFetchRequest() override = default; | ~BatchFetchRequest() override = default; | ||||
| Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); | Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); | ||||
| @@ -212,21 +212,18 @@ class CreateCacheRequest : public BaseRequest { | |||||
| /// \param connection_id | /// \param connection_id | ||||
| /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited | /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited | ||||
| /// \param flag Attributes of the cache. | /// \param flag Attributes of the cache. | ||||
| explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||||
| explicit CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||||
| CreateCacheFlag flag = CreateCacheFlag::kNone); | CreateCacheFlag flag = CreateCacheFlag::kNone); | ||||
| ~CreateCacheRequest() override = default; | ~CreateCacheRequest() override = default; | ||||
| void ParseResult(connection_id_type *id, std::string *out) { | |||||
| auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | |||||
| *id = p->connection_id(); | |||||
| *out = p->cookie()->str(); | |||||
| } | |||||
| /// Overload the base class Prepare | |||||
| /// Overload the base class Prepare/PostReply | |||||
| Status Prepare() override; | Status Prepare() override; | ||||
| Status PostReply() override; | |||||
| private: | private: | ||||
| uint64_t cache_mem_sz_; | uint64_t cache_mem_sz_; | ||||
| CreateCacheFlag flag_; | CreateCacheFlag flag_; | ||||
| CacheClient *cc_; | |||||
| }; | }; | ||||
| /// \brief Request to get all the keys not present at the server. | /// \brief Request to get all the keys not present at the server. | ||||
| @@ -396,6 +393,23 @@ class ToggleWriteModeRequest : public BaseRequest { | |||||
| } | } | ||||
| ~ToggleWriteModeRequest() override = default; | ~ToggleWriteModeRequest() override = default; | ||||
| }; | }; | ||||
| class ConnectResetRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| explicit ConnectResetRequest(connection_id_type connection_id, int32_t client_id) | |||||
| : BaseRequest(RequestType::kConnectReset) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| rq_.set_client_id(client_id); | |||||
| } | |||||
| ~ConnectResetRequest() override = default; | |||||
| /// Override the base class function | |||||
| Status Prepare() override { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq_.client_id() != -1, "Invalid client id"); | |||||
| return Status::OK(); | |||||
| } | |||||
| }; | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <functional> | #include <functional> | ||||
| #include <limits> | #include <limits> | ||||
| #include <vector> | |||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | #include "minddata/dataset/engine/cache/cache_ipc.h" | ||||
| #include "minddata/dataset/engine/cache/cache_service.h" | #include "minddata/dataset/engine/cache/cache_service.h" | ||||
| @@ -43,36 +44,57 @@ Status CacheServer::DoServiceStart() { | |||||
| MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; | MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | RETURN_IF_NOT_OK(vg_.ServiceStart()); | ||||
| // There will be num_workers_ threads working on the grpc queue and | |||||
| // the same number of threads working on the CacheServerRequest queue. | |||||
| RETURN_IF_NOT_OK(hw_info_->GetNumaNodeInfo()); | |||||
| auto num_numa_nodes = GetNumaNodeCount(); | |||||
| // If we link with numa library. Set default memory policy. | |||||
| // If we don't pin thread to cpu, then use up all memory controllers to maximize | |||||
| // memory bandwidth. | |||||
| RETURN_IF_NOT_OK( | |||||
| CacheServerHW::SetDefaultMemoryPolicy(numa_affinity_ ? CachePoolPolicy::kLocal : CachePoolPolicy::kInterleave)); | |||||
| auto my_node = hw_info_->GetMyNode(); | |||||
| MS_LOG(DEBUG) << "Cache server is running on numa node " << my_node; | |||||
| // Bump up num_workers_ to at least the number of numa nodes | |||||
| num_workers_ = std::max(num_numa_nodes, num_workers_); | |||||
| // But also it shouldn't be too many more than the hardware concurrency | |||||
| auto num_cpus = hw_info_->GetCpuCount(); | |||||
| num_workers_ = std::min(2 * num_cpus, num_workers_); | |||||
| // Round up num_workers to a multiple of numa nodes. | |||||
| auto remainder = num_workers_ % num_numa_nodes; | |||||
| if (remainder > 0) num_workers_ += (num_numa_nodes - remainder); | |||||
| MS_LOG(INFO) << "Re-adjusting the number of workers to " << num_workers_; | |||||
| // There will be some threads working on the grpc queue and | |||||
| // some number of threads working on the CacheServerRequest queue. | |||||
| // Like a connector object we will set up the same number of queues but | // Like a connector object we will set up the same number of queues but | ||||
| // we do not need to preserve any order. We will set the capacity of | // we do not need to preserve any order. We will set the capacity of | ||||
| // each queue to be 128 since we are just pushing memory pointers which | |||||
| // each queue to be 64 since we are just pushing memory pointers which | |||||
| // is only 8 byte each. | // is only 8 byte each. | ||||
| const int32_t que_capacity = 128; | |||||
| const int32_t kQueCapacity = 64; | |||||
| // This is the request queue from the client | // This is the request queue from the client | ||||
| cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>(); | cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>(); | ||||
| cache_q_->Init(num_workers_, que_capacity); | |||||
| cache_q_->Init(num_workers_, kQueCapacity); | |||||
| // We will match the number of grpc workers with the number of server workers. | |||||
| // But technically they don't have to be the same. | |||||
| num_grpc_workers_ = num_workers_; | |||||
| MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_; | |||||
| // For the grpc completion queue to work, we need to allocate some | // For the grpc completion queue to work, we need to allocate some | ||||
| // tags which in our case are instances of CacheServerQuest. | // tags which in our case are instances of CacheServerQuest. | ||||
| // They got recycled and we will allocate them in advance and push | // They got recycled and we will allocate them in advance and push | ||||
| // them into some free list. We need more (two or three times) the | // them into some free list. We need more (two or three times) the | ||||
| // size of the cache_q. While each worker is working on a CacheSerRequest, | // size of the cache_q. While each worker is working on a CacheSerRequest, | ||||
| // we need some extra running injecting in the the qrpc completion queue. | // we need some extra running injecting in the the qrpc completion queue. | ||||
| const int32_t multiplier = 3; | |||||
| const int32_t free_list_capacity = multiplier * (que_capacity + 1); | |||||
| const int32_t kMultiplier = 2; | |||||
| int ratio = num_workers_ / num_grpc_workers_; | |||||
| if (num_workers_ % num_grpc_workers_) ++ratio; | |||||
| const int32_t free_list_capacity = kMultiplier * (kQueCapacity + 1) * ratio; | |||||
| free_list_ = std::make_shared<QueueList<CacheServerRequest *>>(); | free_list_ = std::make_shared<QueueList<CacheServerRequest *>>(); | ||||
| free_list_->Init(num_workers_, free_list_capacity); | |||||
| // We need to have a reference to the services memory pool in case | |||||
| // the Services goes out of scope earlier than us since it is a singleton | |||||
| mp_ = Services::GetInstance().GetServiceMemPool(); | |||||
| Allocator<CacheServerRequest> alloc(mp_); | |||||
| tag_.reserve(num_workers_); | |||||
| // Now we populate all free list. | |||||
| for (auto m = 0; m < num_workers_; ++m) { | |||||
| // Ideally we allocate all the free list in one malloc. But it turns out it exceeds the | |||||
| // Arena size. So we will we will allocate one segment at a time. | |||||
| auto my_tag = std::make_unique<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>(alloc); | |||||
| free_list_->Init(num_grpc_workers_, free_list_capacity); | |||||
| tag_.reserve(num_grpc_workers_); | |||||
| // Now we populate all free list. Round robin the free list among the numa nodes. | |||||
| for (auto m = 0; m < num_grpc_workers_; ++m) { | |||||
| NumaAllocator<CacheServerRequest> alloc(m % num_numa_nodes, CachePoolPolicy::kPreferred); | |||||
| // Ideally we allocate all the free list in one malloc. But we will allocate one segment | |||||
| // at a time so that we can change the numa policy easily per grpc worker. | |||||
| auto my_tag = std::make_unique<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>(alloc); | |||||
| // Allocate the tag and assign it the current queue | // Allocate the tag and assign it the current queue | ||||
| RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m)); | RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m)); | ||||
| for (int i = 0; i < free_list_capacity; ++i) { | for (int i = 0; i < free_list_capacity; ++i) { | ||||
| @@ -82,11 +104,6 @@ Status CacheServer::DoServiceStart() { | |||||
| } | } | ||||
| RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); | RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); | ||||
| RETURN_IF_NOT_OK(free_list_->Register(&vg_)); | RETURN_IF_NOT_OK(free_list_->Register(&vg_)); | ||||
| // Spawn a few threads to serve the real request. | |||||
| auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); | |||||
| for (auto i = 0; i < num_workers_; ++i) { | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i))); | |||||
| } | |||||
| // Start the comm layer | // Start the comm layer | ||||
| try { | try { | ||||
| comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_); | ||||
| @@ -94,10 +111,29 @@ Status CacheServer::DoServiceStart() { | |||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| RETURN_STATUS_UNEXPECTED(e.what()); | RETURN_STATUS_UNEXPECTED(e.what()); | ||||
| } | } | ||||
| // Spawn a few threads to serve the real request. | |||||
| auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); | |||||
| for (auto i = 0; i < num_workers_; ++i) { | |||||
| Task *pTask; | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i), &pTask)); | |||||
| // Save a copy of the pointer to the underlying Task object. We may dynamically change their affinity if needed. | |||||
| numa_tasks_.emplace(i, pTask); | |||||
| // Spread out all the threads to all the numa nodes if needed | |||||
| if (IsNumaAffinityOn()) { | |||||
| auto numa_id = i % num_numa_nodes; | |||||
| RETURN_IF_NOT_OK(SetAffinity(*pTask, numa_id)); | |||||
| } | |||||
| } | |||||
| // Finally loop forever to handle the request. | // Finally loop forever to handle the request. | ||||
| auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1); | auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1); | ||||
| for (auto i = 0; i < num_workers_; ++i) { | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i))); | |||||
| for (auto i = 0; i < num_grpc_workers_; ++i) { | |||||
| Task *pTask; | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i), &pTask)); | |||||
| // All these grpc workers will be allocated to the same node which is where we allocate all those free tag | |||||
| // memory. | |||||
| if (IsNumaAffinityOn()) { | |||||
| RETURN_IF_NOT_OK(SetAffinity(*pTask, i % num_numa_nodes)); | |||||
| } | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -108,8 +144,6 @@ Status CacheServer::DoServiceStop() { | |||||
| // First stop all the threads. | // First stop all the threads. | ||||
| RETURN_IF_NOT_OK(vg_.ServiceStop()); | RETURN_IF_NOT_OK(vg_.ServiceStop()); | ||||
| // Clean up all the caches if any. | // Clean up all the caches if any. | ||||
| // Dump a message how much memory we have consumed in total. | |||||
| MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes."; | |||||
| UniqueLock lck(&rwLock_); | UniqueLock lck(&rwLock_); | ||||
| auto it = all_caches_.begin(); | auto it = all_caches_.begin(); | ||||
| while (it != all_caches_.end()) { | while (it != all_caches_.end()) { | ||||
| @@ -134,13 +168,14 @@ CacheService *CacheServer::GetService(connection_id_type id) const { | |||||
| Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); | CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); | ||||
| std::string cookie; | std::string cookie; | ||||
| int32_t client_id; | |||||
| auto session_id = rq->connection_info().session_id(); | auto session_id = rq->connection_info().session_id(); | ||||
| auto crc = rq->connection_info().crc(); | auto crc = rq->connection_info().crc(); | ||||
| // Before allowing the creation, make sure the session had already been created by the user | // Before allowing the creation, make sure the session had already been created by the user | ||||
| // Our intention is to add this cache to the active sessions list so leave the list locked during | // Our intention is to add this cache to the active sessions list so leave the list locked during | ||||
| // this entire function. | // this entire function. | ||||
| UniqueLock lock(&sessions_lock_); | |||||
| UniqueLock sess_lck(&sessions_lock_); | |||||
| auto session_it = active_sessions_.find(session_id); | auto session_it = active_sessions_.find(session_id); | ||||
| if (session_it == active_sessions_.end()) { | if (session_it == active_sessions_.end()) { | ||||
| RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!"); | RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!"); | ||||
| @@ -163,6 +198,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||||
| } | } | ||||
| flatbuffers::FlatBufferBuilder fbb; | flatbuffers::FlatBufferBuilder fbb; | ||||
| flatbuffers::Offset<flatbuffers::String> off_cookie; | flatbuffers::Offset<flatbuffers::String> off_cookie; | ||||
| flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list; | |||||
| // Before creating the cache, first check if this is a request for a shared usage of an existing cache | // Before creating the cache, first check if this is a request for a shared usage of an existing cache | ||||
| // If two CreateService come in with identical connection_id, we need to serialize the create. | // If two CreateService come in with identical connection_id, we need to serialize the create. | ||||
| // The first create will be successful and be given a special cookie. | // The first create will be successful and be given a special cookie. | ||||
| @@ -171,32 +207,74 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||||
| if (global_shutdown_) { | if (global_shutdown_) { | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // We would like to protect ourselves from over allocating too much. We will go over existing cache | |||||
| // and calculate how much we have consumed so far. | |||||
| auto end = all_caches_.end(); | auto end = all_caches_.end(); | ||||
| auto it = all_caches_.find(connection_id); | |||||
| auto it = all_caches_.begin(); | |||||
| bool duplicate = false; | bool duplicate = false; | ||||
| auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; | |||||
| int64_t max_avail = avail_mem; | |||||
| while (it != end) { | |||||
| if (it->first == connection_id) { | |||||
| duplicate = true; | |||||
| break; | |||||
| } else { | |||||
| auto &cs = it->second; | |||||
| CacheService::ServiceStat stat; | |||||
| RETURN_IF_NOT_OK(cs->GetStat(&stat)); | |||||
| int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz; | |||||
| max_avail -= mem_consumed; | |||||
| if (max_avail <= 0) { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||||
| } | |||||
| } | |||||
| ++it; | |||||
| } | |||||
| if (it == end) { | if (it == end) { | ||||
| // If we have some cache using some memory already, make a reasonable decision if we should return | |||||
| // out of memory. | |||||
| if (max_avail < avail_mem) { | |||||
| int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit. | |||||
| if (req_mem > max_avail) { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||||
| } else if (req_mem == 0) { | |||||
| // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than | |||||
| // 85% of our limit, fail this request. | |||||
| if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= 0.15) { | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::unique_ptr<CacheService> cs; | std::unique_ptr<CacheService> cs; | ||||
| try { | try { | ||||
| cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id); | cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id); | ||||
| RETURN_IF_NOT_OK(cs->ServiceStart()); | RETURN_IF_NOT_OK(cs->ServiceStart()); | ||||
| cookie = cs->cookie(); | cookie = cs->cookie(); | ||||
| client_id = cs->num_clients_.fetch_add(1); | |||||
| all_caches_.emplace(connection_id, std::move(cs)); | all_caches_.emplace(connection_id, std::move(cs)); | ||||
| } catch (const std::bad_alloc &e) { | } catch (const std::bad_alloc &e) { | ||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| } | } | ||||
| // Add the cache into the active session tracking. | |||||
| // We have already validated that the session exists and that this is a new cache created. | |||||
| session_it->second.insert(connection_id); | |||||
| } else { | } else { | ||||
| duplicate = true; | duplicate = true; | ||||
| client_id = it->second->num_clients_.fetch_add(1); | |||||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | ||||
| } | } | ||||
| // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling | |||||
| // the following function | |||||
| lck.Unlock(); | |||||
| sess_lck.Unlock(); | |||||
| auto numa_id = client_id % GetNumaNodeCount(); | |||||
| std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id); | |||||
| // Send back the data | |||||
| off_cookie = fbb.CreateString(cookie); | off_cookie = fbb.CreateString(cookie); | ||||
| off_cpu_list = fbb.CreateVector(cpu_list); | |||||
| CreateCacheReplyMsgBuilder bld(fbb); | CreateCacheReplyMsgBuilder bld(fbb); | ||||
| bld.add_connection_id(connection_id); | bld.add_connection_id(connection_id); | ||||
| bld.add_cookie(off_cookie); | bld.add_cookie(off_cookie); | ||||
| bld.add_client_id(client_id); | |||||
| // The last thing we send back is a set of cpu id that we suggest the client should bind itself to | |||||
| bld.add_cpu_id(off_cpu_list); | |||||
| auto off = bld.Finish(); | auto off = bld.Finish(); | ||||
| fbb.Finish(off); | fbb.Finish(off); | ||||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | ||||
| @@ -220,26 +298,8 @@ Status CacheServer::DestroyCache(CacheRequest *rq) { | |||||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | ||||
| } | } | ||||
| } | } | ||||
| // Now that this cache is removed, we need to also remove it's connection id from active session tracking | |||||
| auto session_id = GetSessionID(id); | |||||
| UniqueLock sess_lck(&sessions_lock_); | |||||
| auto it = active_sessions_.find(session_id); | |||||
| if (it == active_sessions_.end()) { | |||||
| // The session was not found in the active sessions | |||||
| RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!"); | |||||
| } | |||||
| auto connection_it = it->second.find(id); | |||||
| if (connection_it == it->second.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!"); | |||||
| } | |||||
| // remove that connection id from the set | |||||
| it->second.erase(connection_it); | |||||
| MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id; | |||||
| // We aren't touching the session list even though we may be dropping the last remaining cache of a session. | |||||
| // Leave that to be done by the drop session command. | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -266,6 +326,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) { | |||||
| buffers.push_back(rq->buf_data(i).data()); | buffers.push_back(rq->buf_data(i).data()); | ||||
| } | } | ||||
| row_id_type id = -1; | row_id_type id = -1; | ||||
| // We will allocate the memory the same numa node this thread is bound to. | |||||
| RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id)); | RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id)); | ||||
| reply->set_result(std::to_string(id)); | reply->set_result(std::to_string(id)); | ||||
| } else { | } else { | ||||
| @@ -301,6 +362,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||||
| if (!cs->HasBuildPhase() || cookie == cs->cookie()) { | if (!cs->HasBuildPhase() || cookie == cs->cookie()) { | ||||
| row_id_type id = -1; | row_id_type id = -1; | ||||
| ReadableSlice src(p, sz); | ReadableSlice src(p, sz); | ||||
| // We will allocate the memory the same numa node this thread is bound to. | |||||
| rc = cs->FastCacheRow(src, &id); | rc = cs->FastCacheRow(src, &id); | ||||
| reply->set_result(std::to_string(id)); | reply->set_result(std::to_string(id)); | ||||
| } else { | } else { | ||||
| @@ -330,9 +392,19 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||||
| for (auto i = 0; i < sz; ++i) { | for (auto i = 0; i < sz; ++i) { | ||||
| row_id.push_back(p->row_id()->Get(i)); | row_id.push_back(p->row_id()->Get(i)); | ||||
| } | } | ||||
| int64_t mem_sz = 0; | |||||
| std::vector<key_size_pair> v; | |||||
| RETURN_IF_NOT_OK(cs->PreBatchFetch(row_id, &v, &mem_sz)); | |||||
| std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>(); | |||||
| RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb)); | |||||
| auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||||
| int64_t mem_sz = sizeof(int64_t) * (sz + 1); | |||||
| for (auto i = 0; i < sz; ++i) { | |||||
| auto row_sz = locator->rows()->Get(i)->size(); | |||||
| // row_sz is the size of the cached data. Later we will spawn multiple threads | |||||
| // each of which will copy the data into either shared memory or protobuf concurrently but | |||||
| // to different region. | |||||
| // To avoid false sharing, we will bump up row_sz to be a multiple of 4k, i.e. 4096 bytes | |||||
| row_sz = round_up_4K(row_sz); | |||||
| mem_sz += row_sz; | |||||
| } | |||||
| auto client_flag = rq->flag(); | auto client_flag = rq->flag(); | ||||
| bool local_client = BitTest(client_flag, kLocalClientSupport); | bool local_client = BitTest(client_flag, kLocalClientSupport); | ||||
| // For large amount data to be sent back, we will use shared memory provided it is a local | // For large amount data to be sent back, we will use shared memory provided it is a local | ||||
| @@ -346,7 +418,11 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||||
| void *q = nullptr; | void *q = nullptr; | ||||
| RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); | ||||
| WritableSlice dest(q, mem_sz); | WritableSlice dest(q, mem_sz); | ||||
| RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); | |||||
| Status rc = cs->BatchFetch(fbb, &dest); | |||||
| if (rc.IsError()) { | |||||
| shared_pool->Deallocate(q); | |||||
| return rc; | |||||
| } | |||||
| // We can't return the absolute address which makes no sense to the client. | // We can't return the absolute address which makes no sense to the client. | ||||
| // Instead we return the difference. | // Instead we return the difference. | ||||
| auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base); | auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base); | ||||
| @@ -363,7 +439,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| } | } | ||||
| WritableSlice dest(mem.data(), mem_sz); | WritableSlice dest(mem.data(), mem_sz); | ||||
| RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); | |||||
| RETURN_IF_NOT_OK(cs->BatchFetch(fbb, &dest)); | |||||
| reply->set_result(std::move(mem)); | reply->set_result(std::move(mem)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -386,6 +462,7 @@ Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) { | |||||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | ||||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | ||||
| bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); | bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); | ||||
| bld.add_num_numa_hit(svc_stat.stat_.num_numa_hit); | |||||
| bld.add_max_row_id(svc_stat.stat_.max_key); | bld.add_max_row_id(svc_stat.stat_.max_key); | ||||
| bld.add_min_row_id(svc_stat.stat_.min_key); | bld.add_min_row_id(svc_stat.stat_.min_key); | ||||
| bld.add_state(svc_stat.state_); | bld.add_state(svc_stat.state_); | ||||
| @@ -506,30 +583,27 @@ Status CacheServer::ToggleWriteMode(CacheRequest *rq) { | |||||
| } | } | ||||
| Status CacheServer::ListSessions(CacheReply *reply) { | Status CacheServer::ListSessions(CacheReply *reply) { | ||||
| SharedLock lck(&sessions_lock_); | |||||
| SharedLock sess_lck(&sessions_lock_); | |||||
| SharedLock lck(&rwLock_); | |||||
| flatbuffers::FlatBufferBuilder fbb; | flatbuffers::FlatBufferBuilder fbb; | ||||
| std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector; | std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector; | ||||
| for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) { | |||||
| session_id_type current_session_id = it->first; | |||||
| // Loop over each cache inside this session | |||||
| if (!it->second.empty()) { | |||||
| for (auto current_conn_id : it->second) { | |||||
| CacheService *cs = GetService(current_conn_id); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions."; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| CacheService::ServiceStat svc_stat; | |||||
| RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); | |||||
| auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, | |||||
| svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key, | |||||
| svc_stat.stat_.max_key, svc_stat.state_); | |||||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); | |||||
| session_msgs_vector.push_back(current_session_info); | |||||
| } | |||||
| for (auto const ¤t_session_id : active_sessions_) { | |||||
| bool found = false; | |||||
| for (auto const &it : all_caches_) { | |||||
| auto current_conn_id = it.first; | |||||
| if (GetSessionID(current_conn_id) == current_session_id) { | |||||
| found = true; | |||||
| auto &cs = it.second; | |||||
| CacheService::ServiceStat svc_stat; | |||||
| RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); | |||||
| auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, | |||||
| svc_stat.stat_.average_cache_sz, svc_stat.stat_.num_numa_hit, | |||||
| svc_stat.stat_.min_key, svc_stat.stat_.max_key, svc_stat.state_); | |||||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); | |||||
| session_msgs_vector.push_back(current_session_info); | |||||
| } | } | ||||
| } else { | |||||
| } | |||||
| if (!found) { | |||||
| // If there is no cache created yet, assign a connection id of 0 along with empty stats | // If there is no cache created yet, assign a connection id of 0 along with empty stats | ||||
| auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0); | auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0); | ||||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats); | auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats); | ||||
| @@ -542,18 +616,35 @@ Status CacheServer::ListSessions(CacheReply *reply) { | |||||
| auto offset = s_builder.Finish(); | auto offset = s_builder.Finish(); | ||||
| fbb.Finish(offset); | fbb.Finish(offset); | ||||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | ||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::ConnectReset(CacheRequest *rq) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| auto client_id = rq->client_id(); | |||||
| MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects"; | |||||
| cs->num_clients_--; | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| /// \brief This is the main loop the cache server thread(s) are running. | /// \brief This is the main loop the cache server thread(s) are running. | ||||
| /// Each thread will pop a request and send the result back to the client using grpc | /// Each thread will pop a request and send the result back to the client using grpc | ||||
| /// \return | /// \return | ||||
| Status CacheServer::ServerRequest(int32_t worker_id) { | |||||
| Status CacheServer::ServerRequest(worker_id_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| MS_LOG(DEBUG) << "Worker id " << worker_id << " is running on node " << hw_info_->GetMyNode(); | |||||
| auto &my_que = cache_q_->operator[](worker_id); | auto &my_que = cache_q_->operator[](worker_id); | ||||
| // Loop forever until we are interrupted or shutdown. | // Loop forever until we are interrupted or shutdown. | ||||
| while (!global_shutdown_) { | while (!global_shutdown_) { | ||||
| bool internal_request = false; | |||||
| CacheServerRequest *cache_req = nullptr; | CacheServerRequest *cache_req = nullptr; | ||||
| RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | ||||
| auto &rq = cache_req->rq_; | auto &rq = cache_req->rq_; | ||||
| @@ -571,8 +662,17 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kBatchFetchRows: { | |||||
| cache_req->rc_ = BatchFetchRows(&rq, &reply); | |||||
| case BaseRequest::RequestType::kInternalFetchRow: { | |||||
| internal_request = true; | |||||
| auto connection_id = rq.connection_id(); | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||||
| cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data())); | |||||
| } | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kCreateCache: { | case BaseRequest::RequestType::kCreateCache: { | ||||
| @@ -636,6 +736,10 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||||
| cache_req->rc_ = ListSessions(&reply); | cache_req->rc_ = ListSessions(&reply); | ||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kConnectReset: { | |||||
| cache_req->rc_ = ConnectReset(&rq); | |||||
| break; | |||||
| } | |||||
| default: | default: | ||||
| std::string errMsg("Unknown request type : "); | std::string errMsg("Unknown request type : "); | ||||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | ||||
| @@ -647,7 +751,13 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||||
| // We will re-tag the request back to the grpc queue. Once it comes back from the client, | // We will re-tag the request back to the grpc queue. Once it comes back from the client, | ||||
| // the CacheServerRequest, i.e. the pointer cache_req, will be free | // the CacheServerRequest, i.e. the pointer cache_req, will be free | ||||
| if (!global_shutdown_) { | if (!global_shutdown_) { | ||||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||||
| if (!internal_request) { | |||||
| cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); | |||||
| } else { | |||||
| // This is an internal request and is not tied to rpc. But need to post because there | |||||
| // is a thread waiting on the completion of this request. | |||||
| cache_req->wp_.Set(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -667,12 +777,20 @@ CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int | |||||
| int32_t shared_meory_sz_in_gb, float memory_cap_ratio) | int32_t shared_meory_sz_in_gb, float memory_cap_ratio) | ||||
| : top_(spill_path), | : top_(spill_path), | ||||
| num_workers_(num_workers), | num_workers_(num_workers), | ||||
| num_grpc_workers_(num_workers_), | |||||
| port_(port), | port_(port), | ||||
| shared_memory_sz_in_gb_(shared_meory_sz_in_gb), | shared_memory_sz_in_gb_(shared_meory_sz_in_gb), | ||||
| global_shutdown_(false), | global_shutdown_(false), | ||||
| memory_cap_ratio_(memory_cap_ratio), | memory_cap_ratio_(memory_cap_ratio), | ||||
| cur_mem_usage_(0) { | |||||
| memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_; | |||||
| numa_affinity_(true) { | |||||
| hw_info_ = std::make_shared<CacheServerHW>(); | |||||
| // If we are not linked with numa library (i.e. NUMA_ENABLED is false), turn off cpu | |||||
| // affinity which can make performance worse. | |||||
| if (!CacheServerHW::numa_enabled()) { | |||||
| numa_affinity_ = false; | |||||
| MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build " | |||||
| "that is compiled with numa support for more optimal performance"; | |||||
| } | |||||
| } | } | ||||
| Status CacheServer::Run(int msg_qid) { | Status CacheServer::Run(int msg_qid) { | ||||
| @@ -719,51 +837,52 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { | |||||
| Status CacheServer::DestroySession(CacheRequest *rq) { | Status CacheServer::DestroySession(CacheRequest *rq) { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); | CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); | ||||
| auto drop_session_id = rq->connection_info().session_id(); | auto drop_session_id = rq->connection_info().session_id(); | ||||
| UniqueLock lck(&sessions_lock_); | |||||
| // First validate that this session exists | |||||
| auto it = active_sessions_.find(drop_session_id); | |||||
| if (it == active_sessions_.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!"); | |||||
| } | |||||
| // Grab the locks in the correct order to avoid deadlock. | |||||
| UniqueLock sess_lck(&sessions_lock_); | |||||
| UniqueLock lck(&rwLock_); | |||||
| // Iterate over the set of connection id's for this session that we're dropping and erase each one. | // Iterate over the set of connection id's for this session that we're dropping and erase each one. | ||||
| { | |||||
| UniqueLock rwlck(&rwLock_); | |||||
| for (auto drop_connection_id : it->second) { | |||||
| auto cache_drop_it = all_caches_.find(drop_connection_id); | |||||
| if (cache_drop_it == all_caches_.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry."); | |||||
| } | |||||
| all_caches_.erase(cache_drop_it); | |||||
| MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id; | |||||
| // **Do not bother to remove the cache connection id from the active session because we will soon remove the | |||||
| // entire session. | |||||
| bool found = false; | |||||
| for (auto it = all_caches_.begin(); it != all_caches_.end();) { | |||||
| auto connection_id = it->first; | |||||
| auto session_id = GetSessionID(connection_id); | |||||
| // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. | |||||
| // So we will just manually do it. | |||||
| if (session_id == drop_session_id) { | |||||
| found = true; | |||||
| it = all_caches_.erase(it); | |||||
| MS_LOG(INFO) << "Destroy cache with id " << connection_id; | |||||
| } else { | |||||
| ++it; | |||||
| } | } | ||||
| } | } | ||||
| // Finally remove the session itself | // Finally remove the session itself | ||||
| active_sessions_.erase(it); | |||||
| MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; | |||||
| return Status::OK(); | |||||
| auto n = active_sessions_.erase(drop_session_id); | |||||
| if (n > 0) { | |||||
| MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; | |||||
| return Status::OK(); | |||||
| } else { | |||||
| if (found) { | |||||
| std::string errMsg = | |||||
| "A destroy cache request has been completed but it had a stale session id " + std::to_string(drop_session_id); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } else { | |||||
| std::string errMsg = "Session id " + std::to_string(drop_session_id) + " not found."; | |||||
| return Status(StatusCode::kFileNotExist, errMsg); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| session_id_type CacheServer::GenerateSessionID() { | session_id_type CacheServer::GenerateSessionID() { | ||||
| UniqueLock lock(&sessions_lock_); | |||||
| UniqueLock sess_lck(&sessions_lock_); | |||||
| auto mt = GetRandomDevice(); | auto mt = GetRandomDevice(); | ||||
| std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max()); | std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max()); | ||||
| session_id_type session_id; | session_id_type session_id; | ||||
| bool duplicate = false; | bool duplicate = false; | ||||
| do { | do { | ||||
| session_id = distribution(mt); | session_id = distribution(mt); | ||||
| auto it = active_sessions_.find(session_id); | |||||
| duplicate = (it != active_sessions_.end()); | |||||
| auto r = active_sessions_.insert(session_id); | |||||
| duplicate = !r.second; | |||||
| } while (duplicate); | } while (duplicate); | ||||
| // Add this session to our tracking of active sessions with initialized empty set of connections. | |||||
| active_sessions_[session_id] = std::set<connection_id_type>(); | |||||
| return session_id; | return session_id; | ||||
| } | } | ||||
| @@ -789,7 +908,7 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheServer::RpcRequest(int32_t worker_id) { | |||||
| Status CacheServer::RpcRequest(worker_id_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); | RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -820,12 +939,22 @@ Status CacheServer::GlobalShutdown() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t CacheServer::GetTotalSystemMemory() { | |||||
| auto pages = sysconf(_SC_PHYS_PAGES); | |||||
| auto page_size = sysconf(_SC_PAGE_SIZE); | |||||
| auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size); | |||||
| MS_LOG(INFO) << "Total physical RAM in bytes: " << total; | |||||
| return total; | |||||
| worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { | |||||
| auto num_numa_nodes = GetNumaNodeCount(); | |||||
| MS_ASSERT(numa_id < num_numa_nodes); | |||||
| auto num_workers_per_node = GetNumWorkers() / num_numa_nodes; | |||||
| std::mt19937 gen = GetRandomDevice(); | |||||
| std::uniform_int_distribution<worker_id_t> dist(0, num_workers_per_node - 1); | |||||
| auto n = dist(gen); | |||||
| worker_id_t worker_id = n * num_numa_nodes + numa_id; | |||||
| MS_ASSERT(worker_id < GetNumWorkers()); | |||||
| return worker_id; | |||||
| } | |||||
| worker_id_t CacheServer::GetRandomWorker() { | |||||
| std::mt19937 gen = GetRandomDevice(); | |||||
| std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1); | |||||
| return dist(gen); | |||||
| } | } | ||||
| Status CacheServer::Builder::IpcResourceCleanup() { | Status CacheServer::Builder::IpcResourceCleanup() { | ||||
| @@ -842,6 +971,8 @@ Status CacheServer::Builder::IpcResourceCleanup() { | |||||
| rc = mem.Attach(); | rc = mem.Attach(); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } else { | |||||
| RETURN_IF_NOT_OK(mem.Detach()); | |||||
| } | } | ||||
| int32_t num_attached; | int32_t num_attached; | ||||
| RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached)); | RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached)); | ||||
| @@ -892,5 +1023,16 @@ Status CacheServer::Builder::SanityCheck() { | |||||
| RETURN_IF_NOT_OK(IpcResourceCleanup()); | RETURN_IF_NOT_OK(IpcResourceCleanup()); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| CacheServer::Builder::Builder() | |||||
| : top_("/tmp"), | |||||
| num_workers_(std::thread::hardware_concurrency() / 2), | |||||
| port_(50052), | |||||
| shared_memory_sz_in_gb_(4), | |||||
| memory_cap_ratio_(0.8) { | |||||
| if (num_workers_ == 0) { | |||||
| num_workers_ = 1; | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,23 +17,31 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | ||||
| #include <stdlib.h> | |||||
| #include <string.h> | #include <string.h> | ||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <atomic> | #include <atomic> | ||||
| #include <chrono> | |||||
| #include <iostream> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include <set> | #include <set> | ||||
| #include <thread> | |||||
| #include "minddata/dataset/engine/cache/cache_hw.h" | |||||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||||
| #include "minddata/dataset/engine/cache/cache_service.h" | #include "minddata/dataset/engine/cache/cache_service.h" | ||||
| #include "minddata/dataset/engine/cache/cache_grpc_server.h" | #include "minddata/dataset/engine/cache/cache_grpc_server.h" | ||||
| #include "minddata/dataset/engine/cache/cache_pool.h" | |||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/util/allocator.h" | #include "minddata/dataset/util/allocator.h" | ||||
| #include "minddata/dataset/util/arena.h" | #include "minddata/dataset/util/arena.h" | ||||
| #include "minddata/dataset/util/cache_pool.h" | |||||
| #include "minddata/dataset/util/lock.h" | #include "minddata/dataset/util/lock.h" | ||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/semaphore.h" | |||||
| #include "minddata/dataset/util/service.h" | #include "minddata/dataset/util/service.h" | ||||
| #include "minddata/dataset/util/services.h" | #include "minddata/dataset/util/services.h" | ||||
| #include "minddata/dataset/util/system_pool.h" | #include "minddata/dataset/util/system_pool.h" | ||||
| @@ -47,9 +55,10 @@ class CacheServer : public Service { | |||||
| public: | public: | ||||
| friend class Services; | friend class Services; | ||||
| using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | ||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {} | |||||
| Builder(); | |||||
| ~Builder() = default; | ~Builder() = default; | ||||
| @@ -161,26 +170,40 @@ class CacheServer : public Service { | |||||
| /// \return Status object | /// \return Status object | ||||
| static Status ReturnRequestTag(CacheServerRequest *p); | static Status ReturnRequestTag(CacheServerRequest *p); | ||||
| /// \brief This returns the size (in bytes) of the physical RAM on the machine. | |||||
| /// \return the size (in bytes) of the physical RAM on the machine. | |||||
| static int64_t GetTotalSystemMemory(); | |||||
| /// Return an instance of the numa control | |||||
| std::shared_ptr<CacheServerHW> GetHWControl() { return hw_info_; } | |||||
| /// \brief Internally this is how much we will try to use without exceeding the limit | |||||
| /// \return Internal cap maximum | |||||
| int64_t GetAvailableSystemMemory() { return memory_cap_; } | |||||
| /// \brief Set CPU affinity | |||||
| Status SetAffinity(const Task &tk, numa_id_t numa_node) { return hw_info_->SetAffinity(tk, numa_node); } | |||||
| /// \brief Find out the current memory usage | |||||
| int64_t GetMemoryUsage() { return cur_mem_usage_; } | |||||
| /// \brief return number of workers | |||||
| auto GetNumWorkers() const { return num_workers_; } | |||||
| /// \brief This updates our current memory usage. | |||||
| enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 }; | |||||
| void UpdateMemoryUsage(int64_t sz, MemUsageOp op) { | |||||
| if (op == MemUsageOp::kAllocate) { | |||||
| cur_mem_usage_ += sz; | |||||
| } else { | |||||
| cur_mem_usage_ -= sz; | |||||
| } | |||||
| } | |||||
| /// \brief return number of grpc workers | |||||
| auto GetNumGrpcWorkers() const { return num_grpc_workers_; } | |||||
| /// \brief return number of numa nodes | |||||
| auto GetNumaNodeCount() const { return hw_info_->GetNumaNodeCount(); } | |||||
| /// \brief Assign a worker by a numa id | |||||
| /// \return worker id | |||||
| worker_id_t GetWorkerByNumaId(numa_id_t node_id); | |||||
| /// \brief Randomly pick a worker | |||||
| /// \return worker id | |||||
| worker_id_t GetRandomWorker(); | |||||
| /// \brief Check if we bind threads to numa cores | |||||
| bool IsNumaAffinityOn() const { return numa_affinity_; } | |||||
| /// \brief Internal function to do row batch fetch | |||||
| /// \param rq Request | |||||
| /// \param reply Reply | |||||
| /// \return Status object | |||||
| Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Return the memory cap ratio | |||||
| float GetMemoryCapRatio() const { return memory_cap_ratio_; } | |||||
| private: | private: | ||||
| static std::once_flag init_instance_flag_; | static std::once_flag init_instance_flag_; | ||||
| @@ -189,20 +212,21 @@ class CacheServer : public Service { | |||||
| mutable RWLock sessions_lock_; | mutable RWLock sessions_lock_; | ||||
| std::string top_; | std::string top_; | ||||
| cache_index all_caches_; | cache_index all_caches_; | ||||
| std::map<session_id_type, std::set<connection_id_type>> active_sessions_; | |||||
| std::set<session_id_type> active_sessions_; | |||||
| std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_; | std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_; | ||||
| std::shared_ptr<QueueList<CacheServerRequest *>> free_list_; | std::shared_ptr<QueueList<CacheServerRequest *>> free_list_; | ||||
| std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_; | |||||
| std::vector<std::unique_ptr<MemGuard<CacheServerRequest, NumaAllocator<CacheServerRequest>>>> tag_; | |||||
| std::shared_ptr<CacheServerGreeterImpl> comm_layer_; | std::shared_ptr<CacheServerGreeterImpl> comm_layer_; | ||||
| std::shared_ptr<MemoryPool> mp_; | |||||
| TaskGroup vg_; | TaskGroup vg_; | ||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t num_grpc_workers_; | |||||
| int32_t port_; | int32_t port_; | ||||
| int32_t shared_memory_sz_in_gb_; | int32_t shared_memory_sz_in_gb_; | ||||
| std::atomic<bool> global_shutdown_; | std::atomic<bool> global_shutdown_; | ||||
| float memory_cap_ratio_; | float memory_cap_ratio_; | ||||
| int64_t memory_cap_; | |||||
| std::atomic<int64_t> cur_mem_usage_; | |||||
| std::shared_ptr<CacheServerHW> hw_info_; | |||||
| std::map<worker_id_t, Task *> numa_tasks_; | |||||
| bool numa_affinity_; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param spill_path Top directory for spilling buffers to. | /// \param spill_path Top directory for spilling buffers to. | ||||
| @@ -226,11 +250,11 @@ class CacheServer : public Service { | |||||
| Status DestroyCache(CacheRequest *rq); | Status DestroyCache(CacheRequest *rq); | ||||
| /// \brief Entry point for all internal server threads. | /// \brief Entry point for all internal server threads. | ||||
| Status ServerRequest(int32_t worker_id); | |||||
| Status ServerRequest(worker_id_t worker_id); | |||||
| /// \brief Entry point for all grpc threads. | /// \brief Entry point for all grpc threads. | ||||
| /// \return | /// \return | ||||
| Status RpcRequest(int32_t worker_id); | |||||
| Status RpcRequest(worker_id_t worker_id); | |||||
| Status DestroySession(CacheRequest *rq); | Status DestroySession(CacheRequest *rq); | ||||
| @@ -266,12 +290,6 @@ class CacheServer : public Service { | |||||
| Status FastCacheRow(CacheRequest *rq, CacheReply *reply); | Status FastCacheRow(CacheRequest *rq, CacheReply *reply); | ||||
| Status CacheRow(CacheRequest *rq, CacheReply *reply); | Status CacheRow(CacheRequest *rq, CacheReply *reply); | ||||
| /// \brief Internal function to do row batch fetch | |||||
| /// \param rq Request | |||||
| /// \param reply Reply | |||||
| /// \return Status object | |||||
| Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Internal function to get statistics | /// \brief Internal function to get statistics | ||||
| /// \param rq | /// \param rq | ||||
| /// \param reply | /// \param reply | ||||
| @@ -309,6 +327,9 @@ class CacheServer : public Service { | |||||
| /// \param reply | /// \param reply | ||||
| /// \return Status object | /// \return Status object | ||||
| Status ListSessions(CacheReply *reply); | Status ListSessions(CacheReply *reply); | ||||
| /// \brief Connect request by a pipeline | |||||
| Status ConnectReset(CacheRequest *rq); | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,51 +13,45 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <random> | |||||
| #include "minddata/dataset/engine/cache/cache_service.h" | #include "minddata/dataset/engine/cache/cache_service.h" | ||||
| #include "minddata/dataset/engine/cache/cache_server.h" | #include "minddata/dataset/engine/cache/cache_server.h" | ||||
| #include "minddata/dataset/engine/cache/cache_numa.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/slice.h" | #include "minddata/dataset/util/slice.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) | CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) | ||||
| : root_(root), | : root_(root), | ||||
| cache_mem_sz_(mem_sz), | |||||
| cache_mem_sz_(mem_sz * 1048576L), // mem_sz is in MB unit | |||||
| cp_(nullptr), | cp_(nullptr), | ||||
| next_id_(0), | next_id_(0), | ||||
| generate_id_(generate_id), | generate_id_(generate_id), | ||||
| st_(generate_id ? State::kBuildPhase : State::kNone), | |||||
| cur_mem_usage_(0), | |||||
| cur_disk_usage_(0) {} | |||||
| num_clients_(0), | |||||
| st_(generate_id ? CacheServiceState::kBuildPhase : CacheServiceState::kNone) {} | |||||
| CacheService::~CacheService() { (void)ServiceStop(); } | CacheService::~CacheService() { (void)ServiceStop(); } | ||||
| bool CacheService::UseArena() { | |||||
| // If fixed size, use Arena instead of the pool from global context. | |||||
| return (cache_mem_sz_ > 0); | |||||
| } | |||||
| Status CacheService::DoServiceStart() { | Status CacheService::DoServiceStart() { | ||||
| std::shared_ptr<MemoryPool> mp_; | |||||
| CacheServer &cs = CacheServer::GetInstance(); | CacheServer &cs = CacheServer::GetInstance(); | ||||
| if (UseArena()) { | |||||
| auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L; | |||||
| float memory_cap_ratio = cs.GetMemoryCapRatio(); | |||||
| if (cache_mem_sz_ > 0) { | |||||
| auto avail_mem = CacheServerHW::GetTotalSystemMemory(); | |||||
| if (cache_mem_sz_ > avail_mem) { | if (cache_mem_sz_ > avail_mem) { | ||||
| // Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway. | // Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway. | ||||
| MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem | |||||
| << " MB"; | |||||
| MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " while available system memory " << avail_mem; | |||||
| } | } | ||||
| // Create a fixed size arena based on the parameter. | |||||
| std::shared_ptr<Arena> arena; | |||||
| RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); | |||||
| mp_ = std::move(arena); | |||||
| // update the global usage only. | |||||
| cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate); | |||||
| } else { | |||||
| // Unlimited size. Simply use a system pool. Another choice is CircularPool. | |||||
| mp_ = std::make_shared<SystemPool>(); | |||||
| memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem; | |||||
| } | |||||
| numa_pool_ = std::make_shared<NumaMemoryPool>(cs.GetHWControl(), memory_cap_ratio); | |||||
| // It is possible we aren't able to allocate the pool for many reasons. | |||||
| std::vector<numa_id_t> avail_nodes = numa_pool_->GetAvailableNodes(); | |||||
| if (avail_nodes.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Unable to bring up numa memory pool"); | |||||
| } | } | ||||
| // Put together a CachePool for backing up the Tensor | |||||
| cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), UseArena(), root_); | |||||
| // Put together a CachePool for backing up the Tensor. | |||||
| cp_ = std::make_shared<CachePool>(numa_pool_, root_); | |||||
| RETURN_IF_NOT_OK(cp_->ServiceStart()); | RETURN_IF_NOT_OK(cp_->ServiceStart()); | ||||
| // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. | // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. | ||||
| cookie_ = cp_->MyName(); | cookie_ = cp_->MyName(); | ||||
| @@ -68,26 +62,18 @@ Status CacheService::DoServiceStop() { | |||||
| if (cp_ != nullptr) { | if (cp_ != nullptr) { | ||||
| RETURN_IF_NOT_OK(cp_->ServiceStop()); | RETURN_IF_NOT_OK(cp_->ServiceStop()); | ||||
| } | } | ||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| if (UseArena()) { | |||||
| cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage() | |||||
| << " bytes."; | |||||
| cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree); | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) { | Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) { | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | RETURN_UNEXPECTED_IF_NULL(row_id_generated); | ||||
| if (st_ == State::kFetchPhase) { | |||||
| if (st_ == CacheServiceState::kFetchPhase) { | |||||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | ||||
| // allow other to cache more rows. | // allow other to cache more rows. | ||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | ||||
| } | } | ||||
| if (st_ == State::kNoLocking) { | |||||
| if (st_ == CacheServiceState::kNoLocking) { | |||||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | // We ignore write this request once we turn off locking on the B+ tree. So we will just | ||||
| // return out of memory from now on. | // return out of memory from now on. | ||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| @@ -128,26 +114,13 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||||
| all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); | all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); | ||||
| total_sz += msg->data_sz()->Get(i); | total_sz += msg->data_sz()->Get(i); | ||||
| } | } | ||||
| // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. | |||||
| // Otherwise, we check how much (globally) how much we use and may simply spill to disk | |||||
| // directly. | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); | |||||
| Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly); | |||||
| // Now we cache the buffer. | |||||
| Status rc = cp_->Insert(*row_id_generated, all_data); | |||||
| if (rc == Status(StatusCode::kDuplicateKey)) { | if (rc == Status(StatusCode::kDuplicateKey)) { | ||||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | MS_LOG(DEBUG) << "Ignoring duplicate key."; | ||||
| } else { | } else { | ||||
| RETURN_IF_NOT_OK(rc); | RETURN_IF_NOT_OK(rc); | ||||
| } | } | ||||
| // All good, then update the memory usage local and global (if not using arena) | |||||
| if (write_to_disk_directly) { | |||||
| cur_disk_usage_ += total_sz; | |||||
| } else { | |||||
| cur_mem_usage_ += total_sz; | |||||
| if (!UseArena()) { | |||||
| cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| RETURN_STATUS_UNEXPECTED(e.what()); | RETURN_STATUS_UNEXPECTED(e.what()); | ||||
| @@ -157,12 +130,12 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||||
| Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { | Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | RETURN_UNEXPECTED_IF_NULL(row_id_generated); | ||||
| if (st_ == State::kFetchPhase) { | |||||
| if (st_ == CacheServiceState::kFetchPhase) { | |||||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | ||||
| // allow other to cache more rows. | // allow other to cache more rows. | ||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | ||||
| } | } | ||||
| if (st_ == State::kNoLocking) { | |||||
| if (st_ == CacheServiceState::kNoLocking) { | |||||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | // We ignore write this request once we turn off locking on the B+ tree. So we will just | ||||
| // return out of memory from now on. | // return out of memory from now on. | ||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| @@ -183,27 +156,13 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ | |||||
| } | } | ||||
| *row_id_generated = msg->row_id(); | *row_id_generated = msg->row_id(); | ||||
| } | } | ||||
| // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. | |||||
| // Otherwise, we check how much (globally) how much we use and may simply spill to disk | |||||
| // directly. | |||||
| auto total_sz = src.GetSize(); | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); | |||||
| Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly); | |||||
| // Now we cache the buffer. | |||||
| Status rc = cp_->Insert(*row_id_generated, {src}); | |||||
| if (rc == Status(StatusCode::kDuplicateKey)) { | if (rc == Status(StatusCode::kDuplicateKey)) { | ||||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | MS_LOG(DEBUG) << "Ignoring duplicate key."; | ||||
| } else { | } else { | ||||
| RETURN_IF_NOT_OK(rc); | RETURN_IF_NOT_OK(rc); | ||||
| } | } | ||||
| // All good, then update the memory usage local and global (if not using arena) | |||||
| if (write_to_disk_directly) { | |||||
| cur_disk_usage_ += total_sz; | |||||
| } else { | |||||
| cur_mem_usage_ += total_sz; | |||||
| if (!UseArena()) { | |||||
| cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| RETURN_STATUS_UNEXPECTED(e.what()); | RETURN_STATUS_UNEXPECTED(e.what()); | ||||
| @@ -247,52 +206,116 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *out, | |||||
| int64_t *mem_sz) { | |||||
| Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | |||||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) { | |||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| RETURN_UNEXPECTED_IF_NULL(mem_sz); | |||||
| const auto num_elements = v.size(); | |||||
| *mem_sz = (num_elements + 1) * sizeof(int64_t); | |||||
| (*out).reserve(num_elements); | |||||
| std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v; | |||||
| datalocator_v.reserve(v.size()); | |||||
| for (auto row_id : v) { | for (auto row_id : v) { | ||||
| auto sz = cp_->GetSize(row_id); | |||||
| if (sz > 0) { | |||||
| (*out).emplace_back(row_id, sz); | |||||
| (*mem_sz) += sz; | |||||
| } else { | |||||
| // key not found | |||||
| (*out).emplace_back(-1, 0); | |||||
| } | |||||
| flatbuffers::Offset<DataLocatorMsg> offset; | |||||
| RETURN_IF_NOT_OK(cp_->GetDataLocator(row_id, fbb, &offset)); | |||||
| datalocator_v.push_back(offset); | |||||
| } | } | ||||
| auto offset_v = fbb->CreateVector(datalocator_v); | |||||
| BatchDataLocatorMsgBuilder bld(*fbb); | |||||
| bld.add_connection_id(connection_id); | |||||
| bld.add_rows(offset_v); | |||||
| auto offset_final = bld.Finish(); | |||||
| fbb->Finish(offset_final); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &info, | |||||
| WritableSlice *out) const { | |||||
| Status CacheService::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) const { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| if (st_ == State::kBuildPhase) { | |||||
| if (st_ == CacheServiceState::kBuildPhase) { | |||||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | ||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | ||||
| } | } | ||||
| const auto num_elements = v.size(); | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| int32_t numQ = cs.GetNumGrpcWorkers(); | |||||
| auto rng = GetRandomDevice(); | |||||
| std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1); | |||||
| int32_t qID = distribution(rng); | |||||
| std::vector<CacheServerRequest *> cache_rq_list; | |||||
| auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer()); | |||||
| const auto num_elements = p->rows()->size(); | |||||
| auto connection_id = p->connection_id(); | |||||
| cache_rq_list.reserve(num_elements); | |||||
| int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | int64_t data_offset = (num_elements + 1) * sizeof(int64_t); | ||||
| auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer()); | ||||
| offset_array[0] = data_offset; | offset_array[0] = data_offset; | ||||
| for (auto i = 0; i < num_elements; ++i) { | for (auto i = 0; i < num_elements; ++i) { | ||||
| auto sz = info.at(i).second; | |||||
| offset_array[i + 1] = offset_array[i] + sz; | |||||
| auto data_locator = p->rows()->Get(i); | |||||
| auto node_id = data_locator->node_id(); | |||||
| size_t sz = data_locator->size(); | |||||
| void *source_addr = reinterpret_cast<void *>(data_locator->addr()); | |||||
| auto key = data_locator->key(); | |||||
| // Please read the comment in CacheServer::BatchFetchRows where we allocate | |||||
| // the buffer big enough so each thread (which we are going to dispatch) will | |||||
| // not run into false sharing problem. We are going to round up sz to 4k. | |||||
| auto sz_4k = round_up_4K(sz); | |||||
| offset_array[i + 1] = offset_array[i] + sz_4k; | |||||
| if (sz > 0) { | if (sz > 0) { | ||||
| WritableSlice row_data(*out, offset_array[i], sz); | WritableSlice row_data(*out, offset_array[i], sz); | ||||
| auto key = info.at(i).first; | |||||
| size_t bytesRead = 0; | |||||
| RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); | |||||
| if (bytesRead != sz) { | |||||
| MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "." | |||||
| << " Internal key: " << key << "\n"; | |||||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||||
| } | |||||
| // Get a request and send to the proper worker (at some numa node) to do the fetch. | |||||
| worker_id_t worker_id = cs.IsNumaAffinityOn() ? cs.GetWorkerByNumaId(node_id) : cs.GetRandomWorker(); | |||||
| CacheServerRequest *cache_rq; | |||||
| RETURN_IF_NOT_OK(cs.GetFreeRequestTag(qID++ % numQ, &cache_rq)); | |||||
| cache_rq_list.push_back(cache_rq); | |||||
| // Set up all the necessarily field. | |||||
| cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; | |||||
| cache_rq->st_ = CacheServerRequest::STATE::PROCESS; | |||||
| cache_rq->rq_.set_connection_id(connection_id); | |||||
| cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_)); | |||||
| auto dest_addr = row_data.GetMutablePointer(); | |||||
| flatbuffers::FlatBufferBuilder fb2; | |||||
| FetchRowMsgBuilder bld(fb2); | |||||
| bld.add_key(key); | |||||
| bld.add_size(sz); | |||||
| bld.add_source_addr(reinterpret_cast<int64_t>(source_addr)); | |||||
| bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr)); | |||||
| auto offset = bld.Finish(); | |||||
| fb2.Finish(offset); | |||||
| cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); | |||||
| RETURN_IF_NOT_OK(cs.PushRequest(worker_id, cache_rq)); | |||||
| } | |||||
| } | |||||
| // Now wait for all of them to come back. Let go of the shared lock. We shouldn't be holding | |||||
| // any lock while we can wait for a long time. | |||||
| rw.Unlock(); | |||||
| Status rc; | |||||
| for (CacheServerRequest *rq : cache_rq_list) { | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { | |||||
| rc = rq->rc_; | |||||
| } | |||||
| RETURN_IF_NOT_OK(cs.ReturnRequestTag(rq)); | |||||
| } | |||||
| return rc; | |||||
| } | |||||
| Status CacheService::InternalFetchRow(const FetchRowMsg *p) { | |||||
| RETURN_UNEXPECTED_IF_NULL(p); | |||||
| SharedLock rw(&rw_lock_); | |||||
| size_t bytesRead = 0; | |||||
| int64_t key = p->key(); | |||||
| size_t sz = p->size(); | |||||
| void *source_addr = reinterpret_cast<void *>(p->source_addr()); | |||||
| void *dest_addr = reinterpret_cast<void *>(p->dest_addr()); | |||||
| WritableSlice dest(dest_addr, sz); | |||||
| if (source_addr != nullptr) { | |||||
| // We are not checking if the row is still present but simply use the information passed in. | |||||
| // This saves another tree lookup and is faster. | |||||
| ReadableSlice src(source_addr, sz); | |||||
| RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); | |||||
| } else { | |||||
| RETURN_IF_NOT_OK(cp_->Read(key, &dest, &bytesRead)); | |||||
| if (bytesRead != sz) { | |||||
| std::string errMsg = "Unexpected length. Read " + std::to_string(bytesRead) + ". Expected " + std::to_string(sz) + | |||||
| "." + " Internal key: " + std::to_string(key); | |||||
| MS_LOG(ERROR) << errMsg; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | } | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -312,7 +335,7 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) { | |||||
| Status CacheService::FetchSchema(std::string *out) const { | Status CacheService::FetchSchema(std::string *out) const { | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| if (st_ == State::kBuildPhase) { | |||||
| if (st_ == CacheServiceState::kBuildPhase) { | |||||
| // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. | ||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | ||||
| } | } | ||||
| @@ -333,7 +356,7 @@ Status CacheService::BuildPhaseDone() { | |||||
| if (HasBuildPhase()) { | if (HasBuildPhase()) { | ||||
| // Exclusive lock to switch phase | // Exclusive lock to switch phase | ||||
| UniqueLock rw(&rw_lock_); | UniqueLock rw(&rw_lock_); | ||||
| st_ = State::kFetchPhase; | |||||
| st_ = CacheServiceState::kFetchPhase; | |||||
| cp_->SetLocking(false); | cp_->SetLocking(false); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } else { | } else { | ||||
| @@ -348,12 +371,12 @@ Status CacheService::ToggleWriteMode(bool on_off) { | |||||
| } else { | } else { | ||||
| // If we stop accepting write request, we turn off locking for the | // If we stop accepting write request, we turn off locking for the | ||||
| // underlying B+ tree. All future write request we will return kOutOfMemory. | // underlying B+ tree. All future write request we will return kOutOfMemory. | ||||
| if (st_ == State::kNone && !on_off) { | |||||
| st_ = State::kNoLocking; | |||||
| if (st_ == CacheServiceState::kNone && !on_off) { | |||||
| st_ = CacheServiceState::kNoLocking; | |||||
| cp_->SetLocking(on_off); | cp_->SetLocking(on_off); | ||||
| MS_LOG(WARNING) << "Locking mode is switched off."; | MS_LOG(WARNING) << "Locking mode is switched off."; | ||||
| } else if (st_ == State::kNoLocking && on_off) { | |||||
| st_ = State::kNone; | |||||
| } else if (st_ == CacheServiceState::kNoLocking && on_off) { | |||||
| st_ = CacheServiceState::kNone; | |||||
| cp_->SetLocking(on_off); | cp_->SetLocking(on_off); | ||||
| } | } | ||||
| } | } | ||||
| @@ -29,36 +29,28 @@ | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/engine/cache/cache_pool.h" | |||||
| #include "minddata/dataset/util/arena.h" | #include "minddata/dataset/util/arena.h" | ||||
| #include "minddata/dataset/util/btree.h" | #include "minddata/dataset/util/btree.h" | ||||
| #include "minddata/dataset/util/cache_pool.h" | |||||
| #include "minddata/dataset/util/service.h" | #include "minddata/dataset/util/service.h" | ||||
| #include "minddata/dataset/util/services.h" | #include "minddata/dataset/util/services.h" | ||||
| #include "minddata/dataset/util/system_pool.h" | #include "minddata/dataset/util/system_pool.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| /// Some typedef used for BatchFetch | |||||
| using key_size_pair = std::pair<CachePool::key_type, size_t>; | |||||
| /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is | /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is | ||||
| /// created to support spilling | /// created to support spilling | ||||
| class CacheService : public Service { | class CacheService : public Service { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited | /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited | ||||
| /// \param root Spill path. Empty string means no spilling | /// \param root Spill path. Empty string means no spilling | ||||
| /// \param generate_id If the cache service should generate row id for buffer that is cached. | /// \param generate_id If the cache service should generate row id for buffer that is cached. | ||||
| /// For non-mappable dataset, this should be set to true. | /// For non-mappable dataset, this should be set to true. | ||||
| CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); | CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); | ||||
| ~CacheService(); | |||||
| /// \brief For fixed size memory, we will create an Arena. | |||||
| /// \return false if unlimited memory. | |||||
| bool UseArena(); | |||||
| ~CacheService() override; | |||||
| Status DoServiceStart() override; | Status DoServiceStart() override; | ||||
| Status DoServiceStop() override; | Status DoServiceStop() override; | ||||
| @@ -77,18 +69,18 @@ class CacheService : public Service { | |||||
| Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated); | Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated); | ||||
| /// \brief This function is used in preparation for batch fetching. | /// \brief This function is used in preparation for batch fetching. | ||||
| /// It calculates how much memory we should allocate and which row id are present. | |||||
| /// \param[in/out] Pointer to vector of <CachePool::key_type, size_t> | |||||
| /// \param[in/out] mem_sz how much memory is required to batch fetch | |||||
| /// It calculates how much memory we should allocate and which row id are present, etc. | |||||
| /// All needed results are stored in the flat buffer. | |||||
| /// \return Status object | /// \return Status object | ||||
| Status PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *, int64_t *mem_sz); | |||||
| Status PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v, | |||||
| const std::shared_ptr<flatbuffers::FlatBufferBuilder> &); | |||||
| /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded | ||||
| /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. | ||||
| /// \param[in] v A vector of row id. | /// \param[in] v A vector of row id. | ||||
| /// \param[out] out A contiguous memory buffer that holds the requested rows. | /// \param[out] out A contiguous memory buffer that holds the requested rows. | ||||
| /// \return Status object | /// \return Status object | ||||
| Status BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &, WritableSlice *out) const; | |||||
| Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &, WritableSlice *out) const; | |||||
| /// \brief Getter function | /// \brief Getter function | ||||
| /// \return Spilling path | /// \return Spilling path | ||||
| @@ -96,7 +88,7 @@ class CacheService : public Service { | |||||
| /// \brief A structure returned from the cache server for statistics request. | /// \brief A structure returned from the cache server for statistics request. | ||||
| class ServiceStat { | class ServiceStat { | ||||
| public: | public: | ||||
| using state_type = std::underlying_type<State>::type; | |||||
| using state_type = std::underlying_type<CacheServiceState>::type; | |||||
| ServiceStat() : state_(0) {} | ServiceStat() : state_(0) {} | ||||
| ~ServiceStat() = default; | ~ServiceStat() = default; | ||||
| CachePool::CacheStat stat_{}; | CachePool::CacheStat stat_{}; | ||||
| @@ -134,10 +126,6 @@ class CacheService : public Service { | |||||
| /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. | /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. | ||||
| /// \return Status object | /// \return Status object | ||||
| Status BuildPhaseDone(); | Status BuildPhaseDone(); | ||||
| /// \brief Find out the current memory usage | |||||
| int64_t GetMemoryUsage() { return cur_mem_usage_; } | |||||
| /// \brief Find out the current disk usage | |||||
| int64_t GetDiskUsage() { return cur_disk_usage_; } | |||||
| /// \brief For kToggleWriteMode request | /// \brief For kToggleWriteMode request | ||||
| Status ToggleWriteMode(bool on_off); | Status ToggleWriteMode(bool on_off); | ||||
| @@ -149,14 +137,10 @@ class CacheService : public Service { | |||||
| std::atomic<row_id_type> next_id_; | std::atomic<row_id_type> next_id_; | ||||
| bool generate_id_; | bool generate_id_; | ||||
| std::string cookie_; | std::string cookie_; | ||||
| State st_; | |||||
| std::atomic<int32_t> num_clients_; | |||||
| CacheServiceState st_; | |||||
| std::string schema_; | std::string schema_; | ||||
| // If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it. | |||||
| // Otherwise we track how much is in memory and how much is on disk (if root_ is not empty). | |||||
| // We use them to control when we should stop caching in memory in the case when there is no | |||||
| // Arena. | |||||
| std::atomic<int64_t> cur_mem_usage_; | |||||
| std::atomic<int64_t> cur_disk_usage_; | |||||
| std::shared_ptr<NumaMemoryPool> numa_pool_; | |||||
| // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make | // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make | ||||
| // this request after we hit memory full or disk full. So the result is unlikely to change. | // this request after we hit memory full or disk full. So the result is unlikely to change. | ||||
| std::mutex get_key_miss_mux_; | std::mutex get_key_miss_mux_; | ||||
| @@ -164,6 +148,8 @@ class CacheService : public Service { | |||||
| /// \brief Private function to generate a row id | /// \brief Private function to generate a row id | ||||
| /// \return Row id assigned. | /// \return Row id assigned. | ||||
| row_id_type GetNextRowId() { return next_id_.fetch_add(1); } | row_id_type GetNextRowId() { return next_id_.fetch_add(1); } | ||||
| Status InternalFetchRow(const FetchRowMsg *p); | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -65,6 +65,7 @@ table ServiceStatMsg { | |||||
| num_mem_cached:int64; | num_mem_cached:int64; | ||||
| num_disk_cached:int64; | num_disk_cached:int64; | ||||
| avg_cache_sz:int64; | avg_cache_sz:int64; | ||||
| num_numa_hit:int64; | |||||
| min_row_id:int64; | min_row_id:int64; | ||||
| max_row_id:int64; | max_row_id:int64; | ||||
| state:int8; | state:int8; | ||||
| @@ -89,8 +90,10 @@ table CreateCacheRequestMsg { | |||||
| /// Return result of CreateCacheRequest | /// Return result of CreateCacheRequest | ||||
| table CreateCacheReplyMsg { | table CreateCacheReplyMsg { | ||||
| connection_id:int64; | |||||
| client_id:int32; | |||||
| connection_id:uint64; | |||||
| cookie:string; | cookie:string; | ||||
| cpu_id:[int32]; | |||||
| } | } | ||||
| table ListSessionMsg { | table ListSessionMsg { | ||||
| @@ -102,3 +105,22 @@ table ListSessionMsg { | |||||
| table ListSessionsMsg { | table ListSessionsMsg { | ||||
| sessions:[ListSessionMsg]; | sessions:[ListSessionMsg]; | ||||
| } | } | ||||
| table DataLocatorMsg { | |||||
| key:int64; | |||||
| node_id:int32; | |||||
| addr:int64; | |||||
| size:int64; | |||||
| } | |||||
| table BatchDataLocatorMsg { | |||||
| connection_id:uint64; | |||||
| rows:[DataLocatorMsg]; | |||||
| } | |||||
| table FetchRowMsg { | |||||
| key:int64; | |||||
| source_addr:int64; | |||||
| dest_addr:int64; | |||||
| size:int64; | |||||
| } | |||||
| @@ -0,0 +1,32 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||||
| if (ENABLE_CACHE) | |||||
| ms_protobuf_generate(CACHE_PERF_PROTO_SRCS CACHE_PERF_PROTO_HDRS cache_perf.proto) | |||||
| add_executable(cache_perf cache_perf.cc cache_msg.cc cache_perf_run.cc ${CACHE_PERF_PROTO_SRCS}) | |||||
| target_link_libraries(cache_perf | |||||
| _c_dataengine | |||||
| _c_mindrecord | |||||
| mindspore::protobuf | |||||
| mindspore_gvar | |||||
| ${PYTHON_LIBRARIES} | |||||
| pthread) | |||||
| if (USE_GLOG) | |||||
| target_link_libraries(cache_perf mindspore::glog) | |||||
| endif () | |||||
| add_executable(cache_pipeline cache_pipeline.cc cache_msg.cc cache_pipeline_run.cc ${CACHE_PERF_PROTO_SRCS}) | |||||
| target_link_libraries(cache_pipeline | |||||
| _c_dataengine | |||||
| _c_mindrecord | |||||
| mindspore::protobuf | |||||
| mindspore_gvar | |||||
| ${PYTHON_LIBRARIES} | |||||
| pthread) | |||||
| if (USE_GLOG) | |||||
| target_link_libraries(cache_pipeline mindspore::glog) | |||||
| endif () | |||||
| endif () | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/perf/cache_msg.h" | |||||
| #include <sys/types.h> | |||||
| #include <sys/ipc.h> | |||||
| #include <sys/msg.h> | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status CachePerfMsg::Send(int32_t qID) { | |||||
| auto err = msgsnd(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), IPC_NOWAIT); | |||||
| if (err == -1) { | |||||
| std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CachePerfMsg::Receive(int32_t qID) { | |||||
| // This is a blocking call. Either there is some message or we the queue is removed when | |||||
| // the destructor is called. | |||||
| auto err = msgrcv(qID, reinterpret_cast<void *>(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR); | |||||
| if (err == -1) { | |||||
| if (errno == EIDRM) { | |||||
| return Status(StatusCode::kInterrupted); | |||||
| } else { | |||||
| std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ | |||||
| #include <cstdint> | |||||
| #include <limits> | |||||
| #include <string> | |||||
| #include "proto/cache_perf.pb.h" | |||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // All our messages are very small. So we will use the stack version without the need | |||||
| // to allocate memory. | |||||
| struct CacheSmallMsg { | |||||
| int64_t mtype; | |||||
| union { | |||||
| char mtext[1]; | |||||
| struct { | |||||
| int32_t type; // the first 4 bytes is the RequestType | |||||
| int32_t proto_sz; | |||||
| char proto_buffer[kSharedMessageSize]; | |||||
| } msg; | |||||
| } body; | |||||
| }; | |||||
| /// A message queue structure between the parent and the child process | |||||
| class CachePerfMsg { | |||||
| public: | |||||
| enum MessageType : int16_t { | |||||
| kInterrupt = 0, | |||||
| kEpochResult = 1, | |||||
| kEpochStart = 2, | |||||
| kEpochEnd = 3, | |||||
| kError = 4, | |||||
| // Add new message before it. | |||||
| kUnknownMessage = 32767 | |||||
| }; | |||||
| CachePerfMsg() : small_msg_{1} { | |||||
| small_msg_.body.msg.type = kUnknownMessage; | |||||
| small_msg_.body.msg.proto_sz = 0; | |||||
| small_msg_.body.msg.proto_buffer[0] = 0; | |||||
| } | |||||
| ~CachePerfMsg() = default; | |||||
| char *GetMutableBuffer() { return small_msg_.body.msg.proto_buffer; } | |||||
| Status Send(int32_t qID); | |||||
| void SetType(MessageType requestType) { small_msg_.body.msg.type = requestType; } | |||||
| void SetProtoBufSz(size_t sz) { small_msg_.body.msg.proto_sz = sz; } | |||||
| MessageType GetType() const { return static_cast<MessageType>(small_msg_.body.msg.type); } | |||||
| size_t GetProtoBufSz() const { return small_msg_.body.msg.proto_sz; } | |||||
| Status Receive(int32_t qID); | |||||
| private: | |||||
| CacheSmallMsg small_msg_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef USE_GLOG | |||||
| #include <glog/logging.h> | |||||
| #endif | |||||
| #include <iostream> | |||||
| #include "minddata/dataset/engine/cache/perf/cache_perf_run.h" | |||||
| namespace ds = mindspore::dataset; | |||||
| int main(int argc, char **argv) { | |||||
| #ifdef USE_GLOG | |||||
| FLAGS_log_dir = "/tmp"; | |||||
| google::InitGoogleLogging(argv[0]); | |||||
| #endif | |||||
| ds::CachePerfRun cachePerfRun; | |||||
| if (cachePerfRun.ProcessArgs(argc, argv) == 0) { | |||||
| std::cout << cachePerfRun << std::endl; | |||||
| ds::Status rc = cachePerfRun.Run(); | |||||
| if (rc.IsError()) { | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| } | |||||
| return static_cast<int>(rc.get_code()); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package mindspore.dataset; | |||||
| option cc_enable_arenas = true; | |||||
| message PipelineWorkerEpochSummary { | |||||
| int32 pipeline = 1; | |||||
| int32 worker = 2; | |||||
| int64 min = 3; | |||||
| int64 max = 4; | |||||
| int64 avg = 5; | |||||
| int64 med = 6; | |||||
| int64 cnt = 7; | |||||
| int64 elapse = 8; | |||||
| } | |||||
| message EpochDone { | |||||
| int32 pipeline = 1; | |||||
| } | |||||
| message ErrorMsg { | |||||
| int32 rc = 1; | |||||
| string msg = 2; | |||||
| } | |||||
| @@ -0,0 +1,575 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/perf/cache_perf_run.h" | |||||
| #include <string.h> | |||||
| #include <sys/types.h> | |||||
| #include <sys/stat.h> | |||||
| #include <sys/wait.h> | |||||
| #include <sys/ipc.h> | |||||
| #include <sys/msg.h> | |||||
| #include <unistd.h> | |||||
| #include <algorithm> | |||||
| #include <chrono> | |||||
| #include <iomanip> | |||||
| #include <sstream> | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/services.h" | |||||
| #include "minddata/dataset/util/sig_handler.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| const char CachePerfRun::kCachePipelineBinary[] = "cache_pipeline"; | |||||
| void CachePerfRun::PrintHelp() { | |||||
| std::cout << "Options:\n" | |||||
| " -h,--help: Show this usage message\n" | |||||
| " -s,--num_rows: Set the sample size, i.e., the number of " | |||||
| "rows\n" | |||||
| " -r,--row_size: Set the average row size\n" | |||||
| " -n,--pipeline: Set the number of parallel pieplines. Default = " | |||||
| << kDftNumOfPipelines | |||||
| << "\n" | |||||
| " -e,--epoch: Set the number of epochs. Default = " | |||||
| << kDftNumberOfEpochs | |||||
| << "\n" | |||||
| " --shuffle: Set shuffle=True. Default = " | |||||
| << std::boolalpha << kDftShuffle | |||||
| << "\n" | |||||
| " -p,--prefetch_size: Set the prefetch size for cache. Default = " | |||||
| << kDftPrefetchSize << "\n" | |||||
| << " -a,--cache_size: Set cache size. Default = " << kDftCacheSize | |||||
| << " (Mb)\n" | |||||
| " --spill: Set spill to disk to True. Default = " | |||||
| << std::boolalpha << kDftSpill << "\n" | |||||
| << " -w,--workers: Set the number of parallel workers. Default = " << cfg_.num_parallel_workers() | |||||
| << "\n" | |||||
| " --connection: Set number of TCP/IP connections per pipeline. Default = " | |||||
| << kDftNumConnections << "\n" | |||||
| << " --port: TCP/IP port of the cache server. Default = " << kCfgDefaultCachePort << "\n" | |||||
| << " --hostname: Hostname of the cache server. Default = " << kCfgDefaultCacheHost << "\n"; | |||||
| } | |||||
| int32_t CachePerfRun::ProcessArgs(int argc, char **argv) { | |||||
| if (argc == 1) { | |||||
| PrintHelp(); | |||||
| return -1; | |||||
| } | |||||
| const int32_t port_opt = 1000; // there is no short option for port | |||||
| const int32_t hostname_opt = 1001; // there is no short option for hostname | |||||
| const int32_t connect_opt = 1002; // there is no short option for connect | |||||
| int shuffle = 0; | |||||
| int spill = 0; | |||||
| const char *const short_opts = ":n:e:p:a:s:r:w:"; | |||||
| const option long_opts[] = {{"pipeline", required_argument, nullptr, 'n'}, | |||||
| {"epoch", required_argument, nullptr, 'e'}, | |||||
| {"prefetch_size", required_argument, nullptr, 'p'}, | |||||
| {"shuffle", no_argument, &shuffle, 1}, | |||||
| {"cache_size", required_argument, nullptr, 'a'}, | |||||
| {"num_rows", required_argument, nullptr, 's'}, | |||||
| {"row_size", required_argument, nullptr, 'r'}, | |||||
| {"workers", required_argument, nullptr, 'w'}, | |||||
| {"port", required_argument, nullptr, port_opt}, | |||||
| {"hostname", required_argument, nullptr, hostname_opt}, | |||||
| {"spill", no_argument, &spill, 1}, | |||||
| {"connection", required_argument, nullptr, connect_opt}, | |||||
| {"help", no_argument, nullptr, 'h'}, | |||||
| {nullptr, no_argument, nullptr, 0}}; | |||||
| std::map<int32_t, int32_t> seen_opts; | |||||
| int32_t rc = 0; | |||||
| try { | |||||
| while (rc == 0) { | |||||
| int32_t option_indxex; | |||||
| const auto opt = getopt_long(argc, argv, short_opts, long_opts, &option_indxex); | |||||
| if (-1 == opt) { | |||||
| if (optind < argc) { | |||||
| rc = -1; | |||||
| std::cerr << "Unknown arguments: "; | |||||
| while (optind < argc) { | |||||
| std::cerr << argv[optind++] << " "; | |||||
| } | |||||
| std::cerr << std::endl; | |||||
| } | |||||
| break; | |||||
| } | |||||
| if (opt > 0) { | |||||
| seen_opts[opt]++; | |||||
| if (seen_opts[opt] > 1) { | |||||
| std::string long_name = long_opts[option_indxex].name; | |||||
| std::cerr << "The " << long_name << " argument was given more than once." << std::endl; | |||||
| rc = -1; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| switch (opt) { | |||||
| case 0: { | |||||
| if (long_opts[option_indxex].flag == &shuffle) { | |||||
| shuffle_ = true; | |||||
| } else if (long_opts[option_indxex].flag == &spill) { | |||||
| cache_builder_.SetSpill(true); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case 'n': { | |||||
| num_pipelines_ = std::stoi(optarg); | |||||
| break; | |||||
| } | |||||
| case 'e': { | |||||
| num_epoches_ = std::stoi(optarg); | |||||
| break; | |||||
| } | |||||
| case 'p': { | |||||
| int32_t prefetch_sz = std::stoi(optarg); | |||||
| cache_builder_.SetPrefetchSize(prefetch_sz); | |||||
| break; | |||||
| } | |||||
| case 'a': { | |||||
| int32_t cache_sz = std::stoi(optarg); | |||||
| cache_builder_.SetCacheMemSz(cache_sz); | |||||
| break; | |||||
| } | |||||
| case 's': { | |||||
| num_rows_ = std::stoi(optarg); | |||||
| break; | |||||
| } | |||||
| case 'r': { | |||||
| row_size_ = std::stoi(optarg); | |||||
| break; | |||||
| } | |||||
| case 'w': { | |||||
| cfg_.set_num_parallel_workers(std::stoi(optarg)); | |||||
| break; | |||||
| } | |||||
| case connect_opt: { | |||||
| int32_t connection_sz = std::stoi(optarg); | |||||
| cache_builder_.SetNumConnections(connection_sz); | |||||
| break; | |||||
| } | |||||
| case port_opt: { | |||||
| int32_t port = std::stoi(optarg); | |||||
| cache_builder_.SetPort(port); | |||||
| break; | |||||
| } | |||||
| case hostname_opt: { | |||||
| std::string hostname = optarg; | |||||
| cache_builder_.SetHostname(hostname); | |||||
| break; | |||||
| } | |||||
| case 'h': // -h or --help | |||||
| PrintHelp(); | |||||
| rc = -1; | |||||
| break; | |||||
| case ':': | |||||
| std::cerr << "Missing argument for option " << char(optopt) << std::endl; | |||||
| rc = -1; | |||||
| break; | |||||
| case '?': // Unrecognized option | |||||
| default: | |||||
| std::cerr << "Unknown option " << char(optopt) << std::endl; | |||||
| PrintHelp(); | |||||
| rc = -1; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } catch (const std::exception &e) { | |||||
| PrintHelp(); | |||||
| rc = -1; | |||||
| } | |||||
| if (rc < 0) { | |||||
| return rc; | |||||
| } | |||||
| // We have all the defaults except sample size and average row size which the user must specify. | |||||
| auto it = seen_opts.find('s'); | |||||
| if (it == seen_opts.end()) { | |||||
| std::cerr << "Missing sample size." << std::endl; | |||||
| return -1; | |||||
| } | |||||
| it = seen_opts.find('r'); | |||||
| if (it == seen_opts.end()) { | |||||
| std::cerr << "Missing average row size." << std::endl; | |||||
| return -1; | |||||
| } | |||||
| if (num_rows_ <= 0) { | |||||
| std::cerr << "Sample size must be positive." << std::endl; | |||||
| return -1; | |||||
| } | |||||
| if (row_size_ <= 0) { | |||||
| std::cerr << "Average row size must be positive." << std::endl; | |||||
| return -1; | |||||
| } | |||||
| if (num_pipelines_ <= 0) { | |||||
| std::cerr << "Number of pipelines must be positive." << std::endl; | |||||
| return -1; | |||||
| } | |||||
| if (num_epoches_ <= 0) { | |||||
| std::cerr << "Number of epoches must be positive." << std::endl; | |||||
| return -1; | |||||
| } | |||||
| if (num_rows_ < num_pipelines_) { | |||||
| std::cerr << "Sample size is smaller than the number of pipelines." << std::endl; | |||||
| return -1; | |||||
| } | |||||
| pid_lists_.reserve(num_pipelines_); | |||||
| return 0; | |||||
| } | |||||
| Status CachePerfRun::GetSession() { | |||||
| CacheClientGreeter comm(cache_builder_.GetHostname(), cache_builder_.GetPort(), 1); | |||||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||||
| auto rq = std::make_shared<GenerateSessionIdRequest>(); | |||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| session_ = rq->GetSessionId(); | |||||
| std::cout << "Session: " << session_ << std::endl; | |||||
| cache_builder_.SetSessionId(session_); | |||||
| return Status::OK(); | |||||
| } | |||||
| CachePerfRun::CachePerfRun() | |||||
| : my_pipeline_(-1), | |||||
| num_pipelines_(kDftNumOfPipelines), | |||||
| num_epoches_(kDftNumberOfEpochs), | |||||
| num_rows_(0), | |||||
| row_size_(0), | |||||
| shuffle_(kDftShuffle), | |||||
| session_(0), | |||||
| crc_(0), | |||||
| epoch_sync_cnt_(0) { | |||||
| cache_builder_.SetSpill(kDftSpill).SetCacheMemSz(kDftCacheSize); | |||||
| } | |||||
| CachePerfRun::~CachePerfRun() { | |||||
| if (session_ != 0) { | |||||
| Status rc; | |||||
| CacheClientGreeter comm(cache_builder_.GetHostname(), cache_builder_.GetPort(), 1); | |||||
| rc = comm.ServiceStart(); | |||||
| if (rc.IsOk()) { | |||||
| CacheClientInfo cinfo; | |||||
| cinfo.set_session_id(session_); | |||||
| auto rq = std::make_shared<DropSessionRequest>(cinfo); | |||||
| rc = comm.HandleRequest(rq); | |||||
| if (rc.IsOk()) { | |||||
| rc = rq->Wait(); | |||||
| if (rc.IsOk()) { | |||||
| std::cout << "Drop session " << session_ << " successful" << std::endl; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // Send an interrupt message to each child. | |||||
| for (auto msg_qid : msg_send_lists_) { | |||||
| CachePerfMsg msg; | |||||
| msg.SetType(CachePerfMsg::MessageType::kInterrupt); | |||||
| (void)msg.Send(msg_qid); | |||||
| } | |||||
| // Wait for each child to return | |||||
| for (auto pid : pid_lists_) { | |||||
| int status; | |||||
| if (waitpid(pid, &status, 0) == -1) { | |||||
| std::string errMsg = "waitpid fails. errno = " + std::to_string(errno); | |||||
| std::cerr << errMsg << std::endl; | |||||
| } else { | |||||
| MS_LOG(INFO) << "Child pid " << pid << " returns." << std::endl; | |||||
| } | |||||
| } | |||||
| // Remove all the message queues | |||||
| for (auto msg_qid : msg_send_lists_) { | |||||
| // Remove the message que and never mind about the return code. | |||||
| (void)msgctl(msg_qid, IPC_RMID, nullptr); | |||||
| } | |||||
| for (auto msg_qid : msg_recv_lists_) { | |||||
| // Remove the message que and never mind about the return code. | |||||
| (void)msgctl(msg_qid, IPC_RMID, nullptr); | |||||
| } | |||||
| } | |||||
| void CachePerfRun::PrintEpochSummary() const { | |||||
| std::cout << std::setw(12) << "Pipeline #" << std::setw(10) << "worker id" << std::setw(11) << "min (μs)" | |||||
| << std::setw(11) << "max (μs)" << std::setw(11) << "avg (μs)" << std::setw(14) << "median (μs)" | |||||
| << std::setw(14) << "buffer count" << std::setw(18) << "Elapsed time (s)" << std::endl; | |||||
| for (auto &it : epoch_results_) { | |||||
| auto epoch_worker_summary = it.second; | |||||
| std::cout << std::setw(12) << epoch_worker_summary.pipeline() + 1 << std::setw(10) << epoch_worker_summary.worker() | |||||
| << std::setw(10) << epoch_worker_summary.min() << std::setw(10) << epoch_worker_summary.max() | |||||
| << std::setw(10) << epoch_worker_summary.avg() << std::setw(13) << epoch_worker_summary.med() | |||||
| << std::setw(14) << epoch_worker_summary.cnt() << std::setw(18) << epoch_worker_summary.elapse() | |||||
| << std::endl; | |||||
| } | |||||
| } | |||||
| Status CachePerfRun::ListenToPipeline(int32_t workerId) { | |||||
| TaskManager::FindMe()->Post(); | |||||
| int32_t qID = msg_recv_lists_[workerId]; | |||||
| do { | |||||
| RETURN_IF_INTERRUPTED(); | |||||
| CachePerfMsg msg; | |||||
| RETURN_IF_NOT_OK(msg.Receive(qID)); | |||||
| // Decode the messages. | |||||
| auto type = msg.GetType(); | |||||
| char *p = msg.GetMutableBuffer(); | |||||
| switch (type) { | |||||
| case CachePerfMsg::MessageType::kEpochResult: { | |||||
| PipelineWorkerEpochSummary epoch_worker_summary; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(epoch_worker_summary.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); | |||||
| { | |||||
| auto pipeline = epoch_worker_summary.pipeline(); | |||||
| auto worker = epoch_worker_summary.worker(); | |||||
| std::unique_lock<std::mutex> lock(mux_); | |||||
| // sort by pipeline/worker | |||||
| auto r = | |||||
| epoch_results_.emplace(std::pair<int32_t, int32_t>(pipeline, worker), std::move(epoch_worker_summary)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(r.second, "Insert failed"); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case CachePerfMsg::MessageType::kEpochEnd: { | |||||
| EpochDone proto; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(proto.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); | |||||
| auto n = epoch_sync_cnt_.fetch_add(1); | |||||
| if (n + 1 == num_pipelines_) { | |||||
| pipeline_wp_.Set(); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case CachePerfMsg::MessageType::kInterrupt: { | |||||
| TaskManager::WakeUpWatchDog(); | |||||
| return Status::OK(); | |||||
| } | |||||
| case CachePerfMsg::kError: { | |||||
| ErrorMsg proto; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(proto.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); | |||||
| return Status(static_cast<const StatusCode>(proto.rc()), proto.msg()); | |||||
| } | |||||
| default: | |||||
| std::string errMsg = "Unknown request type: " + std::to_string(type); | |||||
| MS_LOG(ERROR) << errMsg; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| break; | |||||
| } | |||||
| } while (true); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CachePerfRun::Run() { | |||||
| // Now we bring up TaskManager. | |||||
| RETURN_IF_NOT_OK(Services::CreateInstance()); | |||||
| // Handle Control-C | |||||
| RegisterHandlers(); | |||||
| // Get a session from the server. | |||||
| RETURN_IF_NOT_OK(GetSession()); | |||||
| // Generate a random crc. | |||||
| auto mt = GetRandomDevice(); | |||||
| std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<int32_t>::max()); | |||||
| crc_ = distribution(mt); | |||||
| std::cout << "CRC: " << crc_ << std::endl; | |||||
| // Create all the resources required by the pipelines before we fork. | |||||
| for (auto i = 0; i < num_pipelines_; ++i) { | |||||
| // We will use shared message queues for communication between parent (this process) | |||||
| // and each pipelines. | |||||
| auto access_mode = S_IRUSR | S_IWUSR; | |||||
| int32_t msg_send_qid = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); | |||||
| if (msg_send_qid == -1) { | |||||
| std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| msg_send_lists_.push_back(msg_send_qid); | |||||
| int32_t msg_recv_qid = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); | |||||
| if (msg_recv_qid == -1) { | |||||
| std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| msg_recv_lists_.push_back(msg_recv_qid); | |||||
| } | |||||
| // Now we create the children knowing all two sets of message queues are constructed. | |||||
| for (auto i = 0; i < num_pipelines_; ++i) { | |||||
| auto pid = fork(); | |||||
| if (pid == 0) { | |||||
| // Child. We will call another binary but with different (hidden) parameters. | |||||
| // The parent process is waiting on a wait post. Any error we hit here we must interrupt the | |||||
| // parent process | |||||
| auto interrupt_parent = [this, i]() { | |||||
| CachePerfMsg msg; | |||||
| msg.SetType(CachePerfMsg::MessageType::kInterrupt); | |||||
| msg.Send(msg_recv_lists_[i]); | |||||
| }; | |||||
| const std::string self_proc = "/proc/self/exe"; | |||||
| std::string canonical_path; | |||||
| canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use. | |||||
| // Some lower level OS library calls are needed here to determine the binary path. | |||||
| if (realpath(self_proc.data(), canonical_path.data()) == nullptr) { | |||||
| std::cerr << "Failed to identify cache_perf binary path: " + std::to_string(errno) << ": " << strerror(errno) | |||||
| << std::endl; | |||||
| interrupt_parent(); | |||||
| // Call _exit instead of exit because we will hang in TaskManager destructor for a forked child process. | |||||
| _exit(-1); | |||||
| } | |||||
| canonical_path.resize(strlen(canonical_path.data())); | |||||
| int last_seperator = canonical_path.find_last_of('/'); | |||||
| if (last_seperator == std::string::npos) { | |||||
| std::cerr << "Canonical path can't locate / " << canonical_path << std::endl; | |||||
| interrupt_parent(); | |||||
| // Call _exit instead of exit because we will hang in TaskManager destructor for a forked child process. | |||||
| _exit(-1); | |||||
| } | |||||
| // truncate the binary name so we are left with the absolute path of cache_admin binary | |||||
| canonical_path.resize(last_seperator + 1); | |||||
| std::string cache_pipeline_binary = canonical_path + std::string(kCachePipelineBinary); | |||||
| std::string pipeline_cfg = std::to_string(i) + "," + std::to_string(session_) + "," + std::to_string(crc_) + "," + | |||||
| std::to_string(msg_send_lists_[i]) + "," + std::to_string(msg_recv_lists_[i]) + "," + | |||||
| std::to_string(num_pipelines_) + "," + std::to_string(num_epoches_) + "," + | |||||
| std::to_string(num_rows_) + "," + std::to_string(row_size_) + "," + | |||||
| std::to_string(cfg_.num_parallel_workers()) + "," + | |||||
| (shuffle_ ? std::string("true").data() : std::string("false").data()); | |||||
| std::string client_cfg = cache_builder_.GetHostname() + "," + std::to_string(cache_builder_.GetPort()) + "," + | |||||
| std::to_string(cache_builder_.GetPrefetchSize()) + "," + | |||||
| std::to_string(cache_builder_.GetCacheMemSz()) + "," + | |||||
| std::to_string(cache_builder_.GetNumConnections()) + "," + | |||||
| (cache_builder_.isSpill() ? std::string("true").data() : std::string("false").data()); | |||||
| char *argv[4]; | |||||
| argv[0] = const_cast<char *>(kCachePipelineBinary); | |||||
| argv[1] = pipeline_cfg.data(); | |||||
| argv[2] = client_cfg.data(); | |||||
| argv[3] = nullptr; | |||||
| // Invoke the binary. | |||||
| execv(cache_pipeline_binary.data(), argv); | |||||
| std::cerr << "Unable to exec. Errno = " + std::to_string(errno) << ": " << strerror(errno) << std::endl; | |||||
| interrupt_parent(); | |||||
| // Call _exit instead of exit because we will hang TaskManager destructor for a forked child process. | |||||
| _exit(-1); | |||||
| } else if (pid > 0) { | |||||
| std::cout << "Pipeline number " << i + 1 << " has been created with process id: " << pid << std::endl; | |||||
| pid_lists_.push_back(pid); | |||||
| } else { | |||||
| std::string errMsg = "Failed to fork process for cache pipeline: " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| } | |||||
| // Spawn a few threads to monitor the communications from the pipeline. | |||||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||||
| auto f = std::bind(&CachePerfRun::ListenToPipeline, this, std::placeholders::_1); | |||||
| for (auto i = 0; i < num_pipelines_; ++i) { | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(f, i))); | |||||
| } | |||||
| // Wait until all pipelines finish the first epoch. | |||||
| RETURN_IF_NOT_OK(pipeline_wp_.Wait()); | |||||
| std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer() | |||||
| << std::endl; | |||||
| PrintEpochSummary(); | |||||
| // Get some stat but we need to connect. The server will thinks it is the (n+1) pipeline | |||||
| RETURN_IF_NOT_OK(cache_builder_.Build(&cc_)); | |||||
| Status rc = cc_->CreateCache(crc_, false); | |||||
| // Duplicate key is fine. | |||||
| if (rc.IsError() && rc.get_code() != StatusCode::kDuplicateKey) { | |||||
| return rc; | |||||
| } | |||||
| CacheServiceStat stat{}; | |||||
| RETURN_IF_NOT_OK(cc_->GetStat(&stat)); | |||||
| std::cout << "Get statistics for this session:\n"; | |||||
| std::cout << std::setw(12) << "Mem cached" << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" | |||||
| << std::setw(10) << "Numa hit" << std::endl; | |||||
| std::string stat_mem_cached; | |||||
| std::string stat_disk_cached; | |||||
| std::string stat_avg_cached; | |||||
| std::string stat_numa_hit; | |||||
| stat_mem_cached = (stat.num_mem_cached == 0) ? "n/a" : std::to_string(stat.num_mem_cached); | |||||
| stat_disk_cached = (stat.num_disk_cached == 0) ? "n/a" : std::to_string(stat.num_disk_cached); | |||||
| stat_avg_cached = (stat.avg_cache_sz == 0) ? "n/a" : std::to_string(stat.avg_cache_sz); | |||||
| stat_numa_hit = (stat.num_numa_hit == 0) ? "n/a" : std::to_string(stat.num_numa_hit); | |||||
| std::cout << std::setw(12) << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached | |||||
| << std::setw(10) << stat_numa_hit << std::endl; | |||||
| // Toggle write mode off since the rest are just read only. | |||||
| // Simplest way is call this special internal function. | |||||
| cc_->ServerRunningOutOfResources(); | |||||
| // The rest of the epochs are just fetching. | |||||
| auto epoch_num = 2; | |||||
| while (epoch_num <= num_epoches_) { | |||||
| epoch_sync_cnt_ = 0; | |||||
| pipeline_wp_.Clear(); | |||||
| epoch_results_.clear(); | |||||
| // Signal each pipeline to start | |||||
| for (auto msg_qid : msg_send_lists_) { | |||||
| CachePerfMsg msg; | |||||
| msg.SetType(CachePerfMsg::MessageType::kEpochStart); | |||||
| (void)msg.Send(msg_qid); | |||||
| } | |||||
| // Wait for the child to finish | |||||
| RETURN_IF_NOT_OK(pipeline_wp_.Wait()); | |||||
| std::cout << "Epoch " << epoch_num | |||||
| << " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl; | |||||
| PrintEpochSummary(); | |||||
| ++epoch_num; | |||||
| } | |||||
| // Destroy the cache. We no longer need it around. | |||||
| RETURN_IF_NOT_OK(cc_->DestroyCache()); | |||||
| // Unreserve the session | |||||
| CacheClientInfo cinfo; | |||||
| cinfo.set_session_id(session_); | |||||
| auto rq = std::make_shared<DropSessionRequest>(cinfo); | |||||
| RETURN_IF_NOT_OK(cc_->PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| std::cout << "Drop session " << session_ << " successful" << std::endl; | |||||
| session_ = 0; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,100 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ | |||||
| #include <getopt.h> | |||||
| #include <atomic> | |||||
| #include <cstdint> | |||||
| #include <limits> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <random> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||||
| #include "minddata/dataset/engine/cache/perf/cache_msg.h" | |||||
| #include "minddata/dataset/util/queue.h" | |||||
| #include "minddata/dataset/util/task_manager.h" | |||||
| #include "minddata/dataset/util/wait_post.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| constexpr int32_t kDftNumOfPipelines = 8; | |||||
| constexpr int32_t kDftNumberOfEpochs = 10; | |||||
| constexpr int32_t kDftCacheSize = 0; | |||||
| constexpr bool kDftShuffle = false; | |||||
| constexpr bool kDftSpill = false; | |||||
| class CachePerfRun { | |||||
| public: | |||||
| static const char kCachePipelineBinary[]; | |||||
| CachePerfRun(); | |||||
| ~CachePerfRun(); | |||||
| void PrintHelp(); | |||||
| int32_t ProcessArgs(int argc, char **argv); | |||||
| void Print(std::ostream &out) const { | |||||
| out << "Number of pipelines: " << num_pipelines_ << "\n" | |||||
| << "Number of epochs: " << num_epoches_ << "\n" | |||||
| << "Sample size: " << num_rows_ << "\n" | |||||
| << "Average row size: " << row_size_ << "\n" | |||||
| << "Shuffle: " << std::boolalpha << shuffle_; | |||||
| } | |||||
| friend std::ostream &operator<<(std::ostream &out, const CachePerfRun &cp) { | |||||
| cp.Print(out); | |||||
| return out; | |||||
| } | |||||
| Status Run(); | |||||
| private: | |||||
| std::mutex mux_; | |||||
| int32_t my_pipeline_; | |||||
| int32_t num_pipelines_; | |||||
| int32_t num_epoches_; | |||||
| int64_t num_rows_; | |||||
| int32_t row_size_; | |||||
| bool shuffle_; | |||||
| CacheClient::Builder cache_builder_; | |||||
| session_id_type session_; | |||||
| int32_t crc_; | |||||
| std::vector<int32_t> pid_lists_; | |||||
| std::vector<int32_t> msg_send_lists_; | |||||
| std::vector<int32_t> msg_recv_lists_; | |||||
| TaskGroup vg_; | |||||
| std::atomic<int32_t> epoch_sync_cnt_; | |||||
| WaitPost pipeline_wp_; | |||||
| std::map<std::pair<int32_t, int32_t>, PipelineWorkerEpochSummary> epoch_results_; | |||||
| ConfigManager cfg_; | |||||
| std::shared_ptr<CacheClient> cc_; | |||||
| Status GetSession(); | |||||
| Status ListenToPipeline(int32_t workerId); | |||||
| void PrintEpochSummary() const; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef USE_GLOG | |||||
| #include <glog/logging.h> | |||||
| #endif | |||||
| #include <string.h> | |||||
| #include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h" | |||||
| namespace ds = mindspore::dataset; | |||||
| int main(int argc, char **argv) { | |||||
| #ifdef USE_GLOG | |||||
| FLAGS_log_dir = "/tmp"; | |||||
| FLAGS_minloglevel = google::WARNING; | |||||
| google::InitGoogleLogging(argv[0]); | |||||
| #endif | |||||
| ds::CachePipelineRun cachePipelineRun; | |||||
| if (cachePipelineRun.ProcessArgs(argc, argv) == 0) { | |||||
| ds::Status rc = cachePipelineRun.Run(); | |||||
| // If we hit any error, send the rc back to the parent. | |||||
| if (rc.IsError()) { | |||||
| ds::ErrorMsg proto; | |||||
| proto.set_rc(static_cast<int32_t>(rc.get_code())); | |||||
| proto.set_msg(rc.ToString()); | |||||
| ds::CachePerfMsg msg; | |||||
| (void)cachePipelineRun.SendMessage(&msg, ds::CachePerfMsg::MessageType::kError, &proto); | |||||
| } | |||||
| return static_cast<int>(rc.get_code()); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,471 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h" | |||||
| #include <string.h> | |||||
| #include <sys/types.h> | |||||
| #include <algorithm> | |||||
| #include <chrono> | |||||
| #include <iomanip> | |||||
| #include <sstream> | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/engine/data_buffer.h" | |||||
| #include "minddata/dataset/engine/data_schema.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/services.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| void CachePipelineRun::PrintHelp() { std::cout << "Please run the executable cache_perf instead." << std::endl; } | |||||
| int32_t CachePipelineRun::ProcessArgs(int argc, char **argv) { | |||||
| if (argc != 3) { | |||||
| PrintHelp(); | |||||
| return -1; | |||||
| } | |||||
| try { | |||||
| std::stringstream cfg_ss(argv[1]); | |||||
| std::string s; | |||||
| int32_t numArgs = 0; | |||||
| while (std::getline(cfg_ss, s, ',')) { | |||||
| if (numArgs == 0) { | |||||
| my_pipeline_ = std::stoi(s); | |||||
| } else if (numArgs == 1) { | |||||
| session_ = std::stoul(s); | |||||
| cache_builder_.SetSessionId(session_); | |||||
| } else if (numArgs == 2) { | |||||
| crc_ = std::stoi(s); | |||||
| } else if (numArgs == 3) { | |||||
| recv_id_ = std::stoi(s); | |||||
| } else if (numArgs == 4) { | |||||
| send_id_ = std::stoi(s); | |||||
| } else if (numArgs == 5) { | |||||
| num_pipelines_ = std::stoi(s); | |||||
| } else if (numArgs == 6) { | |||||
| num_epoches_ = std::stoi(s); | |||||
| } else if (numArgs == 7) { | |||||
| num_rows_ = std::stol(s); | |||||
| } else if (numArgs == 8) { | |||||
| row_size_ = std::stoi(s); | |||||
| } else if (numArgs == 9) { | |||||
| cfg_.set_num_parallel_workers(std::stol(s)); | |||||
| } else if (numArgs == 10) { | |||||
| shuffle_ = strcmp(s.data(), "true") == 0; | |||||
| } | |||||
| ++numArgs; | |||||
| } | |||||
| if (numArgs != 11) { | |||||
| std::cerr << "Incomplete arguments. Expect 11. But get " << numArgs << std::endl; | |||||
| return -1; | |||||
| } | |||||
| std::stringstream client_ss(argv[2]); | |||||
| numArgs = 0; | |||||
| while (std::getline(client_ss, s, ',')) { | |||||
| if (numArgs == 0) { | |||||
| cache_builder_.SetHostname(s); | |||||
| } else if (numArgs == 1) { | |||||
| cache_builder_.SetPort(std::stoi(s)); | |||||
| } else if (numArgs == 2) { | |||||
| cache_builder_.SetPrefetchSize(std::stoi(s)); | |||||
| } else if (numArgs == 3) { | |||||
| cache_builder_.SetCacheMemSz(std::stoi(s)); | |||||
| } else if (numArgs == 4) { | |||||
| cache_builder_.SetNumConnections(std::stoi(s)); | |||||
| } else if (numArgs == 5) { | |||||
| cache_builder_.SetSpill(strcmp(s.data(), "true") == 0); | |||||
| } | |||||
| ++numArgs; | |||||
| } | |||||
| if (numArgs != 6) { | |||||
| std::cerr << "Incomplete arguments. Expect 6. But get " << numArgs << std::endl; | |||||
| return -1; | |||||
| } | |||||
| } catch (const std::exception &e) { | |||||
| std::cerr << "Parse error: " << e.what() << std::endl; | |||||
| return -1; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| CachePipelineRun::CachePipelineRun() | |||||
| : my_pipeline_(-1), | |||||
| num_pipelines_(kDftNumOfPipelines), | |||||
| num_epoches_(kDftNumberOfEpochs), | |||||
| num_rows_(0), | |||||
| row_size_(0), | |||||
| shuffle_(kDftShuffle), | |||||
| session_(0), | |||||
| crc_(0), | |||||
| send_id_(-1), | |||||
| recv_id_(-1), | |||||
| start_row_(-1), | |||||
| end_row_(-1) { | |||||
| cache_builder_.SetSpill(kDftSpill).SetCacheMemSz(kDftCacheSize); | |||||
| } | |||||
| CachePipelineRun::~CachePipelineRun() { | |||||
| CachePerfMsg msg; | |||||
| (void)SendMessage<ErrorMsg>(&msg, CachePerfMsg::MessageType::kInterrupt, nullptr); | |||||
| } | |||||
| Status CachePipelineRun::ListenToParent() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| do { | |||||
| RETURN_IF_INTERRUPTED(); | |||||
| CachePerfMsg msg; | |||||
| RETURN_IF_NOT_OK(msg.Receive(recv_id_)); | |||||
| // Decode the messages. | |||||
| auto type = msg.GetType(); | |||||
| switch (type) { | |||||
| case CachePerfMsg::MessageType::kInterrupt: { | |||||
| TaskManager::WakeUpWatchDog(); | |||||
| return Status::OK(); | |||||
| } | |||||
| case CachePerfMsg::MessageType::kEpochStart: { | |||||
| pipeline_wp_.Set(); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| std::string errMsg = "Unknown request type: " + std::to_string(type); | |||||
| MS_LOG(ERROR) << errMsg; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| break; | |||||
| } | |||||
| } while (true); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CachePipelineRun::Run() { | |||||
| RETURN_IF_NOT_OK(cache_builder_.Build(&cc_)); | |||||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||||
| auto num_workers = cfg_.num_parallel_workers(); | |||||
| io_block_queues_.Init(num_workers, cfg_.op_connector_size()); | |||||
| RETURN_IF_NOT_OK(io_block_queues_.Register(&vg_)); | |||||
| Status rc = cc_->CreateCache(crc_, false); | |||||
| // Duplicate key is fine. | |||||
| if (rc.IsError() && rc.get_code() != StatusCode::kDuplicateKey) { | |||||
| return rc; | |||||
| } | |||||
| // Log a warning level message so we can see it in the log file when it starts. | |||||
| MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " successfully creating cache service." << std::endl; | |||||
| // Spawn a thread to listen to the parent process | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(&CachePipelineRun::ListenToParent, this))); | |||||
| RETURN_IF_NOT_OK(RunFirstEpoch()); | |||||
| // The rest of the epochs are just fetching. | |||||
| auto remaining_epochs = num_epoches_ - 1; | |||||
| while (remaining_epochs > 0) { | |||||
| // Wait for parent process signal to start | |||||
| pipeline_wp_.Wait(); | |||||
| pipeline_wp_.Clear(); | |||||
| RETURN_IF_NOT_OK(RunReadEpoch()); | |||||
| --remaining_epochs; | |||||
| } | |||||
| // The listener thread is blocked on a shared message queue. It will be waken up by | |||||
| // the parent process which will send an interrupt message to it, and this program | |||||
| // will exit. | |||||
| RETURN_IF_NOT_OK(vg_.join_all()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CachePipelineRun::RunFirstEpoch() { | |||||
| auto sz = num_rows_ / num_pipelines_; | |||||
| start_row_ = my_pipeline_ * sz; | |||||
| end_row_ = (my_pipeline_ + 1) * sz - 1; | |||||
| if (my_pipeline_ + 1 == num_pipelines_) { | |||||
| end_row_ = num_rows_ - 1; | |||||
| } | |||||
| std::cout << "Pipeline number " << my_pipeline_ + 1 << " row id range: [" << start_row_ << "," << end_row_ << "]" | |||||
| << std::endl; | |||||
| // Spawn the worker threads. | |||||
| auto f = std::bind(&CachePipelineRun::WriterWorkerEntry, this, std::placeholders::_1); | |||||
| std::vector<Task *> worker_threads; | |||||
| auto num_workers = cfg_.num_parallel_workers(); | |||||
| worker_threads.reserve(num_workers); | |||||
| for (int32_t i = 0; i < num_workers; ++i) { | |||||
| Task *pTask; | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Parallel worker", std::bind(f, i), &pTask)); | |||||
| worker_threads.push_back(pTask); | |||||
| } | |||||
| std::vector<row_id_type> keys; | |||||
| auto rows_per_buffer = cfg_.rows_per_buffer(); | |||||
| keys.reserve(rows_per_buffer); | |||||
| int32_t worker_id = 0; | |||||
| for (auto i = start_row_; i <= end_row_; ++i) { | |||||
| keys.push_back(i); | |||||
| if (keys.size() == rows_per_buffer) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); | |||||
| keys.clear(); | |||||
| } | |||||
| } | |||||
| if (!keys.empty()) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); | |||||
| keys.clear(); | |||||
| } | |||||
| // Shutdown threads and wait for them to come back. | |||||
| for (int32_t i = 0; i < num_workers; i++) { | |||||
| RETURN_IF_NOT_OK( | |||||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||||
| } | |||||
| for (auto *pTask : worker_threads) { | |||||
| RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking)); | |||||
| } | |||||
| // Send a message saying epoch one done for this pipeline. | |||||
| EpochDone proto; | |||||
| proto.set_pipeline(my_pipeline_); | |||||
| CachePerfMsg msg; | |||||
| RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochEnd, &proto)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) { | |||||
| Status rc; | |||||
| TaskManager::FindMe()->Post(); | |||||
| int64_t min_val = std::numeric_limits<int64_t>::max(); | |||||
| int64_t max_val = 0; | |||||
| int64_t total_val = 0; | |||||
| int64_t cnt = 0; | |||||
| std::vector<int64_t> duration; | |||||
| duration.reserve(num_rows_ / num_pipelines_ / cfg_.num_parallel_workers()); | |||||
| bool resource_err = false; | |||||
| auto col_desc = std::make_unique<ColDescriptor>("int64", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); | |||||
| auto num_elements = row_size_ / sizeof(int64_t); | |||||
| TensorShape shape(std::vector<dsize_t>(1, num_elements)); | |||||
| std::unique_ptr<IOBlock> blk; | |||||
| auto epoch_start = std::chrono::steady_clock::now(); | |||||
| do { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); | |||||
| std::vector<int64_t> keys; | |||||
| RETURN_IF_NOT_OK(blk->GetKeys(&keys)); | |||||
| if (keys.empty()) { | |||||
| // empty key is a quit signal for workers | |||||
| break; | |||||
| } | |||||
| // Once we hit resource error, we drain the io block. No point to send anything to the server. | |||||
| if (!resource_err) { | |||||
| auto buffer = std::make_unique<DataBuffer>(cnt++, DataBuffer::kDeBFlagNone); | |||||
| auto tensor_table = std::make_unique<TensorQTable>(); | |||||
| for (auto id : keys) { | |||||
| TensorRow row; | |||||
| std::shared_ptr<Tensor> element; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc->type(), &element)); | |||||
| row.setId(id); | |||||
| // CreateEmpty allocates the memory but in virutal address. Let's commit the memory | |||||
| // so we can get an accurate timing. | |||||
| auto it = element->begin<int64_t>(); | |||||
| for (auto i = 0; i < num_elements; ++i, ++it) { | |||||
| *it = i; | |||||
| } | |||||
| row.push_back(std::move(element)); | |||||
| tensor_table->push_back(std::move(row)); | |||||
| } | |||||
| buffer->set_tensor_table(std::move(tensor_table)); | |||||
| // Measure the time to call WriteBuffer | |||||
| auto start_tick = std::chrono::steady_clock::now(); | |||||
| rc = cc_->WriteBuffer(std::move(buffer)); | |||||
| auto end_tick = std::chrono::steady_clock::now(); | |||||
| if (rc.IsError()) { | |||||
| if (rc.IsOutofMemory() || rc.IsNoSpace()) { | |||||
| MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " worker id " << worker_id << ": " | |||||
| << rc.ToString(); | |||||
| resource_err = true; | |||||
| cc_->ServerRunningOutOfResources(); | |||||
| continue; | |||||
| } else { | |||||
| return rc; | |||||
| } | |||||
| } else { | |||||
| int64_t ms = std::chrono::duration_cast<std::chrono::microseconds>(end_tick - start_tick).count(); | |||||
| min_val = std::min(min_val, ms); | |||||
| max_val = std::max(max_val, ms); | |||||
| duration.push_back(ms); | |||||
| total_val += ms; | |||||
| } | |||||
| } | |||||
| } while (true); | |||||
| auto epoch_end = std::chrono::steady_clock::now(); | |||||
| int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(epoch_end - epoch_start).count(); | |||||
| PipelineWorkerEpochSummary proto; | |||||
| proto.set_pipeline(my_pipeline_); | |||||
| proto.set_worker(worker_id); | |||||
| proto.set_min(min_val); | |||||
| proto.set_max(max_val); | |||||
| proto.set_elapse(elapse_time); | |||||
| auto sz = duration.size(); | |||||
| proto.set_cnt(sz); | |||||
| if (sz > 0) { | |||||
| // median | |||||
| auto n = sz / 2; | |||||
| std::nth_element(duration.begin(), duration.begin() + n, duration.end()); | |||||
| auto median = duration[n]; | |||||
| proto.set_med(median); | |||||
| // average | |||||
| int64_t avg = total_val / sz; | |||||
| proto.set_avg(avg); | |||||
| } | |||||
| CachePerfMsg msg; | |||||
| RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochResult, &proto)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CachePipelineRun::RunReadEpoch() { | |||||
| std::vector<row_id_type> keys; | |||||
| auto rows_per_buffer = cc_->GetPrefetchSize(); // We will use prefetch size to read. | |||||
| auto num_workers = cfg_.num_parallel_workers(); | |||||
| keys.reserve(rows_per_buffer); | |||||
| // Spawn workers | |||||
| auto f = std::bind(&CachePipelineRun::ReaderWorkerEntry, this, std::placeholders::_1); | |||||
| std::vector<Task *> worker_threads; | |||||
| worker_threads.reserve(num_workers); | |||||
| for (int32_t i = 0; i < num_workers; ++i) { | |||||
| Task *pTask; | |||||
| RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Parallel worker", std::bind(f, i), &pTask)); | |||||
| worker_threads.push_back(pTask); | |||||
| } | |||||
| std::vector<row_id_type> all_keys; | |||||
| all_keys.reserve(end_row_ - start_row_ + 1); | |||||
| for (auto i = start_row_; i <= end_row_; ++i) { | |||||
| all_keys.push_back((i)); | |||||
| } | |||||
| // If we need to shuffle the keys | |||||
| if (shuffle_) { | |||||
| std::shuffle(all_keys.begin(), all_keys.end(), GetRandomDevice()); | |||||
| } | |||||
| int32_t worker_id = 0; | |||||
| for (auto id : all_keys) { | |||||
| keys.push_back(id); | |||||
| if (keys.size() == rows_per_buffer) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); | |||||
| keys.clear(); | |||||
| } | |||||
| } | |||||
| if (!keys.empty()) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); | |||||
| keys.clear(); | |||||
| } | |||||
| // Shutdown threads and wait for them to come back. | |||||
| for (int32_t i = 0; i < num_workers; i++) { | |||||
| RETURN_IF_NOT_OK( | |||||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||||
| } | |||||
| for (auto *pTask : worker_threads) { | |||||
| RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking)); | |||||
| } | |||||
| // Send a message saying epoch one done for this pipeline. | |||||
| EpochDone proto; | |||||
| proto.set_pipeline(my_pipeline_); | |||||
| CachePerfMsg msg; | |||||
| RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochEnd, &proto)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CachePipelineRun::ReaderWorkerEntry(int32_t worker_id) { | |||||
| Status rc; | |||||
| TaskManager::FindMe()->Post(); | |||||
| int64_t min_val = std::numeric_limits<int64_t>::max(); | |||||
| int64_t max_val = 0; | |||||
| int64_t total_val = 0; | |||||
| int64_t cnt = 0; | |||||
| std::vector<int64_t> duration; | |||||
| duration.reserve(num_rows_ / num_pipelines_ / cfg_.num_parallel_workers()); | |||||
| std::unique_ptr<IOBlock> blk; | |||||
| auto epoch_start = std::chrono::steady_clock::now(); | |||||
| do { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); | |||||
| std::vector<int64_t> keys; | |||||
| RETURN_IF_NOT_OK(blk->GetKeys(&keys)); | |||||
| if (keys.empty()) { | |||||
| // empty key is a quit signal for workers | |||||
| break; | |||||
| } | |||||
| std::vector<row_id_type> prefetch_keys; | |||||
| prefetch_keys.reserve(keys.size()); | |||||
| // Filter out all those keys that unlikely we will find at the server | |||||
| for (auto row_id : keys) { | |||||
| if (!cc_->KeyIsCacheMiss(row_id)) { | |||||
| prefetch_keys.push_back(row_id); | |||||
| } | |||||
| } | |||||
| // Early exit if nothing to fetch | |||||
| if (prefetch_keys.empty()) { | |||||
| continue; | |||||
| } | |||||
| // Get the rows from the server | |||||
| TensorTable ttbl; | |||||
| // Measure how long it takes for the row to come back. | |||||
| auto start_tick = std::chrono::steady_clock::now(); | |||||
| RETURN_IF_NOT_OK(cc_->GetRows(prefetch_keys, &ttbl)); | |||||
| auto end_tick = std::chrono::steady_clock::now(); | |||||
| int64_t ms = std::chrono::duration_cast<std::chrono::microseconds>(end_tick - start_tick).count(); | |||||
| min_val = std::min(min_val, ms); | |||||
| max_val = std::max(max_val, ms); | |||||
| duration.push_back(ms); | |||||
| total_val += ms; | |||||
| ++cnt; | |||||
| } while (true); | |||||
| auto epoch_end = std::chrono::steady_clock::now(); | |||||
| int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(epoch_end - epoch_start).count(); | |||||
| PipelineWorkerEpochSummary proto; | |||||
| proto.set_pipeline(my_pipeline_); | |||||
| proto.set_worker(worker_id); | |||||
| proto.set_min(min_val); | |||||
| proto.set_max(max_val); | |||||
| proto.set_elapse(elapse_time); | |||||
| auto sz = duration.size(); | |||||
| proto.set_cnt(sz); | |||||
| if (sz > 0) { | |||||
| // median | |||||
| auto n = sz / 2; | |||||
| std::nth_element(duration.begin(), duration.begin() + n, duration.end()); | |||||
| auto median = duration[n]; | |||||
| proto.set_med(median); | |||||
| // average | |||||
| int64_t avg = total_val / sz; | |||||
| proto.set_avg(avg); | |||||
| } | |||||
| CachePerfMsg msg; | |||||
| RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochResult, &proto)); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,117 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ | |||||
| #include <getopt.h> | |||||
| #include <atomic> | |||||
| #include <cstdint> | |||||
| #include <limits> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <random> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||||
| #include "minddata/dataset/engine/cache/perf/cache_msg.h" | |||||
| #include "minddata/dataset/util/queue.h" | |||||
| #include "minddata/dataset/util/task_manager.h" | |||||
| #include "minddata/dataset/util/wait_post.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| constexpr int32_t kDftNumOfPipelines = 8; | |||||
| constexpr int32_t kDftNumberOfEpochs = 10; | |||||
| constexpr int32_t kDftCacheSize = 0; | |||||
| constexpr bool kDftShuffle = false; | |||||
| constexpr bool kDftSpill = false; | |||||
| class CachePipelineRun { | |||||
| public: | |||||
| CachePipelineRun(); | |||||
| ~CachePipelineRun(); | |||||
| static void PrintHelp(); | |||||
| int32_t ProcessArgs(int argc, char **argv); | |||||
| void Print(std::ostream &out) const { | |||||
| out << "Number of pipelines: " << num_pipelines_ << "\n" | |||||
| << "Number of epochs: " << num_epoches_ << "\n" | |||||
| << "Sample size: " << num_rows_ << "\n" | |||||
| << "Average row size: " << row_size_ << "\n" | |||||
| << "Shuffle: " << std::boolalpha << shuffle_; | |||||
| } | |||||
| friend std::ostream &operator<<(std::ostream &out, const CachePipelineRun &cp) { | |||||
| cp.Print(out); | |||||
| return out; | |||||
| } | |||||
| Status Run(); | |||||
| template <typename T> | |||||
| Status SendMessage(CachePerfMsg *msg, CachePerfMsg::MessageType type, T *proto) { | |||||
| RETURN_UNEXPECTED_IF_NULL(msg); | |||||
| msg->SetType(type); | |||||
| if (proto != nullptr) { | |||||
| auto size_needed = proto->ByteSizeLong(); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| size_needed <= kSharedMessageSize, | |||||
| "Shared message set too small. Suggest to increase to " + std::to_string(size_needed)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(proto->SerializeToArray(msg->GetMutableBuffer(), kSharedMessageSize), | |||||
| "Serialization fails"); | |||||
| msg->SetProtoBufSz(size_needed); | |||||
| } | |||||
| RETURN_IF_NOT_OK(msg->Send(send_id_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | |||||
| int32_t my_pipeline_; | |||||
| int32_t num_pipelines_; | |||||
| int32_t num_epoches_; | |||||
| int64_t num_rows_; | |||||
| int32_t row_size_; | |||||
| bool shuffle_; | |||||
| CacheClient::Builder cache_builder_; | |||||
| session_id_type session_; | |||||
| int32_t crc_; | |||||
| TaskGroup vg_; | |||||
| WaitPost pipeline_wp_; | |||||
| int32_t send_id_; | |||||
| int32_t recv_id_; | |||||
| row_id_type start_row_; | |||||
| row_id_type end_row_; | |||||
| ConfigManager cfg_; | |||||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks | |||||
| std::shared_ptr<CacheClient> cc_; | |||||
| Status ListenToParent(); | |||||
| Status RunFirstEpoch(); | |||||
| Status RunReadEpoch(); | |||||
| Status WriterWorkerEntry(int32_t worker_id); | |||||
| Status ReaderWorkerEntry(int32_t worker_id); | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ | |||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/util/storage_container.h" | |||||
| #include "minddata/dataset/engine/cache/storage_container.h" | |||||
| #include <fcntl.h> | #include <fcntl.h> | ||||
| #include <sys/stat.h> | #include <sys/stat.h> | ||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/util/storage_manager.h" | |||||
| #include "minddata/dataset/engine/cache/storage_manager.h" | |||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <sstream> | #include <sstream> | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/cache/storage_container.h" | |||||
| #include "minddata/dataset/util/allocator.h" | #include "minddata/dataset/util/allocator.h" | ||||
| #include "minddata/dataset/util/auto_index.h" | #include "minddata/dataset/util/auto_index.h" | ||||
| #include "minddata/dataset/util/lock.h" | #include "minddata/dataset/util/lock.h" | ||||
| @@ -28,7 +29,6 @@ | |||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| #include "minddata/dataset/util/service.h" | #include "minddata/dataset/util/service.h" | ||||
| #include "minddata/dataset/util/slice.h" | #include "minddata/dataset/util/slice.h" | ||||
| #include "minddata/dataset/util/storage_container.h" | |||||
| using ListOfContainers = std::vector<std::shared_ptr<mindspore::dataset::StorageContainer>>; | using ListOfContainers = std::vector<std::shared_ptr<mindspore::dataset::StorageContainer>>; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -271,29 +271,18 @@ Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector | |||||
| } | } | ||||
| // Get the rows from the server | // Get the rows from the server | ||||
| TensorTable ttbl; | TensorTable ttbl; | ||||
| Status rc = cache_client_->GetRows(prefetch_keys, &ttbl); | |||||
| if (rc.IsOk()) { | |||||
| auto row_it = ttbl.begin(); | |||||
| for (auto row_id : prefetch_keys) { | |||||
| auto &row = *row_it; | |||||
| if (row.empty()) { | |||||
| cache_miss->push_back(row_id); | |||||
| } | |||||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| ++row_it; | |||||
| } | |||||
| } else { | |||||
| // In case any thread is waiting for the rows to come back and blocked on a semaphore, | |||||
| // we will put an empty row in the local cache. | |||||
| for (auto row_id : prefetch_keys) { | |||||
| TensorRow row; | |||||
| row.setId(row_id); | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); | |||||
| auto row_it = ttbl.begin(); | |||||
| for (auto row_id : prefetch_keys) { | |||||
| auto &row = *row_it; | |||||
| if (row.empty()) { | |||||
| cache_miss->push_back(row_id); | cache_miss->push_back(row_id); | ||||
| } | } | ||||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| ++row_it; | |||||
| } | } | ||||
| return rc; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| Status CacheBase::Prefetcher(int32_t worker_id) { | Status CacheBase::Prefetcher(int32_t worker_id) { | ||||
| @@ -322,6 +311,16 @@ Status CacheBase::Prefetcher(int32_t worker_id) { | |||||
| return rc; | return rc; | ||||
| } | } | ||||
| } while (rc.IsNetWorkError()); | } while (rc.IsNetWorkError()); | ||||
| // In case any thread is waiting for the rows to come back and blocked on a semaphore, | |||||
| // we will put an empty row in the local cache. | |||||
| if (rc.IsError() && AllowCacheMiss()) { | |||||
| for (auto row_id : prefetch_keys) { | |||||
| TensorRow row; | |||||
| row.setId(row_id); | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| cache_miss.push_back(row_id); | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| if (AllowCacheMiss()) { | if (AllowCacheMiss()) { | ||||
| // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from | // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from | ||||
| @@ -24,7 +24,6 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/connector.h" | #include "minddata/dataset/engine/connector.h" | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | #include "minddata/dataset/engine/datasetops/repeat_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| @@ -309,8 +309,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha | |||||
| if (st_.compare_exchange_strong(expected, State::kDirty)) { | if (st_.compare_exchange_strong(expected, State::kDirty)) { | ||||
| // We will do a deep copy but write directly into CacheRequest protobuf or shared memory | // We will do a deep copy but write directly into CacheRequest protobuf or shared memory | ||||
| Status rc; | Status rc; | ||||
| cleaner_copy_ = | |||||
| std::make_shared<CacheRowRequest>(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient()); | |||||
| cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get()); | |||||
| rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); | rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); | ||||
| if (rc.IsOk()) { | if (rc.IsOk()) { | ||||
| // Send the request async. The cleaner will check the return code. | // Send the request async. The cleaner will check the return code. | ||||
| @@ -153,7 +153,7 @@ Status CacheOp::WaitForCachingAllRows() { | |||||
| bool BuildPhaseDone = true; | bool BuildPhaseDone = true; | ||||
| do { | do { | ||||
| RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); | ||||
| BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase); | |||||
| BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase); | |||||
| if (!BuildPhaseDone) { | if (!BuildPhaseDone) { | ||||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | std::this_thread::sleep_for(std::chrono::milliseconds(100)); | ||||
| } | } | ||||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| CacheErrorPass::CacheErrorPass() : is_cached_(false) {} | |||||
| CacheErrorPass::CacheErrorPass() : is_cached_(false), is_mappable_(false) {} | |||||
| // Identifies the subtree below this node as being cached | // Identifies the subtree below this node as being cached | ||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | ||||
| @@ -75,5 +75,81 @@ Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modifi | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| #endif | #endif | ||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||||
| // Turn on the flag that this is a tree with mappable leaf dataset | |||||
| is_mappable_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| // Turn off the flag that we're under a merge op | |||||
| is_cached_ = false; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Currently, returns an error if RepeatOp exists under a cache | |||||
| // Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. | |||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| if (is_cached_ && is_mappable_) { | |||||
| RETURN_STATUS_UNEXPECTED("Repeat is not supported as a descendant operator under a mappable cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -67,8 +67,81 @@ class CacheErrorPass : public NodePass { | |||||
| Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||||
| /// \brief Identifies the leaf dataset as being mappable | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||||
| /// \brief Identifies the subtree above this node as not being cached | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Identifies and block repeat under cache scenarios | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||||
| private: | private: | ||||
| bool is_cached_; | bool is_cached_; | ||||
| bool is_mappable_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -3,7 +3,6 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE | |||||
| add_library(utils OBJECT | add_library(utils OBJECT | ||||
| arena.cc | arena.cc | ||||
| buddy.cc | buddy.cc | ||||
| cache_pool.cc | |||||
| circular_pool.cc | circular_pool.cc | ||||
| data_helper.cc | data_helper.cc | ||||
| memory_pool.cc | memory_pool.cc | ||||
| @@ -16,8 +15,6 @@ add_library(utils OBJECT | |||||
| lock.cc | lock.cc | ||||
| semaphore.cc | semaphore.cc | ||||
| status.cc | status.cc | ||||
| storage_container.cc | |||||
| storage_manager.cc | |||||
| slice.cc | slice.cc | ||||
| path.cc | path.cc | ||||
| wait_post.cc | wait_post.cc | ||||
| @@ -94,6 +94,11 @@ Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc, | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive"); | CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive"); | ||||
| try { | try { | ||||
| T *data = alloc.allocate(n); | T *data = alloc.allocate(n); | ||||
| // Some of our implementation of allocator (e.g. NumaAllocator) don't throw std::bad_alloc. | |||||
| // So we have to catch for null ptr | |||||
| if (data == nullptr) { | |||||
| return Status(StatusCode::kOutOfMemory); | |||||
| } | |||||
| if (!std::is_arithmetic<T>::value) { | if (!std::is_arithmetic<T>::value) { | ||||
| for (auto i = 0; i < n; i++) { | for (auto i = 0; i < n; i++) { | ||||
| std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...); | std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...); | ||||
| @@ -78,6 +78,18 @@ class Path { | |||||
| Path operator/(const char *); | Path operator/(const char *); | ||||
| bool operator==(const Path &rhs) const { return (path_ == rhs.path_); } | |||||
| bool operator!=(const Path &rhs) const { return (path_ != rhs.path_); } | |||||
| bool operator<(const Path &rhs) const { return (path_ < rhs.path_); } | |||||
| bool operator>(const Path &rhs) const { return (path_ > rhs.path_); } | |||||
| bool operator<=(const Path &rhs) const { return (path_ <= rhs.path_); } | |||||
| bool operator>=(const Path &rhs) const { return (path_ >= rhs.path_); } | |||||
| bool Exists(); | bool Exists(); | ||||
| bool IsDirectory(); | bool IsDirectory(); | ||||
| @@ -37,6 +37,11 @@ void Task::operator()() { | |||||
| ss << Services::GetUniqueID(); | ss << Services::GetUniqueID(); | ||||
| #endif | #endif | ||||
| MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started."; | MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started."; | ||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| native_handle_ = pthread_self(); | |||||
| #endif | |||||
| try { | try { | ||||
| // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set | // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set | ||||
| // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can | // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can | ||||
| @@ -96,7 +101,8 @@ Task::Task(const std::string &myName, const std::function<Status()> &f) | |||||
| task_group_(nullptr), | task_group_(nullptr), | ||||
| is_master_(false), | is_master_(false), | ||||
| running_(false), | running_(false), | ||||
| caught_severe_exception_(false) { | |||||
| caught_severe_exception_(false), | |||||
| native_handle_(0) { | |||||
| IntrpResource::ResetIntrpState(); | IntrpResource::ResetIntrpState(); | ||||
| wp_.ResetIntrpState(); | wp_.ResetIntrpState(); | ||||
| wp_.Clear(); | wp_.Clear(); | ||||
| @@ -164,5 +170,10 @@ Status Task::OverrideInterruptRc(const Status &rc) { | |||||
| } | } | ||||
| return rc; | return rc; | ||||
| } | } | ||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| pthread_t Task::GetNativeHandle() const { return native_handle_; } | |||||
| #endif | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,9 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ | ||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| #include <pthread.h> | |||||
| #endif | |||||
| #include <chrono> | #include <chrono> | ||||
| #include <exception> | #include <exception> | ||||
| #include <functional> | #include <functional> | ||||
| @@ -84,7 +87,7 @@ class Task : public IntrpResource { | |||||
| std::thread::id get_id() { return id_; } | std::thread::id get_id() { return id_; } | ||||
| std::string MyName() { return my_name_; } | |||||
| std::string MyName() const { return my_name_; } | |||||
| // An operator used by std::find | // An operator used by std::find | ||||
| bool operator==(const Task &other) const { return (this == &other); } | bool operator==(const Task &other) const { return (this == &other); } | ||||
| @@ -97,6 +100,10 @@ class Task : public IntrpResource { | |||||
| static Status OverrideInterruptRc(const Status &rc); | static Status OverrideInterruptRc(const Status &rc); | ||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| pthread_t GetNativeHandle() const; | |||||
| #endif | |||||
| private: | private: | ||||
| mutable std::mutex mux_; | mutable std::mutex mux_; | ||||
| std::string my_name_; | std::string my_name_; | ||||
| @@ -113,6 +120,12 @@ class Task : public IntrpResource { | |||||
| volatile bool running_; | volatile bool running_; | ||||
| volatile bool caught_severe_exception_; | volatile bool caught_severe_exception_; | ||||
| #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) | |||||
| pthread_t native_handle_; | |||||
| #else | |||||
| uint64_t native_handle_; | |||||
| #endif | |||||
| void ShutdownGroup(); | void ShutdownGroup(); | ||||
| TaskGroup *MyTaskGroup(); | TaskGroup *MyTaskGroup(); | ||||
| void set_task_group(TaskGroup *vg); | void set_task_group(TaskGroup *vg); | ||||
| @@ -24,7 +24,6 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "minddata/dataset/util/storage_container.h" // lint !e322 | |||||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | ||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| @@ -31,7 +31,7 @@ HandleRcExit $? 1 1 | |||||
| export RUN_CACHE_TEST=TRUE | export RUN_CACHE_TEST=TRUE | ||||
| # Each of these tests will create session, use it, then destroy it after the test | # Each of these tests will create session, use it, then destroy it after the test | ||||
| for i in $(seq 1 6) | |||||
| for i in $(seq 1 5) | |||||
| do | do | ||||
| test_name="test_cache_map_basic${i}" | test_name="test_cache_map_basic${i}" | ||||
| GetSession | GetSession | ||||
| @@ -121,6 +121,12 @@ HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_voc" 1 | PytestCmd "test_cache_map.py" "test_cache_map_voc" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| PytestCmd "test_cache_map.py" "test_cache_map_python_sampler" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" | |||||
| HandleRcExit $? 0 0 | |||||
| # Run two parallel pipelines (sharing cache) | # Run two parallel pipelines (sharing cache) | ||||
| for i in $(seq 1 2) | for i in $(seq 1 2) | ||||
| do | do | ||||
| @@ -309,6 +315,9 @@ HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1 | PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_nested_repeat" | |||||
| HandleRcExit $? 0 0 | |||||
| for i in $(seq 1 3) | for i in $(seq 1 3) | ||||
| do | do | ||||
| test_name="test_cache_nomap_multiple_cache${i}" | test_name="test_cache_nomap_multiple_cache${i}" | ||||
| @@ -107,49 +107,10 @@ def test_cache_map_basic2(): | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_map_basic3(): | def test_cache_map_basic3(): | ||||
| """ | |||||
| Test a repeat under mappable cache | |||||
| Cache | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| Repeat | |||||
| | | |||||
| ImageFolder | |||||
| """ | |||||
| logger.info("Test cache basic 3") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This DATA_DIR only has 2 images in it | |||||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.repeat(4) | |||||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||||
| logger.info("ds1.dataset_size is ", ds1.get_dataset_size()) | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||||
| logger.info("get data from dataset") | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||||
| assert num_iter == 8 | |||||
| logger.info('test_cache_basic3 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_basic4(): | |||||
| """ | """ | ||||
| Test different rows result in core dump | Test different rows result in core dump | ||||
| """ | """ | ||||
| logger.info("Test cache basic 4") | |||||
| logger.info("Test cache basic 3") | |||||
| if "SESSION_ID" in os.environ: | if "SESSION_ID" in os.environ: | ||||
| session_id = int(os.environ['SESSION_ID']) | session_id = int(os.environ['SESSION_ID']) | ||||
| else: | else: | ||||
| @@ -171,11 +132,11 @@ def test_cache_map_basic4(): | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | logger.info("Number of data in ds1: {} ".format(num_iter)) | ||||
| assert num_iter == 8 | assert num_iter == 8 | ||||
| logger.info('test_cache_basic4 Ended.\n') | |||||
| logger.info('test_cache_basic3 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_map_basic5(): | |||||
| def test_cache_map_basic4(): | |||||
| """ | """ | ||||
| Test Map with non-deterministic TensorOps above cache | Test Map with non-deterministic TensorOps above cache | ||||
| @@ -188,7 +149,7 @@ def test_cache_map_basic5(): | |||||
| ImageFolder | ImageFolder | ||||
| """ | """ | ||||
| logger.info("Test cache failure 5") | |||||
| logger.info("Test cache basic 4") | |||||
| if "SESSION_ID" in os.environ: | if "SESSION_ID" in os.environ: | ||||
| session_id = int(os.environ['SESSION_ID']) | session_id = int(os.environ['SESSION_ID']) | ||||
| else: | else: | ||||
| @@ -211,11 +172,11 @@ def test_cache_map_basic5(): | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | logger.info("Number of data in ds1: {} ".format(num_iter)) | ||||
| assert num_iter == 8 | assert num_iter == 8 | ||||
| logger.info('test_cache_failure5 Ended.\n') | |||||
| logger.info('test_cache_basic4 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_map_basic6(): | |||||
| def test_cache_map_basic5(): | |||||
| """ | """ | ||||
| Test cache as root node | Test cache as root node | ||||
| @@ -223,7 +184,7 @@ def test_cache_map_basic6(): | |||||
| | | | | ||||
| ImageFolder | ImageFolder | ||||
| """ | """ | ||||
| logger.info("Test cache basic 6") | |||||
| logger.info("Test cache basic 5") | |||||
| if "SESSION_ID" in os.environ: | if "SESSION_ID" in os.environ: | ||||
| session_id = int(os.environ['SESSION_ID']) | session_id = int(os.environ['SESSION_ID']) | ||||
| else: | else: | ||||
| @@ -239,7 +200,7 @@ def test_cache_map_basic6(): | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | logger.info("Number of data in ds1: {} ".format(num_iter)) | ||||
| assert num_iter == 2 | assert num_iter == 2 | ||||
| logger.info('test_cache_basic6 Ended.\n') | |||||
| logger.info('test_cache_basic5 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| @@ -502,6 +463,7 @@ def test_cache_map_failure7(): | |||||
| Generator | Generator | ||||
| """ | """ | ||||
| def generator_1d(): | def generator_1d(): | ||||
| for i in range(64): | for i in range(64): | ||||
| yield (np.array(i),) | yield (np.array(i),) | ||||
| @@ -528,6 +490,44 @@ def test_cache_map_failure7(): | |||||
| logger.info('test_cache_failure7 Ended.\n') | logger.info('test_cache_failure7 Ended.\n') | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_failure8(): | |||||
| """ | |||||
| Test a repeat under mappable cache (failure) | |||||
| Cache | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| Repeat | |||||
| | | |||||
| ImageFolder | |||||
| """ | |||||
| logger.info("Test cache failure 8") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This DATA_DIR only has 2 images in it | |||||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.repeat(4) | |||||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||||
| num_iter += 1 | |||||
| assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value) | |||||
| assert num_iter == 0 | |||||
| logger.info('test_cache_failure8 Ended.\n') | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_map_parameter_check(): | def test_cache_map_parameter_check(): | ||||
| """ | """ | ||||
| @@ -1702,6 +1702,125 @@ def test_cache_map_voc2(): | |||||
| logger.info("test_cache_map_voc2 Ended.\n") | logger.info("test_cache_map_voc2 Ended.\n") | ||||
| class ReverseSampler(ds.Sampler): | |||||
| def __iter__(self): | |||||
| for i in range(self.dataset_size - 1, -1, -1): | |||||
| yield i | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_python_sampler1(): | |||||
| """ | |||||
| Test using a python sampler, and cache after leaf | |||||
| Repeat | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| cache | |||||
| | | |||||
| ImageFolder | |||||
| """ | |||||
| logger.info("Test cache map python sampler1") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This DATA_DIR only has 2 images in it | |||||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op) | |||||
| ds1 = ds1.repeat(4) | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||||
| assert num_iter == 8 | |||||
| logger.info("test_cache_map_python_sampler1 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_python_sampler2(): | |||||
| """ | |||||
| Test using a python sampler, and cache after map | |||||
| Repeat | |||||
| | | |||||
| cache | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| ImageFolder | |||||
| """ | |||||
| logger.info("Test cache map python sampler2") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This DATA_DIR only has 2 images in it | |||||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler()) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||||
| ds1 = ds1.repeat(4) | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||||
| assert num_iter == 8 | |||||
| logger.info("test_cache_map_python_sampler2 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_map_nested_repeat(): | |||||
| """ | |||||
| Test cache on pipeline with nested repeat ops | |||||
| Repeat | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| Repeat | |||||
| | | |||||
| Cache | |||||
| | | |||||
| ImageFolder | |||||
| """ | |||||
| logger.info("Test cache map nested repeat") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This DATA_DIR only has 2 images in it | |||||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.repeat(4) | |||||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"]) | |||||
| ds1 = ds1.repeat(2) | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||||
| logger.info("get data from dataset") | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||||
| assert num_iter == 16 | |||||
| logger.info('test_cache_map_nested_repeat Ended.\n') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_cache_map_basic1() | test_cache_map_basic1() | ||||
| test_cache_map_basic2() | test_cache_map_basic2() | ||||
| @@ -1292,6 +1292,50 @@ def test_cache_nomap_epoch_ctrl3(): | |||||
| logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n") | logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_epoch_ctrl4(): | |||||
| """ | |||||
| Test using two-loops method with repeat under cache | |||||
| cache | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| repeat | |||||
| | | |||||
| TFRecord | |||||
| """ | |||||
| logger.info("Test cache nomap epoch ctrl4") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 3 records in it only | |||||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||||
| ds1 = ds1.repeat(2) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||||
| num_epoch = 5 | |||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) | |||||
| epoch_count = 0 | |||||
| for _ in range(num_epoch): | |||||
| row_count = 0 | |||||
| for _ in iter1: | |||||
| row_count += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(row_count)) | |||||
| assert row_count == 6 | |||||
| epoch_count += 1 | |||||
| assert epoch_count == num_epoch | |||||
| logger.info("test_cache_nomap_epoch_ctrl4 Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | ||||
| def test_cache_nomap_multiple_cache1(): | def test_cache_nomap_multiple_cache1(): | ||||
| """ | """ | ||||
| @@ -1734,6 +1778,47 @@ def test_cache_nomap_textfile2(): | |||||
| logger.info("test_cache_nomap_textfile2 Ended.\n") | logger.info("test_cache_nomap_textfile2 Ended.\n") | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_nested_repeat(): | |||||
| """ | |||||
| Test cache on pipeline with nested repeat ops | |||||
| Repeat | |||||
| | | |||||
| Cache | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| Repeat | |||||
| | | |||||
| TFRecord | |||||
| """ | |||||
| logger.info("Test cache nomap nested repeat") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 3 records in it only | |||||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.repeat(4) | |||||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||||
| ds1 = ds1.repeat(2) | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||||
| logger.info("get data from dataset") | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in ds1: {} ".format(num_iter)) | |||||
| assert num_iter == 24 | |||||
| logger.info('test_cache_nomap_nested_repeat Ended.\n') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_cache_nomap_basic1() | test_cache_nomap_basic1() | ||||
| test_cache_nomap_basic2() | test_cache_nomap_basic2() | ||||
| @@ -0,0 +1,10 @@ | |||||
| ~/cache/cache_admin --start | |||||
| session_id=$(~/cache/cache_admin -g | awk '{print $NF}') | |||||
| export SESSION_ID=${session_id} | |||||
| pytest dataset/test_cache_nomap.py::test_cache_nomap_server_stop & | |||||
| pid=("$!") | |||||
| sleep 2 | |||||
| ~/cache/cache_admin --stop | |||||
| sleep 1 | |||||
| wait ${pid} | |||||