Merge pull request !5930 from Jamie/CacheOp_devtags/v1.1.0
| @@ -36,6 +36,7 @@ include(CPack) | |||
| set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries") | |||
| set(INSTALL_PY_DIR ".") | |||
| set(INSTALL_BASE_DIR ".") | |||
| set(INSTALL_BIN_DIR "bin") | |||
| if (CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| set(INSTALL_LIB_DIR ".") | |||
| @@ -78,7 +79,14 @@ if (ENABLE_MINDDATA) | |||
| DESTINATION ${INSTALL_BASE_DIR} | |||
| COMPONENT mindspore | |||
| ) | |||
| if (CMAKE_SYSTEM_NAME MATCHES "Linux") | |||
| install( | |||
| TARGETS cache_admin cache_server | |||
| OPTIONAL | |||
| DESTINATION ${INSTALL_BIN_DIR} | |||
| COMPONENT mindspore | |||
| ) | |||
| endif() | |||
| file(GLOB_RECURSE OPENCV_LIB_LIST | |||
| ${opencv_LIBPATH}/libopencv_core* | |||
| ${opencv_LIBPATH}/libopencv_imgcodecs* | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <optional> | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| @@ -22,17 +23,19 @@ namespace dataset { | |||
| PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | |||
| (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | |||
| .def( | |||
| py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname, | |||
| std::optional<int32_t> port, int32_t prefetch_sz) { | |||
| std::shared_ptr<CacheClient> cc; | |||
| CacheClient::Builder builder; | |||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz); | |||
| if (hostname) builder.SetHostname(hostname.value()); | |||
| if (port) builder.SetPort(port.value()); | |||
| THROW_IF_ERROR(builder.Build(&cc)); | |||
| return cc; | |||
| })) | |||
| .def(py::init([](session_id_type id, uint64_t mem_sz, bool spill, | |||
| std::optional<std::string> hostname, std::optional<int32_t> port, | |||
| std::optional<int32_t> num_connections, std::optional<int32_t> prefetch_sz) { | |||
| std::shared_ptr<CacheClient> cc; | |||
| CacheClient::Builder builder; | |||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill); | |||
| if (hostname) builder.SetHostname(hostname.value()); | |||
| if (port) builder.SetPort(port.value()); | |||
| if (num_connections) builder.SetNumConnections(num_connections.value()); | |||
| if (prefetch_sz) builder.SetPrefetchSize(prefetch_sz.value()); | |||
| THROW_IF_ERROR(builder.Build(&cc)); | |||
| return cc; | |||
| })) | |||
| .def("GetStat", [](CacheClient &cc) { | |||
| CacheServiceStat stat{}; | |||
| THROW_IF_ERROR(cc.GetStat(&stat)); | |||
| @@ -18,6 +18,7 @@ | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| #include "minddata/dataset/util/system_pool.h" | |||
| @@ -33,7 +34,9 @@ ConfigManager::ConfigManager() | |||
| monitor_sampling_interval_(kCfgMonitorSamplingInterval), | |||
| callback_timout_(kCfgCallbackTimeout), | |||
| cache_host_(kCfgDefaultCacheHost), | |||
| cache_port_(kCfgDefaultCachePort) { | |||
| cache_port_(kCfgDefaultCachePort), | |||
| num_connections_(kDftNumConnections), | |||
| prefetch_size_(kDftPrefetchSize) { | |||
| auto env_cache_host = std::getenv("MS_CACHE_HOST"); | |||
| auto env_cache_port = std::getenv("MS_CACHE_PORT"); | |||
| if (env_cache_host != nullptr) { | |||
| @@ -71,6 +74,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) { | |||
| set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); | |||
| set_cache_host(j.value("cacheHost", cache_host_)); | |||
| set_cache_port(j.value("cachePort", cache_port_)); | |||
| set_num_connections(j.value("numConnections", num_connections_)); | |||
| set_prefetch_size(j.value("prefetchSize", prefetch_size_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -120,8 +125,12 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s | |||
| void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } | |||
| void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; } | |||
| void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = std::move(cache_host); } | |||
| void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; } | |||
| void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; } | |||
| void ConfigManager::set_prefetch_size(int32_t prefetch_size) { prefetch_size_ = prefetch_size; } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -97,6 +97,14 @@ class ConfigManager { | |||
| // @return The port of cache server | |||
| int32_t cache_port() const { return cache_port_; } | |||
| /// getter function | |||
| /// \return Number of tcp/ip connection | |||
| int32_t num_connections() const { return num_connections_; } | |||
| /// getter function | |||
| /// \return Prefetch size | |||
| int32_t prefetch_size() const { return prefetch_size_; } | |||
| // setter function | |||
| // @param rows_per_buffer - The setting to apply to the config | |||
| void set_rows_per_buffer(int32_t rows_per_buffer); | |||
| @@ -121,6 +129,14 @@ class ConfigManager { | |||
| // @param cache_port - The port of cache server | |||
| void set_cache_port(int32_t cache_port); | |||
| /// setter function | |||
| /// \param num_connections | |||
| void set_num_connections(int32_t num_connections); | |||
| /// setter function | |||
| /// \param prefetch_size | |||
| void set_prefetch_size(int32_t prefetch_size); | |||
| uint32_t seed() const; | |||
| // setter function | |||
| @@ -153,6 +169,8 @@ class ConfigManager { | |||
| uint32_t callback_timout_; | |||
| std::string cache_host_; | |||
| int32_t cache_port_; | |||
| int32_t num_connections_; | |||
| int32_t prefetch_size_; | |||
| // Private helper function that takes a nlohmann json format and populates the settings | |||
| // @param j - The json nlohmann json info | |||
| @@ -71,6 +71,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10; | |||
| constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds | |||
| constexpr int32_t kCfgDefaultCachePort = 50052; | |||
| constexpr char kCfgDefaultCacheHost[] = "127.0.0.1"; | |||
| constexpr int32_t kDftPrefetchSize = 20; | |||
| constexpr int32_t kDftNumConnections = 12; | |||
| // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | |||
| constexpr uint8_t kCVInvalidType = 255; | |||
| @@ -79,6 +79,14 @@ class TensorRow { | |||
| const vector_type &getRow() const { return row_; } | |||
| int64_t SizeInBytes() const { | |||
| size_t sz = 0; | |||
| for (auto &it : row_) { | |||
| sz += it->SizeInBytes(); | |||
| } | |||
| return sz; | |||
| } | |||
| // Wrapper functions to support vector operations | |||
| void emplace_back(value_type t) { row_.emplace_back(t); } | |||
| @@ -12,7 +12,9 @@ add_library(engine-cache-client OBJECT | |||
| if (ENABLE_CACHE) | |||
| ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto) | |||
| target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc) | |||
| target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} | |||
| cache_grpc_client.cc | |||
| cache_ipc.cc) | |||
| add_library(engine-cache-server OBJECT | |||
| ${CACHE_GRPC_SRCS} | |||
| @@ -37,12 +37,17 @@ int main(int argc, char **argv) { | |||
| warningMsg += "WARNING:\n"; | |||
| warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research"; | |||
| warningMsg += " purposes at this time.\n"; | |||
| warningMsg += "This command is currently disabled. Quitting.\n"; | |||
| auto env_enable_cache = std::getenv("MS_ENABLE_CACHE"); | |||
| if (env_enable_cache == nullptr || strcmp(env_enable_cache, "TRUE") != 0) { | |||
| // temporary disable cache feature in the current release | |||
| warningMsg += "This command is currently disabled. Quitting.\n"; | |||
| std::cerr << warningMsg << std::endl; | |||
| return 0; | |||
| } | |||
| warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n"; | |||
| // A warning message until the code is mature enough. | |||
| std::cerr << warningMsg << std::endl; | |||
| // temporary disable cache feature in the current release | |||
| return 0; | |||
| if (argc == 1) { | |||
| args.Help(); | |||
| @@ -19,9 +19,11 @@ | |||
| #include <sys/wait.h> | |||
| #include <unistd.h> | |||
| #include <cerrno> | |||
| #include <iomanip> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| @@ -39,6 +41,7 @@ CacheAdminArgHandler::CacheAdminArgHandler() | |||
| num_workers_(kDefaultNumWorkers), | |||
| shm_mem_sz_(kDefaultSharedMemorySizeInGB), | |||
| log_level_(kDefaultLogLevel), | |||
| memory_cap_ratio_(kMemoryCapRatio), | |||
| hostname_(kCfgDefaultCacheHost), | |||
| spill_dir_(kDefaultSpillDir), | |||
| command_id_(CommandId::kCmdUnknown) { | |||
| @@ -62,6 +65,9 @@ CacheAdminArgHandler::CacheAdminArgHandler() | |||
| arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize; | |||
| arg_map_["-l"] = ArgValue::kArgLogLevel; | |||
| arg_map_["--minloglevel"] = ArgValue::kArgLogLevel; | |||
| arg_map_["-r"] = ArgValue::kArgMemoryCapRatio; | |||
| arg_map_["--memory_cap_ratio"] = ArgValue::kArgMemoryCapRatio; | |||
| arg_map_["--list_sessions"] = ArgValue::kArgListSessions; | |||
| // Initialize argument tracker with false values | |||
| for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) { | |||
| ArgValue currAV = static_cast<ArgValue>(i); | |||
| @@ -69,6 +75,8 @@ CacheAdminArgHandler::CacheAdminArgHandler() | |||
| } | |||
| } | |||
| CacheAdminArgHandler::~CacheAdminArgHandler() = default; | |||
| Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id) { | |||
| // Detect if the user tried to provide this argument more than once | |||
| @@ -102,7 +110,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| // Now, attempt to convert the value into it's string format for output | |||
| // Now, attempt to convert the value into it's numeric format for output | |||
| try { | |||
| *out_arg = std::stoul(value_as_string); | |||
| } catch (const std::exception &e) { | |||
| @@ -140,7 +148,13 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, | |||
| // If there is no argument to get, such as the --start command, then out_arg will be a nullptr. | |||
| if (out_arg != nullptr) { | |||
| // Fetch the argument from the arg stream into a string | |||
| *arg_stream >> *out_arg; | |||
| if (arg_stream->rdbuf()->in_avail() != 0) { | |||
| *arg_stream >> *out_arg; | |||
| } else { | |||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| if (out_arg->empty()) { | |||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| @@ -150,12 +164,62 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id) { | |||
| // Detect if the user tried to provide this argument more than once | |||
| ArgValue selected_arg = arg_map_[option]; | |||
| if (used_args_[selected_arg]) { | |||
| std::string err_msg = "The " + option + " argument was given more than once."; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| // Flag that this arg is used now | |||
| used_args_[selected_arg] = true; | |||
| // Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument. | |||
| // Other options are actual commands, for example "--start". | |||
| // If this option is also a command, make sure there has not been multiple commands given before assigning it. | |||
| if (command_id != CommandId::kCmdUnknown) { | |||
| if (command_id_ != CommandId::kCmdUnknown) { | |||
| std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } else { | |||
| command_id_ = command_id; | |||
| } | |||
| } | |||
| std::string value_as_string; | |||
| // Fetch the argument from the arg stream into a string | |||
| *arg_stream >> value_as_string; | |||
| if (value_as_string.empty()) { | |||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| // Now, attempt to convert the value into it's string format for output | |||
| try { | |||
| *out_arg = std::stof(value_as_string, nullptr); | |||
| } catch (const std::exception &e) { | |||
| std::string err_msg = "Invalid numeric value: " + value_as_string; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { | |||
| std::string tok; | |||
| while (*arg_stream >> tok) { | |||
| switch (arg_map_[tok]) { | |||
| case ArgValue::kArgHost: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream)); | |||
| // Temporary sanity check. We only support localhost for now | |||
| if (hostname_ != std::string(kCfgDefaultCacheHost)) { | |||
| std::string err_msg = | |||
| "Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used."; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| break; | |||
| } | |||
| case ArgValue::kArgPort: { | |||
| @@ -203,6 +267,14 @@ Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgMemoryCapRatio: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, &memory_cap_ratio_, arg_stream)); | |||
| break; | |||
| } | |||
| case ArgValue::kArgListSessions: { | |||
| RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdListSessions)); | |||
| break; | |||
| } | |||
| default: { | |||
| // Save space delimited trailing arguments | |||
| trailing_args_ += (" " + tok); | |||
| @@ -232,9 +304,12 @@ Status CacheAdminArgHandler::Validate() { | |||
| } | |||
| // Additional checks here | |||
| if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value."); | |||
| if (num_workers_ < 1 || num_workers_ > 100) | |||
| return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100."); | |||
| if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); | |||
| // port range check? | |||
| if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) | |||
| return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1"); | |||
| if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kSyntaxError, "Port must be in range (1025..65535)."); | |||
| return Status::OK(); | |||
| } | |||
| @@ -245,12 +320,9 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| Help(); | |||
| break; | |||
| } | |||
| case CommandId::kCmdStart: { | |||
| RETURN_IF_NOT_OK(StartServer()); | |||
| break; | |||
| } | |||
| case CommandId::kCmdStart: | |||
| case CommandId::kCmdStop: { | |||
| RETURN_IF_NOT_OK(StopServer()); | |||
| RETURN_IF_NOT_OK(StartStopServer(command_id_)); | |||
| break; | |||
| } | |||
| case CommandId::kCmdGenerateSession: { | |||
| @@ -259,7 +331,7 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| auto rq = std::make_shared<GenerateSessionIdRequest>(); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| std::cout << rq->GetSessionId() << std::endl; | |||
| std::cout << "Session: " << rq->GetSessionId() << std::endl; | |||
| break; | |||
| } | |||
| case CommandId::kCmdDestroySession: { | |||
| @@ -273,6 +345,39 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| std::cout << "Drop session successful" << std::endl; | |||
| break; | |||
| } | |||
| case CommandId::kCmdListSessions: { | |||
| CacheClientGreeter comm(hostname_, port_, 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| auto rq = std::make_shared<ListSessionsRequest>(); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo(); | |||
| if (!session_info.empty()) { | |||
| std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached" | |||
| << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl; | |||
| for (auto curr_session : session_info) { | |||
| std::string cache_id; | |||
| std::string stat_mem_cached; | |||
| std::string stat_disk_cached; | |||
| std::string stat_avg_cached; | |||
| int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); | |||
| cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc); | |||
| stat_mem_cached = | |||
| (curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached); | |||
| stat_disk_cached = | |||
| (curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached); | |||
| stat_avg_cached = | |||
| (curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz); | |||
| std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12) | |||
| << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached | |||
| << std::endl; | |||
| } | |||
| } else { | |||
| std::cout << "No active sessions." << std::endl; | |||
| } | |||
| break; | |||
| } | |||
| default: { | |||
| RETURN_STATUS_UNEXPECTED("Invalid cache admin command id."); | |||
| break; | |||
| @@ -282,7 +387,7 @@ Status CacheAdminArgHandler::RunCommand() { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheAdminArgHandler::StartServer() { | |||
| Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { | |||
| // There currently does not exist any "install path" or method to identify which path the installed binaries will | |||
| // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | |||
| // cache_admin binary (this process). | |||
| @@ -324,7 +429,10 @@ Status CacheAdminArgHandler::StartServer() { | |||
| close(fd[1]); | |||
| dup2(fd[0], 0); | |||
| close(fd[0]); | |||
| wait(nullptr); | |||
| int status; | |||
| if (waitpid(pid, &status, 0) == -1) { | |||
| RETURN_STATUS_UNEXPECTED("waitpid fails. errno = " + std::to_string(errno)); | |||
| } | |||
| std::string msg; | |||
| const int32_t buf_sz = 1024; | |||
| msg.resize(buf_sz); | |||
| @@ -335,6 +443,13 @@ Status CacheAdminArgHandler::StartServer() { | |||
| } | |||
| msg.resize(n); | |||
| std::cout << msg << std::endl; | |||
| if (WIFEXITED(status)) { | |||
| auto exit_status = WEXITSTATUS(status); | |||
| if (exit_status) { | |||
| std::string errMsg = "Child exit status " + std::to_string(exit_status); | |||
| return Status(StatusCode::kUnexpectedError, errMsg); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } else { | |||
| // Child here ... | |||
| @@ -350,19 +465,29 @@ Status CacheAdminArgHandler::StartServer() { | |||
| std::string shared_memory_string = std::to_string(shm_mem_sz_); | |||
| std::string minloglevel_string = std::to_string(log_level_); | |||
| std::string daemonize_string = "true"; | |||
| char *argv[8]; | |||
| argv[0] = cache_server_binary.data(); // First arg is usually the binary name | |||
| argv[1] = spill_dir_.data(); | |||
| argv[2] = workers_string.data(); | |||
| argv[3] = port_string.data(); | |||
| argv[4] = shared_memory_string.data(); | |||
| argv[5] = minloglevel_string.data(); | |||
| argv[6] = daemonize_string.data(); | |||
| argv[7] = nullptr; | |||
| std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_); | |||
| char *argv[9]; | |||
| if (command_id == CommandId::kCmdStart) { | |||
| argv[0] = cache_server_binary.data(); | |||
| argv[1] = spill_dir_.data(); | |||
| argv[2] = workers_string.data(); | |||
| argv[3] = port_string.data(); | |||
| argv[4] = shared_memory_string.data(); | |||
| argv[5] = minloglevel_string.data(); | |||
| argv[6] = daemonize_string.data(); | |||
| argv[7] = memory_cap_ratio_string.data(); | |||
| argv[8] = nullptr; | |||
| } else { | |||
| // We are doing a --stop. Change the name to '-' and we also need the port number. | |||
| // The rest we don't need. | |||
| argv[0] = std::string("-").data(); | |||
| argv[1] = port_string.data(); | |||
| argv[2] = nullptr; | |||
| } | |||
| // Now exec the binary | |||
| execv(argv[0], argv); | |||
| execv(cache_server_binary.data(), argv); | |||
| // If the exec was successful, this line will never be reached due to process image being replaced. | |||
| // ..unless exec failed. | |||
| std::string err_msg = "Failed to exec cache server: " + cache_server_binary; | |||
| @@ -371,16 +496,6 @@ Status CacheAdminArgHandler::StartServer() { | |||
| } | |||
| } | |||
| Status CacheAdminArgHandler::StopServer() { | |||
| CacheClientGreeter comm(hostname_, port_, 1); | |||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||
| auto rq = std::make_shared<ShutdownRequest>(); | |||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||
| // We will ignore the rc because if the shutdown is successful, the server will not reply back. | |||
| (void)rq->Wait(); | |||
| return Status::OK(); | |||
| } | |||
| void CacheAdminArgHandler::Help() { | |||
| std::cerr << "Syntax:\n"; | |||
| std::cerr << " cache_admin [--start | --stop]\n"; | |||
| @@ -390,8 +505,12 @@ void CacheAdminArgHandler::Help() { | |||
| std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; | |||
| std::cerr << " [ [-w | --workers] <number of workers> ]\n"; | |||
| std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; | |||
| std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | |||
| std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; | |||
| std::cerr << " [ --list_sessions ]\n"; | |||
| // Do not expose these option to the user via help or documentation, but the options do exist to aid with | |||
| // development and tuning. | |||
| // std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | |||
| // std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n"; | |||
| std::cerr << " [--help]" << std::endl; | |||
| } | |||
| } // namespace dataset | |||
| @@ -32,6 +32,7 @@ class CacheAdminArgHandler { | |||
| static constexpr int32_t kDefaultNumWorkers = 32; | |||
| static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | |||
| static constexpr int32_t kDefaultLogLevel = 1; | |||
| static constexpr float kMemoryCapRatio = 0.8; | |||
| static const char kServerBinary[]; | |||
| static const char kDefaultSpillDir[]; | |||
| @@ -42,12 +43,13 @@ class CacheAdminArgHandler { | |||
| kCmdStop = 2, | |||
| kCmdGenerateSession = 3, | |||
| kCmdDestroySession = 4, | |||
| kCmdListSessions = 5, | |||
| kCmdUnknown = 32767 | |||
| }; | |||
| CacheAdminArgHandler(); | |||
| ~CacheAdminArgHandler() = default; | |||
| virtual ~CacheAdminArgHandler(); | |||
| Status ParseArgStream(std::stringstream *arg_stream); | |||
| @@ -70,12 +72,12 @@ class CacheAdminArgHandler { | |||
| kArgNumWorkers = 9, | |||
| kArgSharedMemorySize = 10, | |||
| kArgLogLevel = 11, | |||
| kArgNumArgs = 12 // Must be the last position to provide a count | |||
| kArgMemoryCapRatio = 12, | |||
| kArgListSessions = 13, | |||
| kArgNumArgs = 14 // Must be the last position to provide a count | |||
| }; | |||
| Status StartServer(); | |||
| Status StopServer(); | |||
| Status StartStopServer(CommandId); | |||
| Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id = CommandId::kCmdUnknown); | |||
| @@ -83,6 +85,9 @@ class CacheAdminArgHandler { | |||
| Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id = CommandId::kCmdUnknown); | |||
| Status AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream, | |||
| CommandId command_id = CommandId::kCmdUnknown); | |||
| Status Validate(); | |||
| CommandId command_id_; | |||
| @@ -90,6 +95,7 @@ class CacheAdminArgHandler { | |||
| int32_t num_workers_; | |||
| int32_t shm_mem_sz_; | |||
| int32_t log_level_; | |||
| float memory_cap_ratio_; | |||
| session_id_type session_id_; | |||
| std::string hostname_; | |||
| std::string spill_dir_; | |||
| @@ -17,27 +17,19 @@ | |||
| #include "minddata/dataset/util/path.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) | |||
| : ptr_(nullptr), val_in_GB_(val_in_GB), port_(port), shmid_(-1) {} | |||
| CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) { | |||
| // We create the shared memory and we will destroy it. All other client just detach only. | |||
| shm_.RemoveResourcesOnExit(); | |||
| } | |||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() { | |||
| #if CACHE_LOCAL_CLIENT | |||
| if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) { | |||
| shmdt(this->ptr_); | |||
| } | |||
| this->ptr_ = nullptr; | |||
| if (shmid_ != -1) { | |||
| shmctl(shmid_, IPC_RMID, nullptr); | |||
| // Also remove the path we use to generate ftok. | |||
| Path p(PortToUnixSocketPath(port_)); | |||
| (void)p.Remove(); | |||
| } | |||
| #endif | |||
| // Also remove the path we use to generate ftok. | |||
| Path p(PortToUnixSocketPath(port_)); | |||
| (void)p.Remove(); | |||
| } | |||
| Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | |||
| size_t val_in_GB) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| #if CACHE_LOCAL_CLIENT | |||
| auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); | |||
| if (ba == nullptr) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| @@ -46,26 +38,13 @@ Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryAr | |||
| // the destructor of *out to deal. | |||
| (*out).reset(ba); | |||
| // Generate the ftok using a combination of port. | |||
| int err; | |||
| auto shm_key = PortToFtok(port, &err); | |||
| if (shm_key == (key_t)-1) { | |||
| std::string errMsg = "Ftok failed with errno " + std::to_string(err); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||
| SharedMemory::shm_key_t shm_key; | |||
| RETURN_IF_NOT_OK(PortToFtok(port, &shm_key)); | |||
| ba->shm_.SetPublicKey(shm_key); | |||
| // Value is in GB. Convert into bytes. | |||
| int64_t sz = val_in_GB * 1073741824L; | |||
| ba->shmid_ = shmget(shm_key, sz, IPC_CREAT | IPC_EXCL | access_mode); | |||
| if (ba->shmid_) { | |||
| ba->ptr_ = shmat(ba->shmid_, nullptr, 0); | |||
| if (ba->ptr_ == reinterpret_cast<void *>(-1)) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||
| } | |||
| ba->impl_ = std::make_unique<ArenaImpl>(ba->ptr_, sz); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); | |||
| } | |||
| #endif | |||
| RETURN_IF_NOT_OK(ba->shm_.Create(sz)); | |||
| ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -21,6 +21,7 @@ | |||
| #include <string> | |||
| #include "minddata/dataset/util/arena.h" | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// This is a derived class of Arena but resides in shared memory | |||
| @@ -73,10 +74,9 @@ class CachedSharedMemoryArena : public MemoryPool { | |||
| private: | |||
| mutable std::mutex mux_; | |||
| void *ptr_; | |||
| int32_t val_in_GB_; | |||
| int32_t port_; | |||
| int shmid_; | |||
| SharedMemory shm_; | |||
| std::unique_ptr<ArenaImpl> impl_; | |||
| /// Private constructor. Not to be called directly. | |||
| CachedSharedMemoryArena(int32_t port, size_t val_in_GB); | |||
| @@ -24,26 +24,26 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheClient::Builder::Builder() | |||
| : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) { | |||
| : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_connections_(0), prefetch_size_(0) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| hostname_ = cfg->cache_host(); | |||
| port_ = cfg->cache_port(); | |||
| num_workers_ = cfg->num_parallel_workers(); | |||
| prefetch_size_ = 20; // rows_per_buf is too small (1 by default). | |||
| num_connections_ = cfg->num_connections(); // number of async tcp/ip connections | |||
| prefetch_size_ = cfg->prefetch_size(); // prefetch size | |||
| } | |||
| Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *out = | |||
| std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_); | |||
| *out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_connections_, | |||
| prefetch_size_); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::Builder::SanityCheck() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive"); | |||
| @@ -55,26 +55,32 @@ Status CacheClient::Builder::SanityCheck() { | |||
| // Constructor | |||
| CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, | |||
| int32_t port, int32_t num_workers, int32_t prefetch_size) | |||
| int32_t port, int32_t num_connections, int32_t prefetch_size) | |||
| : server_connection_id_(0), | |||
| cache_mem_sz_(cache_mem_sz), | |||
| spill_(spill), | |||
| local_bypass_(false), | |||
| hostname_(std::move(hostname)), | |||
| port_(port), | |||
| num_workers_(num_workers), | |||
| prefetch_size_(prefetch_size) { | |||
| num_connections_(num_connections), | |||
| prefetch_size_(prefetch_size), | |||
| fetch_all_keys_(true) { | |||
| cinfo_.set_session_id(session_id); | |||
| comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_); | |||
| comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_connections_); | |||
| } | |||
| CacheClient::~CacheClient() { | |||
| cache_miss_keys_wp_.Set(); | |||
| (void)comm_->ServiceStop(); | |||
| } | |||
| // print method for display cache details | |||
| void CacheClient::Print(std::ostream &out) const { | |||
| out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc() | |||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz() | |||
| << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname() | |||
| << "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers() | |||
| << "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha | |||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << GetCacheMemSz() | |||
| << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << GetHostname() | |||
| << "\n Port: " << GetPort() << "\n Number of rpc workers: " << GetNumConnections() | |||
| << "\n Prefetch size: " << GetPrefetchSize() << "\n Local client support: " << std::boolalpha | |||
| << SupportLocalClient(); | |||
| } | |||
| @@ -199,14 +205,6 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::PurgeCache() { | |||
| UniqueLock lck(&mux_); | |||
| auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_); | |||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClient::DestroyCache() { | |||
| UniqueLock lck(&mux_); | |||
| auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_); | |||
| @@ -253,5 +251,71 @@ Status CacheClient::BuildPhaseDone() const { | |||
| } | |||
| Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); } | |||
| void CacheClient::ServerRunningOutOfResources() { | |||
| bool expected = true; | |||
| if (fetch_all_keys_.compare_exchange_strong(expected, false)) { | |||
| Status rc; | |||
| // Server runs out of memory or disk space to cache any more rows. | |||
| // First of all, we will turn off the locking. | |||
| auto toggle_write_mode_rq = std::make_shared<ToggleWriteModeRequest>(server_connection_id_, false); | |||
| rc = PushRequest(toggle_write_mode_rq); | |||
| if (rc.IsError()) { | |||
| return; | |||
| } | |||
| // Wait until we can toggle the state of the server to non-locking | |||
| rc = toggle_write_mode_rq->Wait(); | |||
| if (rc.IsError()) { | |||
| return; | |||
| } | |||
| // Now we get a list of all the keys not cached at the server so | |||
| // we can filter out at the prefetch level. | |||
| auto cache_miss_rq = std::make_shared<GetCacheMissKeysRequest>(server_connection_id_); | |||
| rc = PushRequest(cache_miss_rq); | |||
| if (rc.IsError()) { | |||
| return; | |||
| } | |||
| rc = cache_miss_rq->Wait(); | |||
| if (rc.IsError()) { | |||
| return; | |||
| } | |||
| // We will get back a vector of row id between [min,max] that are absent in the server. | |||
| auto &row_id_buf = cache_miss_rq->reply_.result(); | |||
| auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data()); | |||
| std::vector<row_id_type> row_ids; | |||
| auto sz = p->row_id()->size(); | |||
| row_ids.reserve(sz); | |||
| for (auto i = 0; i < sz; ++i) { | |||
| row_ids.push_back(p->row_id()->Get(i)); | |||
| } | |||
| cache_miss_keys_ = std::make_unique<CacheMissKeys>(row_ids); | |||
| // We are all set. | |||
| cache_miss_keys_wp_.Set(); | |||
| } | |||
| } | |||
| CacheClient::CacheMissKeys::CacheMissKeys(const std::vector<row_id_type> &v) { | |||
| auto it = v.begin(); | |||
| min_ = *it; | |||
| ++it; | |||
| max_ = *it; | |||
| ++it; | |||
| while (it != v.end()) { | |||
| gap_.insert(*it); | |||
| ++it; | |||
| } | |||
| MS_LOG(WARNING) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size(); | |||
| } | |||
| bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { | |||
| if (key > max_ || key < min_) { | |||
| return true; | |||
| } else if (key == min_ || key == max_) { | |||
| return false; | |||
| } else { | |||
| auto it = gap_.find(key); | |||
| return it != gap_.end(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,8 +16,13 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ | |||
| #include <atomic> | |||
| #include <iostream> | |||
| #include <limits> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <set> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| @@ -31,6 +36,8 @@ | |||
| #endif | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/util/lock.h" | |||
| #include "minddata/dataset/util/cond_var.h" | |||
| #include "minddata/dataset/util/queue_map.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -89,10 +96,10 @@ class CacheClient { | |||
| } | |||
| /// Setter function to set number of async rpc workers | |||
| /// \param num_workers | |||
| /// \param num_connections | |||
| /// \return Builder object itself | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| num_workers_ = num_workers; | |||
| Builder &SetNumConnections(int32_t num_connections) { | |||
| num_connections_ = num_connections; | |||
| return *this; | |||
| } | |||
| @@ -105,13 +112,13 @@ class CacheClient { | |||
| } | |||
| /// Getter functions | |||
| session_id_type getSessionId() const { return session_id_; } | |||
| uint64_t getCacheMemSz() const { return cache_mem_sz_; } | |||
| session_id_type GetSessionId() const { return session_id_; } | |||
| uint64_t GetCacheMemSz() const { return cache_mem_sz_; } | |||
| bool isSpill() const { return spill_; } | |||
| const std::string &getHostname() const { return hostname_; } | |||
| int32_t getPort() const { return port_; } | |||
| int32_t getNumWorkers() const { return num_workers_; } | |||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||
| int32_t GetPort() const { return port_; } | |||
| int32_t GetNumConnections() const { return num_connections_; } | |||
| int32_t GetPrefetchSize() const { return prefetch_size_; } | |||
| Status SanityCheck(); | |||
| @@ -123,7 +130,7 @@ class CacheClient { | |||
| bool spill_; | |||
| std::string hostname_; | |||
| int32_t port_; | |||
| int32_t num_workers_; | |||
| int32_t num_connections_; | |||
| int32_t prefetch_size_; | |||
| }; | |||
| @@ -132,10 +139,10 @@ class CacheClient { | |||
| /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited | |||
| /// \param spill Spill to disk if out of memory | |||
| CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port, | |||
| int32_t num_workers, int32_t prefetch_size); | |||
| int32_t num_connections, int32_t prefetch_size); | |||
| /// \brief Destructor | |||
| ~CacheClient() { (void)comm_->ServiceStop(); } | |||
| ~CacheClient(); | |||
| /// \brief Send a TensorRow to the cache server | |||
| /// \param[in] row | |||
| @@ -161,10 +168,6 @@ class CacheClient { | |||
| /// \return Status object | |||
| Status CreateCache(uint32_t tree_crc, bool generate_id); | |||
| /// \brief Purge a cache. Cache can be reused after reset. | |||
| /// \return Status object | |||
| Status PurgeCache(); | |||
| /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. | |||
| /// \return Status object | |||
| Status DestroyCache(); | |||
| @@ -218,12 +221,31 @@ class CacheClient { | |||
| /// Getter functions | |||
| session_id_type session_id() const { return cinfo_.session_id(); } | |||
| uint64_t getCacheMemSz() const { return cache_mem_sz_; } | |||
| uint64_t GetCacheMemSz() const { return cache_mem_sz_; } | |||
| bool isSpill() const { return spill_; } | |||
| const std::string &getHostname() const { return hostname_; } | |||
| int32_t getPort() const { return port_; } | |||
| int32_t getNumWorkers() const { return num_workers_; } | |||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||
| const std::string &GetHostname() const { return hostname_; } | |||
| int32_t GetPort() const { return port_; } | |||
| int32_t GetNumConnections() const { return num_connections_; } | |||
| int32_t GetPrefetchSize() const { return prefetch_size_; } | |||
| /// MergeOp will notify us when the server can't cache any more rows. | |||
| /// We will stop any attempt to fetch any rows that are most likely | |||
| /// not present at the server. | |||
| void ServerRunningOutOfResources(); | |||
| /// \brief Check if a row is 100% cache miss at the server by checking the local information | |||
| /// \param key row id to be test | |||
| /// \return true if not at the server | |||
| bool KeyIsCacheMiss(row_id_type key) { | |||
| if (cache_miss_keys_) { | |||
| // Make sure it is fully built even though the pointer is not null | |||
| Status rc = cache_miss_keys_wp_.Wait(); | |||
| if (rc.IsOk()) { | |||
| return cache_miss_keys_->KeyIsCacheMiss(key); | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| private: | |||
| mutable RWLock mux_; | |||
| @@ -240,9 +262,27 @@ class CacheClient { | |||
| bool local_bypass_; | |||
| std::string hostname_; | |||
| int32_t port_; | |||
| int32_t num_workers_; | |||
| int32_t num_connections_; | |||
| int32_t prefetch_size_; | |||
| mutable std::shared_ptr<CacheClientGreeter> comm_; | |||
| std::atomic<bool> fetch_all_keys_; | |||
| WaitPost cache_miss_keys_wp_; | |||
| /// A structure shared by all the prefetchers to know what keys are missing at the server. | |||
| class CacheMissKeys { | |||
| public: | |||
| explicit CacheMissKeys(const std::vector<row_id_type> &v); | |||
| ~CacheMissKeys() = default; | |||
| /// This checks if a key is missing. | |||
| /// \param key | |||
| /// \return true if definitely a key miss | |||
| bool KeyIsCacheMiss(row_id_type key); | |||
| private: | |||
| row_id_type min_; | |||
| row_id_type max_; | |||
| std::set<row_id_type> gap_; | |||
| }; | |||
| std::unique_ptr<CacheMissKeys> cache_miss_keys_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -25,13 +25,6 @@ | |||
| #define CACHE_LOCAL_CLIENT 1 | |||
| #endif | |||
| #ifdef CACHE_LOCAL_CLIENT | |||
| #include <sys/types.h> | |||
| #include <sys/ipc.h> | |||
| #include <sys/shm.h> | |||
| #else | |||
| typedef int key_t; | |||
| #endif | |||
| #ifdef ENABLE_CACHE | |||
| #include <grpcpp/grpcpp.h> | |||
| #endif | |||
| @@ -54,6 +47,8 @@ constexpr static uint32_t kLocalClientSupport = 1; | |||
| /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is | |||
| /// inline in the protobuf. This also implies kLocalClientSupport is also true. | |||
| constexpr static uint32_t kDataIsInSharedMemory = 2; | |||
| /// \brief Size of each message used in message queue. | |||
| constexpr static int32_t kSharedMessageSize = 2048; | |||
| /// \brief Convert a Status object into a protobuf | |||
| /// \param rc[in] Status object | |||
| @@ -62,29 +57,10 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) { | |||
| reply->set_rc(static_cast<int32_t>(rc.get_code())); | |||
| reply->set_msg(rc.ToString()); | |||
| } | |||
| /// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number | |||
| /// \param port | |||
| /// \return unix socket url | |||
| inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } | |||
| /// \brief Generate a shared memory key using the tcp/ip port. | |||
| /// \note It must be called after the cache server generates the unix socket or ftok will fail. | |||
| /// \note Caller must check the return value. -1 means ftok failed. | |||
| /// \param[in] port | |||
| /// \param[out] err. If not null and ftok fails, this will contain the value of errno | |||
| /// \return key | |||
| inline key_t PortToFtok(int port, int *err) { | |||
| key_t shmkey = -1; | |||
| #ifdef CACHE_LOCAL_CLIENT | |||
| const std::string unix_path = PortToUnixSocketPath(port); | |||
| shmkey = ftok(unix_path.data(), 'a'); | |||
| if (err != nullptr && shmkey == (key_t)-1) { | |||
| *err = errno; | |||
| } | |||
| #endif | |||
| return shmkey; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | |||
| @@ -17,34 +17,10 @@ | |||
| #include <chrono> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, | |||
| std::unique_ptr<CacheClientRequestTag> &&tag) { | |||
| // If there is anything extra we need to do before we send. | |||
| RETURN_IF_NOT_OK(tag->base_rq_->Prepare()); | |||
| // One minute timeout | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||
| tag->ctx_.set_deadline(deadline); | |||
| tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq); | |||
| tag->rpc_->StartCall(); | |||
| // Last step is we release the ownership and transfer it to the completion queue. | |||
| // The memory will be released by WorkerEntry or by the destructor when we drain the queue | |||
| auto ccReqTag = tag.release(); | |||
| ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, | |||
| ccReqTag); // inject this object into the completion queue | |||
| return Status::OK(); | |||
| } | |||
| CacheClientGreeter::~CacheClientGreeter() { (void)ServiceStop(); } | |||
| CacheClientGreeter::~CacheClientGreeter() { | |||
| (void)ServiceStop(); | |||
| // Detach from shared memory if any | |||
| if (shmat_addr_ != nullptr) { | |||
| shmdt(shmat_addr_); | |||
| shmat_addr_ = nullptr; | |||
| } | |||
| } | |||
| CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) | |||
| : num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) { | |||
| CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections) | |||
| : num_connections_(num_connections), request_cnt_(0) { | |||
| grpc::ChannelArguments args; | |||
| // We need to bump up the message size to unlimited. The default receiving | |||
| // message limit is 4MB which is not big enough. | |||
| @@ -68,21 +44,11 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port | |||
| Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) { | |||
| *local_bypass = false; | |||
| #if CACHE_LOCAL_CLIENT | |||
| int err; | |||
| shm_key_ = PortToFtok(port, &err); | |||
| if (shm_key_ == (key_t)-1) { | |||
| std::string errMsg = "Ftok failed with errno " + std::to_string(err); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| SharedMemory::shm_key_t shm_key; | |||
| RETURN_IF_NOT_OK(PortToFtok(port, &shm_key)); | |||
| // Attach to the shared memory | |||
| shm_id_ = shmget(shm_key_, 0, 0); | |||
| if (shm_id_ == -1) { | |||
| RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); | |||
| } | |||
| shmat_addr_ = shmat(shm_id_, nullptr, 0); | |||
| if (shmat_addr_ == reinterpret_cast<void *>(-1)) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||
| } | |||
| mem_.SetPublicKey(shm_key); | |||
| RETURN_IF_NOT_OK(mem_.Attach()); | |||
| *local_bypass = true; | |||
| #endif | |||
| return Status::OK(); | |||
| @@ -90,7 +56,7 @@ Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass | |||
| Status CacheClientGreeter::DoServiceStart() { | |||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | |||
| RETURN_IF_NOT_OK(DispatchWorkers(num_workers_)); | |||
| RETURN_IF_NOT_OK(DispatchWorkers(num_connections_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -100,19 +66,40 @@ Status CacheClientGreeter::DoServiceStop() { | |||
| // Shutdown the TaskGroup. | |||
| vg_.interrupt_all(); | |||
| vg_.join_all(Task::WaitFlag::kNonBlocking); | |||
| // Drain the queue | |||
| bool success; | |||
| void *tag; | |||
| while (cq_.Next(&tag, &success)) { | |||
| auto r = reinterpret_cast<CacheClientRequestTag *>(tag); | |||
| delete r; | |||
| // Drain the queue. We know how many requests we send out | |||
| while (!req_.empty()) { | |||
| bool success; | |||
| void *tag; | |||
| while (cq_.Next(&tag, &success)) { | |||
| auto r = reinterpret_cast<CacheClientRequestTag *>(tag); | |||
| req_.erase(r->seqNo_); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) { | |||
| auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq)); | |||
| return tag->MakeCall(stub_.get(), &cq_, std::move(tag)); | |||
| // If there is anything extra we need to do before we send. | |||
| RETURN_IF_NOT_OK(rq->Prepare()); | |||
| auto seqNo = request_cnt_.fetch_add(1); | |||
| auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq), seqNo); | |||
| // One minute timeout | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||
| tag->ctx_.set_deadline(deadline); | |||
| tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_); | |||
| tag->rpc_->StartCall(); | |||
| auto ccReqTag = tag.get(); | |||
| // Insert it into the map. | |||
| { | |||
| std::unique_lock<std::mutex> lck(mux_); | |||
| auto r = req_.emplace(seqNo, std::move(tag)); | |||
| if (!r.second) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| // Last step is to tag the request. | |||
| ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, ccReqTag); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheClientGreeter::WorkerEntry() { | |||
| @@ -129,15 +116,26 @@ Status CacheClientGreeter::WorkerEntry() { | |||
| auto &rc = rq->rc_; | |||
| if (!rc.ok()) { | |||
| auto error_code = rq->rc_.error_code(); | |||
| std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); | |||
| Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| std::string err_msg; | |||
| if (error_code == grpc::StatusCode::UNAVAILABLE) { | |||
| err_msg = | |||
| "Cache server is unreachable. Make sure the server is running. GRPC Code" + std::to_string(error_code); | |||
| } else { | |||
| err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); | |||
| } | |||
| Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg); | |||
| Status2CacheReply(remote_rc, &rq->base_rq_->reply_); | |||
| } | |||
| // Notify the waiting thread. | |||
| rq->Notify(); | |||
| } | |||
| // We can now free the memory | |||
| delete rq; | |||
| { | |||
| // We can now free the memory | |||
| std::unique_lock<std::mutex> lck(mux_); | |||
| auto seqNo = rq->seqNo_; | |||
| auto n = req_.erase(seqNo); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(n == 1, "Sequence " + std::to_string(seqNo) + " not found"); | |||
| } | |||
| } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { | |||
| // If we are interrupted, exit. Otherwise wait again. | |||
| RETURN_IF_INTERRUPTED(); | |||
| @@ -16,10 +16,14 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | |||
| #include <atomic> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| #include "minddata/dataset/util/service.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| @@ -34,16 +38,10 @@ namespace dataset { | |||
| class CacheClientRequestTag { | |||
| public: | |||
| friend class CacheClientGreeter; | |||
| explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {} | |||
| explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq, int64_t seqNo) | |||
| : base_rq_(std::move(rq)), seqNo_(seqNo) {} | |||
| ~CacheClientRequestTag() = default; | |||
| /// \brief Make a RPC call | |||
| /// \param stub from CacheClientGreeter | |||
| /// \param cq from CacheClientGreeter | |||
| /// \return Status object | |||
| static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, | |||
| std::unique_ptr<CacheClientRequestTag> &&tag); | |||
| /// \brief Notify the client that a result has come back from the server | |||
| void Notify() { base_rq_->wp_.Set(); } | |||
| @@ -52,6 +50,7 @@ class CacheClientRequestTag { | |||
| grpc::Status rc_; | |||
| grpc::ClientContext ctx_; | |||
| std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_; | |||
| int64_t seqNo_; | |||
| }; | |||
| /// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC | |||
| @@ -60,7 +59,7 @@ class CacheClientGreeter : public Service { | |||
| friend class CacheClient; | |||
| public: | |||
| explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers); | |||
| explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections); | |||
| ~CacheClientGreeter(); | |||
| /// Override base Service class | |||
| @@ -85,17 +84,18 @@ class CacheClientGreeter : public Service { | |||
| /// \brief This returns where we attach to the shared memory. | |||
| /// \return Base address of the shared memory. | |||
| const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | |||
| const void *SharedMemoryBaseAddr() const { return mem_.SharedMemoryBaseAddr(); } | |||
| private: | |||
| std::shared_ptr<grpc::Channel> channel_; | |||
| std::unique_ptr<CacheServerGreeter::Stub> stub_; | |||
| grpc::CompletionQueue cq_; | |||
| TaskGroup vg_; | |||
| int32_t num_workers_; | |||
| key_t shm_key_; | |||
| int32_t shm_id_; | |||
| void *shmat_addr_; | |||
| int32_t num_connections_; | |||
| std::atomic<int64_t> request_cnt_; | |||
| mutable std::mutex mux_; | |||
| std::map<int64_t, std::unique_ptr<CacheClientRequestTag>> req_; | |||
| SharedMemory mem_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -47,53 +47,10 @@ void CacheServerGreeterImpl::Shutdown() { | |||
| CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); } | |||
| Status CacheServerGreeterImpl::IpcResourceCleanup() { | |||
| #if CACHE_LOCAL_CLIENT | |||
| int err; | |||
| auto shm_key = PortToFtok(port_, &err); | |||
| // We are expecting the unix path doesn't exist. | |||
| if (shm_key == (key_t)-1) { | |||
| return Status::OK(); | |||
| } | |||
| // Attach to the shared memory | |||
| auto shm_id = shmget(shm_key, 0, 0); | |||
| if (shm_id == -1) { | |||
| return Status::OK(); | |||
| } | |||
| struct shmid_ds ds {}; | |||
| auto inx = shmctl(shm_id, IPC_STAT, &ds); | |||
| if (inx == -1) { | |||
| std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id); | |||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| if (ds.shm_nattch == 0) { | |||
| // Stale shared memory from last time. | |||
| // Remove both the memory and the socket path | |||
| inx = shmctl(shm_id, IPC_RMID, nullptr); | |||
| if (inx == -1) { | |||
| std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id); | |||
| errMsg += ". Errno :" + std::to_string(errno); | |||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| Path p(unix_socket_); | |||
| (void)p.Remove(); | |||
| } else { | |||
| // Server is already up. | |||
| MS_LOG(ERROR) << "Cache server is already up and running"; | |||
| // We return a duplicate error. The main() will intercept | |||
| // and output a proper message | |||
| return Status(StatusCode::kDuplicateKey); | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServerGreeterImpl::Run() { | |||
| // To listen on all interfaces, use 0.0.0.0 | |||
| // Use 127.0.0.1 if just locally on the same machine. | |||
| std::string host("0.0.0.0"); // listen on all interfaces. | |||
| // Future, allow the user to choose listening interface. For now, default to localhost | |||
| std::string host("127.0.0.1"); | |||
| std::string server_address = host + ":" + std::to_string(port_); | |||
| grpc::ServerBuilder builder; | |||
| // Default message size for gRPC is 4MB. Increase it to 2g-1 | |||
| @@ -101,9 +58,6 @@ Status CacheServerGreeterImpl::Run() { | |||
| int port_tcpip = 0; | |||
| #if CACHE_LOCAL_CLIENT | |||
| int port_local = 0; | |||
| // Check if we need to do clean up on the shared memory if the server | |||
| // came down unexpectedly like SEGV | |||
| RETURN_IF_NOT_OK(IpcResourceCleanup()); | |||
| // We also optimize on local clients on the same machine using unix socket | |||
| builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local); | |||
| #endif | |||
| @@ -41,7 +41,7 @@ class CacheServerRequest : public BaseRequest { | |||
| st_(STATE::CREATE), | |||
| responder_(&ctx_) {} | |||
| ~CacheServerRequest() = default; | |||
| ~CacheServerRequest() override = default; | |||
| /// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this | |||
| /// functor will translate each protobuf into some form understood by by CacheService class. | |||
| @@ -87,8 +87,6 @@ class CacheServerGreeterImpl final { | |||
| void Shutdown(); | |||
| Status IpcResourceCleanup(); | |||
| private: | |||
| int32_t port_; | |||
| size_t shm_pool_sz_in_gb_; | |||
| @@ -0,0 +1,163 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| #include <sys/stat.h> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status PortToFtok(int port, SharedMemory::shm_key_t *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| key_t shmkey = -1; | |||
| const std::string unix_path = PortToUnixSocketPath(port); | |||
| shmkey = ftok(unix_path.data(), 'a'); | |||
| if (shmkey == (key_t)-1) { | |||
| std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno); | |||
| return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg); | |||
| } | |||
| *out = shmkey; | |||
| return Status::OK(); | |||
| } | |||
| SharedMessage::~SharedMessage() { | |||
| // Only remove the queue if we are asked to. | |||
| if (remove_ipc_on_exit_ && msg_qid_ != -1) { | |||
| // Remove the message que and never mind about the return code. | |||
| (void)msgctl(msg_qid_, IPC_RMID, nullptr); | |||
| msg_qid_ = -1; | |||
| } | |||
| } | |||
| Status SharedMessage::Create() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ == -1, "Message queue already created"); | |||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||
| msg_qid_ = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); | |||
| if (msg_qid_ == -1) { | |||
| std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SharedMessage::SendStatus(const Status &rc) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); | |||
| StatusMsgBuf msg{ | |||
| 1, | |||
| }; | |||
| msg.body.status.err_code = static_cast<int32_t>(rc.get_code()); | |||
| auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err)); | |||
| msg.body.status.err_msg[rc.ToString().size()] = '\0'; | |||
| err = msgsnd(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), IPC_NOWAIT); | |||
| if (err == -1) { | |||
| std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SharedMessage::ReceiveStatus(Status *rc) { | |||
| RETURN_UNEXPECTED_IF_NULL(rc); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); | |||
| struct StatusMsgBuf msg {}; | |||
| auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR); | |||
| if (err == -1) { | |||
| std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| Status rc_recv(static_cast<StatusCode>(msg.body.status.err_code), msg.body.status.err_msg); | |||
| *rc = std::move(rc_recv); | |||
| return Status::OK(); | |||
| } | |||
| SharedMemory::~SharedMemory() { | |||
| if (shmat_addr_) { | |||
| (void)Detach(); | |||
| } | |||
| if (remove_ipc_on_exit_ && shm_id_ != -1) { | |||
| // Remove the shared memory and never mind about the return code. | |||
| Status rc = Destroy(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << rc.ToString(); | |||
| } | |||
| } | |||
| shm_id_ = -1; | |||
| shmat_addr_ = nullptr; | |||
| } | |||
| Status SharedMemory::Create(int64_t sz) { | |||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||
| shm_id_ = shmget(shm_key_, sz, IPC_CREAT | IPC_EXCL | access_mode); | |||
| if (shm_id_ == -1) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); | |||
| } else { | |||
| shmat_addr_ = shmat(shm_id_, nullptr, 0); | |||
| if (shmat_addr_ == reinterpret_cast<void *>(-1)) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SharedMemory::Attach() { | |||
| shm_id_ = shmget(shm_key_, 0, 0); | |||
| if (shm_id_ == -1) { | |||
| RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); | |||
| } | |||
| shmat_addr_ = shmat(shm_id_, nullptr, 0); | |||
| if (shmat_addr_ == reinterpret_cast<void *>(-1)) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SharedMemory::Detach() { | |||
| if (shmat_addr_) { | |||
| auto err = shmdt(shmat_addr_); | |||
| if (err == -1) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory detach failed. Errno " + std::to_string(errno)); | |||
| } | |||
| } | |||
| shmat_addr_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| Status SharedMemory::Destroy() { | |||
| // Remove the shared memory and never mind about the return code. | |||
| auto err = shmctl(shm_id_, IPC_RMID, nullptr); | |||
| if (err == -1) { | |||
| std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_); | |||
| errMsg += ". Errno :" + std::to_string(errno); | |||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SharedMemory::GetNumAttached(int32_t *num) { | |||
| RETURN_UNEXPECTED_IF_NULL(num); | |||
| struct shmid_ds ds {}; | |||
| auto err = shmctl(shm_id_, IPC_STAT, &ds); | |||
| if (err == -1) { | |||
| std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id_); | |||
| errMsg += "\nPlease remove it manually using ipcrm -m command"; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| *num = ds.shm_nattch; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,207 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ | |||
| #include <sys/types.h> | |||
| #include <sys/ipc.h> | |||
| #include <sys/shm.h> | |||
| #include <sys/msg.h> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// A message queue structure between the parent and the child process | |||
| struct StatusMsgBuf { | |||
| int64_t mtype; | |||
| union { | |||
| char mtext[1]; | |||
| struct { | |||
| int32_t err_code; | |||
| char err_msg[kSharedMessageSize]; | |||
| } status; | |||
| } body; | |||
| }; | |||
| class BaseIPC { | |||
| public: | |||
| BaseIPC() : remove_ipc_on_exit_(false) {} | |||
| virtual ~BaseIPC() {} | |||
| /// Indicate if we should remove the ipc resource on exit. Usually this is done by parent process. | |||
| void RemoveResourcesOnExit() { remove_ipc_on_exit_ = true; } | |||
| /// Copy constructors | |||
| BaseIPC(const BaseIPC &rhs) : remove_ipc_on_exit_(false) {} | |||
| BaseIPC &operator=(const BaseIPC &rhs) { | |||
| if (&rhs != this) { | |||
| remove_ipc_on_exit_ = false; | |||
| } | |||
| return *this; | |||
| } | |||
| /// Move constructors | |||
| BaseIPC(BaseIPC &&rhs) noexcept : remove_ipc_on_exit_(rhs.remove_ipc_on_exit_) { rhs.remove_ipc_on_exit_ = false; } | |||
| BaseIPC &operator=(BaseIPC &&rhs) noexcept { | |||
| if (&rhs != this) { | |||
| remove_ipc_on_exit_ = rhs.remove_ipc_on_exit_; | |||
| rhs.remove_ipc_on_exit_ = false; | |||
| } | |||
| return *this; | |||
| } | |||
| protected: | |||
| bool remove_ipc_on_exit_; | |||
| }; | |||
| /// \brief This wraps a shared message for the communication between processes. It is used primarily | |||
| /// for starting and stopping a server. | |||
| class SharedMessage : public BaseIPC { | |||
| public: | |||
| using queue_id_t = int; | |||
| SharedMessage() : msg_qid_(-1) {} | |||
| explicit SharedMessage(queue_id_t qid) : msg_qid_(qid) {} | |||
| ~SharedMessage() override; | |||
| /// Copy constructors | |||
| SharedMessage(const SharedMessage &rhs) : BaseIPC(rhs), msg_qid_(rhs.msg_qid_) {} | |||
| SharedMessage &operator=(const SharedMessage &rhs) { | |||
| if (&rhs != this) { | |||
| msg_qid_ = rhs.msg_qid_; | |||
| BaseIPC::operator=(rhs); | |||
| } | |||
| return *this; | |||
| } | |||
| /// Move constructors | |||
| SharedMessage(SharedMessage &&rhs) noexcept : BaseIPC(std::move(rhs)) { | |||
| msg_qid_ = rhs.msg_qid_; | |||
| rhs.msg_qid_ = -1; | |||
| } | |||
| SharedMessage &operator=(SharedMessage &&rhs) noexcept { | |||
| if (&rhs != this) { | |||
| msg_qid_ = rhs.msg_qid_; | |||
| rhs.msg_qid_ = -1; | |||
| BaseIPC::operator=(std::move(rhs)); | |||
| } | |||
| return *this; | |||
| } | |||
| /// Return the private id | |||
| queue_id_t GetMsgQueueId() const { return msg_qid_; } | |||
| /// \brief Create a private message queue | |||
| Status Create(); | |||
| /// Send a Status object | |||
| Status SendStatus(const Status &rc); | |||
| /// Retrieve a Status object | |||
| Status ReceiveStatus(Status *rc); | |||
| private: | |||
| queue_id_t msg_qid_; | |||
| }; | |||
| /// \brief This wraps a shared memory for the communication between processes. It is used primarily | |||
| /// for transporting large tensor rows. | |||
| class SharedMemory : public BaseIPC { | |||
| public: | |||
| using shm_key_t = int; | |||
| using shm_id_t = int; | |||
| SharedMemory() : shm_id_(-1), shm_key_(-1), shmat_addr_(nullptr) {} | |||
| explicit SharedMemory(shm_key_t public_key) : shm_id_(-1), shm_key_(public_key), shmat_addr_(nullptr) {} | |||
| ~SharedMemory() override; | |||
| /// Copy constructors | |||
| SharedMemory(const SharedMemory &rhs) | |||
| : BaseIPC(rhs), shm_id_(rhs.shm_id_), shm_key_(rhs.shm_key_), shmat_addr_(rhs.shmat_addr_) {} | |||
| SharedMemory &operator=(const SharedMemory &rhs) { | |||
| if (&rhs != this) { | |||
| shm_id_ = rhs.shm_id_; | |||
| shm_key_ = rhs.shm_key_; | |||
| shmat_addr_ = rhs.shmat_addr_; | |||
| BaseIPC::operator=(rhs); | |||
| } | |||
| return *this; | |||
| } | |||
| /// Move constructors | |||
| SharedMemory(SharedMemory &&rhs) noexcept : BaseIPC(std::move(rhs)) { | |||
| shm_id_ = rhs.shm_id_; | |||
| shm_key_ = rhs.shm_key_; | |||
| shmat_addr_ = rhs.shmat_addr_; | |||
| rhs.shm_id_ = -1; | |||
| rhs.shm_key_ = -1; | |||
| rhs.shmat_addr_ = nullptr; | |||
| } | |||
| SharedMemory &operator=(SharedMemory &&rhs) noexcept { | |||
| if (&rhs != this) { | |||
| shm_id_ = rhs.shm_id_; | |||
| shm_key_ = rhs.shm_key_; | |||
| shmat_addr_ = rhs.shmat_addr_; | |||
| rhs.shm_id_ = -1; | |||
| rhs.shm_key_ = -1; | |||
| rhs.shmat_addr_ = nullptr; | |||
| BaseIPC::operator=(std::move(rhs)); | |||
| } | |||
| return *this; | |||
| } | |||
| /// \brief Set the public key | |||
| void SetPublicKey(key_t public_key) { shm_key_ = public_key; } | |||
| /// \brief This returns where we attach to the shared memory. | |||
| /// \return Base address of the shared memory. | |||
| const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | |||
| void *SharedMemoryBaseAddr() { return shmat_addr_; } | |||
| /// \brief Attach to shared memory | |||
| /// \return Status object | |||
| Status Attach(); | |||
| /// Detach from shared memory | |||
| /// \return Status object | |||
| Status Detach(); | |||
| /// Create shared memory | |||
| /// \return Status object | |||
| Status Create(int64_t sz); | |||
| /// Destroy shared memory | |||
| /// \return Status object | |||
| Status Destroy(); | |||
| /// \brief Return the shared memory id | |||
| shm_id_t GetSharedMemoryId() const { return shm_id_; } | |||
| /// \brief Get number of processes attached to the shared memory | |||
| /// \return Status object | |||
| Status GetNumAttached(int32_t *num); | |||
| private: | |||
| shm_id_t shm_id_; | |||
| shm_key_t shm_key_; | |||
| void *shmat_addr_; | |||
| }; | |||
| /// \brief Generate a shared memory key using the tcp/ip port. | |||
| /// \note It must be called after the cache server generates the unix socket or ftok will fail. | |||
| /// \note Caller must check the return value. -1 means ftok failed. | |||
| /// \param[in] port | |||
| /// \param[out] err. If not null and ftok fails, this will contain the value of errno | |||
| /// \return key | |||
| Status PortToFtok(int port, SharedMemory::shm_key_t *); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ | |||
| @@ -21,24 +21,136 @@ | |||
| #include <glog/logging.h> | |||
| #endif | |||
| #include <cstdlib> | |||
| #include <thread> | |||
| #include <chrono> | |||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| namespace ds = mindspore::dataset; | |||
| int main(int argc, char **argv) { | |||
| /// Send a synchronous command to the local server using tcp/ip. | |||
| /// We aren't using any client code because this binary is not necessarily linked with the client library. | |||
| /// So just using grpc call directly. | |||
| /// \param port tcp/ip port to use | |||
| /// \param type Type of command. | |||
| /// \param out grpc result | |||
| /// \return Status object | |||
| ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply, | |||
| grpc::Status *out) { | |||
| if (rq == nullptr) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null"); | |||
| } | |||
| if (reply == nullptr) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null"); | |||
| } | |||
| if (out == nullptr) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null"); | |||
| } | |||
| const std::string hostname = "127.0.0.1"; | |||
| auto unix_socket = ds::PortToUnixSocketPath(port); | |||
| #if CACHE_LOCAL_CLIENT | |||
| const std::string target = "unix://" + unix_socket; | |||
| #else | |||
| const std::string target = hostname + ":" + std::to_string(port); | |||
| #endif | |||
| try { | |||
| rq->set_type(static_cast<int16_t>(type)); | |||
| grpc::ChannelArguments args; | |||
| grpc::ClientContext ctx; | |||
| grpc::CompletionQueue cq; | |||
| // Standard async rpc call | |||
| std::shared_ptr<grpc::Channel> channel = | |||
| grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); | |||
| std::unique_ptr<ds::CacheServerGreeter::Stub> stub = ds::CacheServerGreeter::NewStub(channel); | |||
| std::unique_ptr<grpc::ClientAsyncResponseReader<ds::CacheReply>> rpc = | |||
| stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq); | |||
| rpc->StartCall(); | |||
| // We need to pass a tag. But since this is the only request in the completion queue and so we | |||
| // just pass a dummy | |||
| int64_t dummy; | |||
| void *tag; | |||
| bool success; | |||
| rpc->Finish(reply, out, &dummy); | |||
| // Now we wait on the completion queue synchronously. | |||
| auto r = cq.Next(&tag, &success); | |||
| if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { | |||
| if (!success || tag != &dummy) { | |||
| std::string errMsg = "Unexpected programming error "; | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| if (out->ok()) { | |||
| return ds::Status(static_cast<ds::StatusCode>(reply->rc()), reply->msg()); | |||
| } else { | |||
| auto error_code = out->error_code(); | |||
| std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code); | |||
| return ds::Status(ds::StatusCode::kNetWorkError, errMsg); | |||
| } | |||
| } else { | |||
| std::string errMsg = "Unexpected queue rc = " + std::to_string(r); | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); | |||
| } | |||
| } | |||
| /// Stop the server | |||
| /// \param argv | |||
| /// \return Status object | |||
| ds::Status StopServer(int argc, char **argv) { | |||
| ds::Status rc; | |||
| ds::CacheServer::Builder builder; | |||
| std::string errMsg; | |||
| if (argc != 2) { | |||
| return ds::Status(ds::StatusCode::kSyntaxError); | |||
| } | |||
| int32_t port = strtol(argv[1], nullptr, 10); | |||
| // We will go through the builder to do some snaity check. We only need the port number | |||
| // to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything | |||
| // to the spill directory. | |||
| builder.SetPort(port).SetRootDirectory(""); | |||
| // Part of the sanity check is check the shared memory. If the server is up and running, we expect | |||
| // the return code is kDuplicate. | |||
| rc = builder.SanityCheck(); | |||
| if (rc.IsOk()) { | |||
| errMsg = "Server is not up or has been shutdown already."; | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, errMsg); | |||
| } else if (rc.get_code() != ds::StatusCode::kDuplicateKey) { | |||
| // Not OK, and no duplicate, just return the rc whatever it is. | |||
| return rc; | |||
| } else { | |||
| // Now we get some work to do. We will send a tcp/ip request to the given port. | |||
| // This binary is not linked with client side of code, so we will just call grpc directly. | |||
| ds::CacheRequest rq; | |||
| ds::CacheReply reply; | |||
| grpc::Status grpc_rc; | |||
| rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc); | |||
| // The request is like a self destruct message, the server will not send anything back and | |||
| // shutdown all incoming request. So we should expect some unexpected network error if | |||
| // all goes well and we expect to GRPC code 14. | |||
| auto err_code = grpc_rc.error_code(); | |||
| if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) { | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| return ds::Status::OK(); | |||
| } | |||
| // This executable is not to be called directly, and should be invoked by cache_admin executable. | |||
| if (argc != 7) { | |||
| rc = ds::Status(ds::StatusCode::kSyntaxError); | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return static_cast<int>(rc.get_code()); | |||
| /// Start the server | |||
| /// \param argv | |||
| /// \return Status object | |||
| ds::Status StartServer(int argc, char **argv) { | |||
| ds::Status rc; | |||
| ds::CacheServer::Builder builder; | |||
| if (argc != 8) { | |||
| return ds::Status(ds::StatusCode::kSyntaxError); | |||
| } | |||
| int32_t port = strtol(argv[3], nullptr, 10); | |||
| builder.SetRootDirectory(argv[1]) | |||
| .SetNumWorkers(strtol(argv[2], nullptr, 10)) | |||
| .SetPort(strtol(argv[3], nullptr, 10)) | |||
| .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)); | |||
| .SetPort(port) | |||
| .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)) | |||
| .SetMemoryCapRatio(strtof(argv[7], nullptr)); | |||
| #ifdef USE_GLOG | |||
| FLAGS_minloglevel = strtol(argv[5], nullptr, 10); | |||
| @@ -52,36 +164,42 @@ int main(int argc, char **argv) { | |||
| // is called. This is a standard procedure for daemonize a process on unix. | |||
| if (chdir("/") == -1) { | |||
| std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno); | |||
| std::cerr << errMsg << std::endl; | |||
| return -1; | |||
| } | |||
| // Simple check of the parameters before we move on. | |||
| rc = builder.SanityCheck(); | |||
| if (rc.IsError()) { | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return static_cast<int>(rc.get_code()); | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| // A message queue for communication between parent and child (if we fork). | |||
| ds::SharedMessage msg; | |||
| if (daemonize) { | |||
| #ifdef USE_GLOG | |||
| FLAGS_log_dir = "/tmp"; | |||
| google::InitGoogleLogging(argv[0]); | |||
| FLAGS_log_dir = "/tmp"; | |||
| google::InitGoogleLogging(argv[0]); | |||
| #endif | |||
| if (daemonize) { | |||
| // fork the child process to become the daemon | |||
| rc = msg.Create(); | |||
| if (rc.IsError()) { | |||
| return rc; | |||
| } | |||
| pid_t pid = fork(); | |||
| // failed to fork | |||
| if (pid < 0) { | |||
| std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); | |||
| std::cerr << err_msg << std::endl; | |||
| return errno; | |||
| std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno); | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else if (pid > 0) { | |||
| // Parent | |||
| // Parent and will be responsible for remove the queue on exit. | |||
| msg.RemoveResourcesOnExit(); | |||
| // Sleep one second and we attach to the msg que | |||
| std::this_thread::sleep_for(std::chrono::seconds(1)); | |||
| ds::Status child_rc; | |||
| rc = msg.ReceiveStatus(&child_rc); | |||
| if (rc.IsError()) { | |||
| return rc; | |||
| } | |||
| if (child_rc.IsError()) { | |||
| return child_rc; | |||
| } | |||
| std::cerr << "cache server daemon process has been created as process id: " << pid | |||
| << "\nCheck log file for any start up error" << std::endl; | |||
| signal(SIGCHLD, SIG_IGN); // ignore sig child signal. | |||
| return 0; | |||
| return ds::Status::OK(); | |||
| } else { | |||
| // Child process will continue from here if daemonize and parent has already exited. | |||
| // If we are running in the foreground, none of the code in block below will be run. | |||
| @@ -89,8 +207,8 @@ int main(int argc, char **argv) { | |||
| umask(0); | |||
| sid = setsid(); | |||
| if (sid < 0) { | |||
| MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno); | |||
| return errno; | |||
| std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno); | |||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } | |||
| close(0); | |||
| close(1); | |||
| @@ -100,22 +218,36 @@ int main(int argc, char **argv) { | |||
| // Dump the summary | |||
| MS_LOG(INFO) << builder << std::endl; | |||
| // Create the instance with some sanity checks built in | |||
| rc = builder.Build(); | |||
| if (rc.IsOk()) { | |||
| // If all goes well, kick off the threads. Loop forever and never return unless error. | |||
| ds::CacheServer &cs = ds::CacheServer::GetInstance(); | |||
| // Kick off the threads. Loop forever and never return unless error. | |||
| rc = cs.Run(); | |||
| if (rc.get_code() == ds::StatusCode::kDuplicateKey) { | |||
| std::string errMsg = "Server is already started"; | |||
| MS_LOG(ERROR) << errMsg; | |||
| std::cerr << errMsg << std::endl; | |||
| return 0; | |||
| } | |||
| rc = cs.Run(msg.GetMsgQueueId()); | |||
| } else if (daemonize) { | |||
| // If we didn't pass the sanity check to at least create the instance, use | |||
| // the message queue to return the error message if this is the child daemon. | |||
| return msg.SendStatus(rc); | |||
| } | |||
| return rc; | |||
| } | |||
| int main(int argc, char **argv) { | |||
| ds::Status rc; | |||
| ds::CacheServer::Builder builder; | |||
| // This executable is not to be called directly, and should be invoked by cache_admin executable. | |||
| if (strcmp(argv[0], "-") == 0) { | |||
| rc = StopServer(argc, argv); | |||
| } else { | |||
| rc = StartServer(argc, argv); | |||
| } | |||
| // Check result | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << rc.ToString(); | |||
| std::cerr << rc.ToString() << std::endl; | |||
| return static_cast<int>(rc.get_code()); | |||
| auto errCode = rc.get_code(); | |||
| auto errMsg = rc.ToString(); | |||
| std::cerr << errMsg << std::endl; | |||
| return static_cast<int>(errCode); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -250,5 +250,27 @@ Status GetStatRequest::PostReply() { | |||
| stat_.cache_service_state = msg->state(); | |||
| return Status::OK(); | |||
| } | |||
| Status ListSessionsRequest::PostReply() { | |||
| auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data()); | |||
| auto session_vector = msg->sessions(); | |||
| for (auto i = 0; i < session_vector->size(); ++i) { | |||
| SessionCacheInfo current_info; | |||
| CacheServiceStat stats; | |||
| auto current_session_info = session_vector->Get(i); | |||
| current_info.session_id = current_session_info->session_id(); | |||
| current_info.connection_id = current_session_info->connection_id(); | |||
| stats.num_mem_cached = current_session_info->stats()->num_mem_cached(); | |||
| stats.num_disk_cached = current_session_info->stats()->num_disk_cached(); | |||
| stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz(); | |||
| stats.min_row_id = current_session_info->stats()->min_row_id(); | |||
| stats.max_row_id = current_session_info->stats()->max_row_id(); | |||
| stats.cache_service_state = current_session_info->stats()->state(); | |||
| current_info.stats = stats; // fixed length struct. = operator is safe | |||
| session_info_list_.push_back(current_info); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -46,6 +46,13 @@ struct CacheServiceStat { | |||
| int8_t cache_service_state; | |||
| }; | |||
| /// \brief Info structure ListSessionsRequest | |||
| struct SessionCacheInfo { | |||
| session_id_type session_id; | |||
| connection_id_type connection_id; | |||
| CacheServiceStat stats; | |||
| }; | |||
| /// \brief CacheClient communicates with CacheServer using Requests. | |||
| class BaseRequest { | |||
| public: | |||
| @@ -54,7 +61,7 @@ class BaseRequest { | |||
| kCacheRow = 0, | |||
| kBatchFetchRows = 1, | |||
| kCreateCache = 2, | |||
| kPurgeCache = 3, | |||
| kGetCacheMissKeys = 3, | |||
| kDestroyCache = 4, | |||
| kGetStat = 5, | |||
| kCacheSchema = 6, | |||
| @@ -65,6 +72,9 @@ class BaseRequest { | |||
| kAllocateSharedBlock = 11, | |||
| kFreeSharedBlock = 12, | |||
| kStopService = 13, | |||
| kHeartBeat = 14, | |||
| kToggleWriteMode = 15, | |||
| kListSessions = 16, | |||
| // Add new request before it. | |||
| kRequestUnknown = 32767 | |||
| }; | |||
| @@ -73,6 +83,7 @@ class BaseRequest { | |||
| friend class CacheServerRequest; | |||
| friend class CacheClientGreeter; | |||
| friend class CacheClientRequestTag; | |||
| friend class CacheClient; | |||
| /// \brief Base class of a cache server request | |||
| /// \param type Type of the request | |||
| @@ -119,7 +130,7 @@ class FreeSharedBlockRequest : public BaseRequest { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(std::to_string(addr)); | |||
| } | |||
| ~FreeSharedBlockRequest() = default; | |||
| ~FreeSharedBlockRequest() override = default; | |||
| }; | |||
| /// \brief Request to cache a single TensorRow | |||
| @@ -136,7 +147,7 @@ class CacheRowRequest : public BaseRequest { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(cookie); | |||
| } | |||
| ~CacheRowRequest() = default; | |||
| ~CacheRowRequest() override = default; | |||
| /// \brief Serialize a TensorRow for streaming to the cache server | |||
| /// \param row TensorRow | |||
| @@ -183,7 +194,7 @@ class BatchFetchRequest : public BaseRequest { | |||
| friend class CacheServer; | |||
| friend class CacheService; | |||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass); | |||
| ~BatchFetchRequest() = default; | |||
| ~BatchFetchRequest() override = default; | |||
| Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); | |||
| private: | |||
| @@ -203,7 +214,7 @@ class CreateCacheRequest : public BaseRequest { | |||
| /// \param flag Attributes of the cache. | |||
| explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | |||
| CreateCacheFlag flag = CreateCacheFlag::kNone); | |||
| ~CreateCacheRequest() = default; | |||
| ~CreateCacheRequest() override = default; | |||
| void ParseResult(connection_id_type *id, std::string *out) { | |||
| auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | |||
| *id = p->connection_id(); | |||
| @@ -218,14 +229,15 @@ class CreateCacheRequest : public BaseRequest { | |||
| CreateCacheFlag flag_; | |||
| }; | |||
| /// \brief Request to purge a cache. | |||
| class PurgeCacheRequest : public BaseRequest { | |||
| /// \brief Request to get all the keys not present at the server. | |||
| /// \note Only applicable to mappable case | |||
| class GetCacheMissKeysRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) { | |||
| explicit GetCacheMissKeysRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetCacheMissKeys) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~PurgeCacheRequest() = default; | |||
| ~GetCacheMissKeysRequest() override = default; | |||
| }; | |||
| /// \brief Request to destroy a cache | |||
| @@ -235,7 +247,7 @@ class DestroyCacheRequest : public BaseRequest { | |||
| explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~DestroyCacheRequest() = default; | |||
| ~DestroyCacheRequest() override = default; | |||
| }; | |||
| /// \brief Obtain the statistics of the current connection | |||
| @@ -247,7 +259,7 @@ class GetStatRequest : public BaseRequest { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~GetStatRequest() = default; | |||
| ~GetStatRequest() override = default; | |||
| /// \brief Override base function to process the result. | |||
| Status PostReply() override; | |||
| @@ -269,7 +281,7 @@ class CacheSchemaRequest : public BaseRequest { | |||
| explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~CacheSchemaRequest() = default; | |||
| ~CacheSchemaRequest() override = default; | |||
| Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map); | |||
| }; | |||
| @@ -281,7 +293,7 @@ class FetchSchemaRequest : public BaseRequest { | |||
| explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) { | |||
| rq_.set_connection_id(connection_id); | |||
| } | |||
| ~FetchSchemaRequest() = default; | |||
| ~FetchSchemaRequest() override = default; | |||
| Status PostReply() override; | |||
| @@ -300,7 +312,7 @@ class BuildPhaseDoneRequest : public BaseRequest { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(cookie_); | |||
| } | |||
| ~BuildPhaseDoneRequest() = default; | |||
| ~BuildPhaseDoneRequest() override = default; | |||
| private: | |||
| std::string cookie_; | |||
| @@ -313,7 +325,7 @@ class DropSessionRequest : public BaseRequest { | |||
| explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) { | |||
| rq_.mutable_connection_info()->operator=(cinfo); | |||
| } | |||
| ~DropSessionRequest() = default; | |||
| ~DropSessionRequest() override = default; | |||
| }; | |||
| class GenerateSessionIdRequest : public BaseRequest { | |||
| @@ -325,11 +337,36 @@ class GenerateSessionIdRequest : public BaseRequest { | |||
| rq_.set_connection_id(0); | |||
| } | |||
| ~GenerateSessionIdRequest() = default; | |||
| ~GenerateSessionIdRequest() override = default; | |||
| session_id_type GetSessionId() { return atoi(reply_.result().data()); } | |||
| }; | |||
| class ListSessionsRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| ListSessionsRequest() : BaseRequest(RequestType::kListSessions) { | |||
| // This request is not specific to any cache or session | |||
| rq_.set_connection_id(0); | |||
| } | |||
| ~ListSessionsRequest() override = default; | |||
| /// \brief Override base function to process the result. | |||
| Status PostReply() override; | |||
| void GetSessionCacheInfo(std::vector<SessionCacheInfo> *info) { | |||
| if (info != nullptr) { | |||
| (*info) = session_info_list_; | |||
| } | |||
| } | |||
| std::vector<SessionCacheInfo> GetSessionCacheInfo() { return session_info_list_; } | |||
| private: | |||
| std::vector<SessionCacheInfo> session_info_list_; | |||
| }; | |||
| class AllocateSharedBlockRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| @@ -338,7 +375,7 @@ class AllocateSharedBlockRequest : public BaseRequest { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(std::to_string(requestedSz)); | |||
| } | |||
| ~AllocateSharedBlockRequest() = default; | |||
| ~AllocateSharedBlockRequest() override = default; | |||
| /// \brief On return from the server, we get the (relative) address where | |||
| /// the free block is located. | |||
| @@ -349,11 +386,15 @@ class AllocateSharedBlockRequest : public BaseRequest { | |||
| } | |||
| }; | |||
| class ShutdownRequest : public BaseRequest { | |||
| class ToggleWriteModeRequest : public BaseRequest { | |||
| public: | |||
| friend class CacheServer; | |||
| ShutdownRequest() : BaseRequest(RequestType::kStopService) {} | |||
| ~ShutdownRequest() = default; | |||
| explicit ToggleWriteModeRequest(connection_id_type connection_id, bool on_off) | |||
| : BaseRequest(RequestType::kToggleWriteMode) { | |||
| rq_.set_connection_id(connection_id); | |||
| rq_.add_buf_data(on_off ? "on" : "off"); | |||
| } | |||
| ~ToggleWriteModeRequest() override = default; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,7 @@ | |||
| #include <functional> | |||
| #include <limits> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/cache/cache_request.h" | |||
| #include "minddata/dataset/util/bit.h" | |||
| @@ -107,6 +108,8 @@ Status CacheServer::DoServiceStop() { | |||
| // First stop all the threads. | |||
| RETURN_IF_NOT_OK(vg_.ServiceStop()); | |||
| // Clean up all the caches if any. | |||
| // Dump a message how much memory we have consumed in total. | |||
| MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes."; | |||
| UniqueLock lck(&rwLock_); | |||
| auto it = all_caches_.begin(); | |||
| while (it != all_caches_.end()) { | |||
| @@ -121,7 +124,6 @@ Status CacheServer::DoServiceStop() { | |||
| } | |||
| CacheService *CacheServer::GetService(connection_id_type id) const { | |||
| SharedLock lck(&rwLock_); | |||
| auto it = all_caches_.find(id); | |||
| if (it != all_caches_.end()) { | |||
| return it->second.get(); | |||
| @@ -134,6 +136,16 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| std::string cookie; | |||
| auto session_id = rq->connection_info().session_id(); | |||
| auto crc = rq->connection_info().crc(); | |||
| // Before allowing the creation, make sure the session had already been created by the user | |||
| // Our intention is to add this cache to the active sessions list so leave the list locked during | |||
| // this entire function. | |||
| UniqueLock lock(&sessions_lock_); | |||
| auto session_it = active_sessions_.find(session_id); | |||
| if (session_it == active_sessions_.end()) { | |||
| RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!"); | |||
| } | |||
| // We concat both numbers to form the internal connection id. | |||
| auto connection_id = GetConnectionID(session_id, crc); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache"); | |||
| @@ -172,10 +184,15 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| // Add the cache into the active session tracking. | |||
| // We have already validated that the session exists and that this is a new cache created. | |||
| session_it->second.insert(connection_id); | |||
| } else { | |||
| duplicate = true; | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | |||
| } | |||
| off_cookie = fbb.CreateString(cookie); | |||
| CreateCacheReplyMsgBuilder bld(fbb); | |||
| bld.add_connection_id(connection_id); | |||
| @@ -183,19 +200,18 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||
| auto off = bld.Finish(); | |||
| fbb.Finish(off); | |||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| // Track the history of all the sessions that we have created so far. | |||
| history_sessions_.insert(session_id); | |||
| // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it | |||
| // treat it as OK. | |||
| return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK(); | |||
| } | |||
| Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { | |||
| Status CacheServer::DestroyCache(CacheRequest *rq) { | |||
| // We need a strong lock to protect the map. | |||
| UniqueLock lck(&rwLock_); | |||
| auto id = rq->connection_id(); | |||
| CacheService *cs = GetService(id); | |||
| // it is already destroyed. Ignore it. | |||
| if (cs != nullptr) { | |||
| auto id = rq->connection_id(); | |||
| MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | |||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||
| auto n = all_caches_.erase(id); | |||
| @@ -204,11 +220,34 @@ Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { | |||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | |||
| } | |||
| } | |||
| // Now that this cache is removed, we need to also remove it's connection id from active session tracking | |||
| auto session_id = GetSessionID(id); | |||
| UniqueLock sess_lck(&sessions_lock_); | |||
| auto it = active_sessions_.find(session_id); | |||
| if (it == active_sessions_.end()) { | |||
| // The session was not found in the active sessions | |||
| RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!"); | |||
| } | |||
| auto connection_it = it->second.find(id); | |||
| if (connection_it == it->second.end()) { | |||
| RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!"); | |||
| } | |||
| // remove that connection id from the set | |||
| it->second.erase(connection_it); | |||
| MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id; | |||
| return Status::OK(); | |||
| } | |||
| inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| @@ -236,8 +275,11 @@ inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | |||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | |||
| // Ensure we got 3 pieces of data coming in | |||
| @@ -270,8 +312,11 @@ Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply | |||
| return rc; | |||
| } | |||
| Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| @@ -325,8 +370,11 @@ Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheRepl | |||
| return Status::OK(); | |||
| } | |||
| inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| @@ -338,8 +386,8 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | |||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | |||
| bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); | |||
| bld.add_max_row_id(svc_stat.max_); | |||
| bld.add_min_row_id(svc_stat.min_); | |||
| bld.add_max_row_id(svc_stat.stat_.max_key); | |||
| bld.add_min_row_id(svc_stat.stat_.min_key); | |||
| bld.add_state(svc_stat.state_); | |||
| auto offset = bld.Finish(); | |||
| fbb.Finish(offset); | |||
| @@ -348,8 +396,11 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| return Status::OK(); | |||
| } | |||
| inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { | |||
| Status CacheServer::CacheSchema(CacheRequest *rq) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| @@ -361,8 +412,11 @@ inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { | |||
| return Status::OK(); | |||
| } | |||
| inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||
| Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| @@ -377,8 +431,11 @@ inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) | |||
| return Status::OK(); | |||
| } | |||
| inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { | |||
| Status CacheServer::BuildPhaseDone(CacheRequest *rq) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| @@ -396,15 +453,24 @@ inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::PurgeCache(CacheService *cs) { | |||
| Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| // If shutdown in progress, ignore the command. | |||
| if (global_shutdown_) { | |||
| return Status::OK(); | |||
| } | |||
| // it is already purged. Ignore it. | |||
| if (cs != nullptr) { | |||
| RETURN_IF_NOT_OK(cs->Purge()); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| std::vector<row_id_type> gap; | |||
| RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap)); | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| auto off_t = fbb.CreateVector(gap); | |||
| TensorRowIdsBuilder bld(fbb); | |||
| bld.add_row_id(off_t); | |||
| auto off = bld.Finish(); | |||
| fbb.Finish(off); | |||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -414,6 +480,72 @@ inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *re | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::ToggleWriteMode(CacheRequest *rq) { | |||
| auto connection_id = rq->connection_id(); | |||
| // Hold the shared lock to prevent the cache from being dropped. | |||
| SharedLock lck(&rwLock_); | |||
| CacheService *cs = GetService(connection_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| // First piece of data is the on/off flag | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag"); | |||
| const auto &action = rq->buf_data(0); | |||
| bool on_off = false; | |||
| if (strcmp(action.data(), "on") == 0) { | |||
| on_off = true; | |||
| } else if (strcmp(action.data(), "off") == 0) { | |||
| on_off = false; | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Unknown request: " + action); | |||
| } | |||
| RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::ListSessions(CacheReply *reply) { | |||
| SharedLock lck(&sessions_lock_); | |||
| flatbuffers::FlatBufferBuilder fbb; | |||
| std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector; | |||
| for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) { | |||
| session_id_type current_session_id = it->first; | |||
| // Loop over each cache inside this session | |||
| if (!it->second.empty()) { | |||
| for (auto current_conn_id : it->second) { | |||
| CacheService *cs = GetService(current_conn_id); | |||
| if (cs == nullptr) { | |||
| std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions."; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||
| } else { | |||
| CacheService::ServiceStat svc_stat; | |||
| RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); | |||
| auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, | |||
| svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key, | |||
| svc_stat.stat_.max_key, svc_stat.state_); | |||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); | |||
| session_msgs_vector.push_back(current_session_info); | |||
| } | |||
| } | |||
| } else { | |||
| // If there is no cache created yet, assign a connection id of 0 along with empty stats | |||
| auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0); | |||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats); | |||
| session_msgs_vector.push_back(current_session_info); | |||
| } | |||
| } | |||
| auto session_msgs = fbb.CreateVector(session_msgs_vector); | |||
| ListSessionsMsgBuilder s_builder(fbb); | |||
| s_builder.add_sessions(session_msgs); | |||
| auto offset = s_builder.Finish(); | |||
| fbb.Finish(offset); | |||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||
| return Status::OK(); | |||
| } | |||
| /// \brief This is the main loop the cache server thread(s) are running. | |||
| /// Each thread will pop a request and send the result back to the client using grpc | |||
| /// \return | |||
| @@ -426,12 +558,6 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||
| RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | |||
| auto &rq = cache_req->rq_; | |||
| auto &reply = cache_req->reply_; | |||
| CacheService *cs = nullptr; | |||
| // Request comes in roughly two sets. One set is at the cache level with a connection id. | |||
| // The other set is working at a high level and without a connection id | |||
| if (!rq.has_connection_info()) { | |||
| cs = GetService(rq.connection_id()); | |||
| } | |||
| // Except for creating a new session, we expect cs is not null. | |||
| switch (cache_req->type_) { | |||
| case BaseRequest::RequestType::kCacheRow: { | |||
| @@ -439,42 +565,42 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||
| // call the appropriate method. | |||
| auto flag = rq.flag(); | |||
| if (BitTest(flag, kDataIsInSharedMemory)) { | |||
| cache_req->rc_ = FastCacheRow(cs, &rq, &reply); | |||
| cache_req->rc_ = FastCacheRow(&rq, &reply); | |||
| } else { | |||
| cache_req->rc_ = CacheRow(cs, &rq, &reply); | |||
| cache_req->rc_ = CacheRow(&rq, &reply); | |||
| } | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBatchFetchRows: { | |||
| cache_req->rc_ = BatchFetchRows(cs, &rq, &reply); | |||
| cache_req->rc_ = BatchFetchRows(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCreateCache: { | |||
| cache_req->rc_ = CreateService(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kPurgeCache: { | |||
| cache_req->rc_ = PurgeCache(cs); | |||
| case BaseRequest::RequestType::kGetCacheMissKeys: { | |||
| cache_req->rc_ = GetCacheMissKeys(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDestroyCache: { | |||
| cache_req->rc_ = DestroyCache(cs, &rq); | |||
| cache_req->rc_ = DestroyCache(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kGetStat: { | |||
| cache_req->rc_ = GetStat(cs, &rq, &reply); | |||
| cache_req->rc_ = GetStat(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kCacheSchema: { | |||
| cache_req->rc_ = CacheSchema(cs, &rq); | |||
| cache_req->rc_ = CacheSchema(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kFetchSchema: { | |||
| cache_req->rc_ = FetchSchema(cs, &rq, &reply); | |||
| cache_req->rc_ = FetchSchema(&rq, &reply); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kBuildPhaseDone: { | |||
| cache_req->rc_ = BuildPhaseDone(cs, &rq); | |||
| cache_req->rc_ = BuildPhaseDone(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kDropSession: { | |||
| @@ -498,6 +624,18 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||
| cache_req->rc_ = GlobalShutdown(); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kHeartBeat: { | |||
| cache_req->rc_ = Status::OK(); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kToggleWriteMode: { | |||
| cache_req->rc_ = ToggleWriteMode(&rq); | |||
| break; | |||
| } | |||
| case BaseRequest::RequestType::kListSessions: { | |||
| cache_req->rc_ = ListSessions(&reply); | |||
| break; | |||
| } | |||
| default: | |||
| std::string errMsg("Unknown request type : "); | |||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | |||
| @@ -526,18 +664,32 @@ session_id_type CacheServer::GetSessionID(connection_id_type connection_id) cons | |||
| } | |||
| CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, | |||
| int32_t shared_meory_sz_in_gb) | |||
| int32_t shared_meory_sz_in_gb, float memory_cap_ratio) | |||
| : top_(spill_path), | |||
| num_workers_(num_workers), | |||
| port_(port), | |||
| shared_memory_sz_in_gb_(shared_meory_sz_in_gb), | |||
| global_shutdown_(false) {} | |||
| global_shutdown_(false), | |||
| memory_cap_ratio_(memory_cap_ratio), | |||
| cur_mem_usage_(0) { | |||
| memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_; | |||
| } | |||
| Status CacheServer::Run() { | |||
| RETURN_IF_NOT_OK(ServiceStart()); | |||
| Status CacheServer::Run(int msg_qid) { | |||
| Status rc = ServiceStart(); | |||
| // If there is a message que, return the status now before we call join_all which will never return | |||
| if (msg_qid != -1) { | |||
| SharedMessage msg(msg_qid); | |||
| RETURN_IF_NOT_OK(msg.SendStatus(rc)); | |||
| } | |||
| if (rc.IsError()) { | |||
| return rc; | |||
| } | |||
| // This is called by the main function and we shouldn't exit. Otherwise the main thread | |||
| // will just shutdown. So we will call some function that never return unless error. | |||
| // One good case will be simply to wait for all threads to return. | |||
| // note that after we have sent the initial status using the msg_qid, parent process will exit and | |||
| // remove it. So we can't use it again. | |||
| RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -567,32 +719,51 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { | |||
| Status CacheServer::DestroySession(CacheRequest *rq) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); | |||
| auto drop_session_id = rq->connection_info().session_id(); | |||
| UniqueLock lck(&rwLock_); | |||
| for (auto &cs : all_caches_) { | |||
| auto connection_id = cs.first; | |||
| auto session_id = GetSessionID(connection_id); | |||
| // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. | |||
| // So we will just manually do it. | |||
| if (session_id == drop_session_id) { | |||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||
| auto n = all_caches_.erase(connection_id); | |||
| MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id; | |||
| UniqueLock lck(&sessions_lock_); | |||
| // First validate that this session exists | |||
| auto it = active_sessions_.find(drop_session_id); | |||
| if (it == active_sessions_.end()) { | |||
| RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!"); | |||
| } | |||
| // Iterate over the set of connection id's for this session that we're dropping and erase each one. | |||
| { | |||
| UniqueLock rwlck(&rwLock_); | |||
| for (auto drop_connection_id : it->second) { | |||
| auto cache_drop_it = all_caches_.find(drop_connection_id); | |||
| if (cache_drop_it == all_caches_.end()) { | |||
| RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry."); | |||
| } | |||
| all_caches_.erase(cache_drop_it); | |||
| MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id; | |||
| // **Do not bother to remove the cache connection id from the active session because we will soon remove the | |||
| // entire session. | |||
| } | |||
| } | |||
| // Finally remove the session itself | |||
| active_sessions_.erase(it); | |||
| MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; | |||
| return Status::OK(); | |||
| } | |||
| session_id_type CacheServer::GenerateSessionID() const { | |||
| SharedLock lock(&rwLock_); | |||
| session_id_type CacheServer::GenerateSessionID() { | |||
| UniqueLock lock(&sessions_lock_); | |||
| auto mt = GetRandomDevice(); | |||
| std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max()); | |||
| session_id_type session_id; | |||
| bool duplicate = false; | |||
| do { | |||
| session_id = distribution(mt); | |||
| auto it = history_sessions_.find(session_id); | |||
| duplicate = (it != history_sessions_.end()); | |||
| auto it = active_sessions_.find(session_id); | |||
| duplicate = (it != active_sessions_.end()); | |||
| } while (duplicate); | |||
| // Add this session to our tracking of active sessions with initialized empty set of connections. | |||
| active_sessions_[session_id] = std::set<connection_id_type>(); | |||
| return session_id; | |||
| } | |||
| @@ -637,19 +808,59 @@ Status CacheServer::GlobalShutdown() { | |||
| vg_.interrupt_all(); | |||
| // The next thing to do drop all the caches. | |||
| UniqueLock lck(&rwLock_); | |||
| for (auto &it : all_caches_) { | |||
| auto id = it.first; | |||
| for (auto it = all_caches_.begin(); it != all_caches_.end();) { | |||
| auto id = it->first; | |||
| MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | |||
| // Wait for all outstanding work to be finished. | |||
| auto &cs = it.second; | |||
| auto &cs = it->second; | |||
| UniqueLock cs_lock(&cs->rw_lock_); | |||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||
| (void)all_caches_.erase(id); | |||
| it = all_caches_.erase(it); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| int64_t CacheServer::GetTotalSystemMemory() { | |||
| auto pages = sysconf(_SC_PHYS_PAGES); | |||
| auto page_size = sysconf(_SC_PAGE_SIZE); | |||
| auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size); | |||
| MS_LOG(INFO) << "Total physical RAM in bytes: " << total; | |||
| return total; | |||
| } | |||
| Status CacheServer::Builder::IpcResourceCleanup() { | |||
| Status rc; | |||
| SharedMemory::shm_key_t shm_key; | |||
| auto unix_socket = PortToUnixSocketPath(port_); | |||
| rc = PortToFtok(port_, &shm_key); | |||
| // We are expecting the unix path doesn't exist. | |||
| if (rc.IsError()) { | |||
| return Status::OK(); | |||
| } | |||
| // Attach to the shared memory which we expect don't exist | |||
| SharedMemory mem(shm_key); | |||
| rc = mem.Attach(); | |||
| if (rc.IsError()) { | |||
| return Status::OK(); | |||
| } | |||
| int32_t num_attached; | |||
| RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached)); | |||
| if (num_attached == 0) { | |||
| // Stale shared memory from last time. | |||
| // Remove both the memory and the socket path | |||
| RETURN_IF_NOT_OK(mem.Destroy()); | |||
| Path p(unix_socket); | |||
| (void)p.Remove(); | |||
| } else { | |||
| // Server is already up. | |||
| std::string errMsg = "Cache server is already up and running"; | |||
| // We return a duplicate error. The main() will intercept | |||
| // and output a proper message | |||
| return Status(StatusCode::kDuplicateKey, errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheServer::Builder::SanityCheck() { | |||
| if (shared_memory_sz_in_gb_ <= 0) { | |||
| RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive"); | |||
| @@ -673,6 +884,12 @@ Status CacheServer::Builder::SanityCheck() { | |||
| RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString()); | |||
| } | |||
| } | |||
| if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) { | |||
| RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1"); | |||
| } | |||
| // Check if the shared memory. | |||
| RETURN_IF_NOT_OK(IpcResourceCleanup()); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -17,6 +17,8 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | |||
| #include <string.h> | |||
| #include <unistd.h> | |||
| #include <algorithm> | |||
| #include <atomic> | |||
| #include <memory> | |||
| @@ -47,15 +49,16 @@ class CacheServer : public Service { | |||
| using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | |||
| class Builder { | |||
| public: | |||
| Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {} | |||
| Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {} | |||
| ~Builder() = default; | |||
| /// \brief Getter functions | |||
| const std::string &getTop() const { return top_; } | |||
| int32_t getNumWorkers() const { return num_workers_; } | |||
| int32_t getPort() const { return port_; } | |||
| int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } | |||
| const std::string &GetTop() const { return top_; } | |||
| int32_t GetNumWorkers() const { return num_workers_; } | |||
| int32_t GetPort() const { return port_; } | |||
| int32_t GetSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } | |||
| float GetMemoryCapRatio() const { return memory_cap_ratio_; } | |||
| Builder &SetRootDirectory(std::string root) { | |||
| top_ = std::move(root); | |||
| @@ -73,15 +76,20 @@ class CacheServer : public Service { | |||
| shared_memory_sz_in_gb_ = sz; | |||
| return *this; | |||
| } | |||
| Builder &SetMemoryCapRatio(float ratio) { | |||
| memory_cap_ratio_ = ratio; | |||
| return *this; | |||
| } | |||
| Status SanityCheck(); | |||
| void Print(std::ostream &out) const { | |||
| out << "Summary of the cache server configuration\n" | |||
| << "Spill directory: " << getTop() << "\n" | |||
| << "Number of parallel workers: " << getNumWorkers() << "\n" | |||
| << "Tcp/ip port: " << getPort() << "\n" | |||
| << "Shared memory size (in GB): " << getSharedMemorySzInGb(); | |||
| << "Spill directory: " << GetTop() << "\n" | |||
| << "Number of parallel workers: " << GetNumWorkers() << "\n" | |||
| << "Tcp/ip port: " << GetPort() << "\n" | |||
| << "Shared memory size (in GB): " << GetSharedMemorySzInGb() << "\n" | |||
| << "Memory cap ratio: " << GetMemoryCapRatio(); | |||
| } | |||
| friend std::ostream &operator<<(std::ostream &out, const Builder &bld) { | |||
| @@ -93,7 +101,8 @@ class CacheServer : public Service { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| // We need to bring up the Task Manager by bringing up the Services singleton. | |||
| RETURN_IF_NOT_OK(Services::CreateInstance()); | |||
| RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_)); | |||
| RETURN_IF_NOT_OK( | |||
| CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_, memory_cap_ratio_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -102,20 +111,27 @@ class CacheServer : public Service { | |||
| int32_t num_workers_; | |||
| int32_t port_; | |||
| int32_t shared_memory_sz_in_gb_; | |||
| float memory_cap_ratio_; | |||
| /// \brief Sanity checks on the shared memory. | |||
| /// \return Status object | |||
| Status IpcResourceCleanup(); | |||
| }; | |||
| CacheServer(const CacheServer &) = delete; | |||
| CacheServer &operator=(const CacheServer &) = delete; | |||
| CacheServer(CacheServer &&) = delete; | |||
| CacheServer &operator=(CacheServer &) = delete; | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() override; | |||
| ~CacheServer() { (void)ServiceStop(); } | |||
| ~CacheServer() override { (void)ServiceStop(); } | |||
| static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port, | |||
| int32_t shared_memory_sz) { | |||
| int32_t shared_memory_sz, float memory_cap_ratio) { | |||
| std::call_once(init_instance_flag_, [&]() -> Status { | |||
| auto &svcManager = Services::GetInstance(); | |||
| RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz)); | |||
| auto &SvcManager = Services::GetInstance(); | |||
| RETURN_IF_NOT_OK( | |||
| SvcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz, memory_cap_ratio)); | |||
| return Status::OK(); | |||
| }); | |||
| return Status::OK(); | |||
| @@ -133,7 +149,7 @@ class CacheServer : public Service { | |||
| } | |||
| /// \\brief Kick off server threads. Never return unless error out. | |||
| Status Run(); | |||
| Status Run(SharedMessage::queue_id_t msg_qid); | |||
| /// \brief Get a free tag | |||
| /// \param q[in] pointer to a pointer to a CacheServerRequest | |||
| @@ -145,13 +161,35 @@ class CacheServer : public Service { | |||
| /// \return Status object | |||
| static Status ReturnRequestTag(CacheServerRequest *p); | |||
| /// \brief This returns the size (in bytes) of the physical RAM on the machine. | |||
| /// \return the size (in bytes) of the physical RAM on the machine. | |||
| static int64_t GetTotalSystemMemory(); | |||
| /// \brief Internally this is how much we will try to use without exceeding the limit | |||
| /// \return Internal cap maximum | |||
| int64_t GetAvailableSystemMemory() { return memory_cap_; } | |||
| /// \brief Find out the current memory usage | |||
| int64_t GetMemoryUsage() { return cur_mem_usage_; } | |||
| /// \brief This updates our current memory usage. | |||
| enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 }; | |||
| void UpdateMemoryUsage(int64_t sz, MemUsageOp op) { | |||
| if (op == MemUsageOp::kAllocate) { | |||
| cur_mem_usage_ += sz; | |||
| } else { | |||
| cur_mem_usage_ -= sz; | |||
| } | |||
| } | |||
| private: | |||
| static std::once_flag init_instance_flag_; | |||
| static CacheServer *instance_; | |||
| mutable RWLock rwLock_; | |||
| mutable RWLock sessions_lock_; | |||
| std::string top_; | |||
| cache_index all_caches_; | |||
| std::set<session_id_type> history_sessions_; | |||
| std::map<session_id_type, std::set<connection_id_type>> active_sessions_; | |||
| std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_; | |||
| std::shared_ptr<QueueList<CacheServerRequest *>> free_list_; | |||
| std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_; | |||
| @@ -162,11 +200,15 @@ class CacheServer : public Service { | |||
| int32_t port_; | |||
| int32_t shared_memory_sz_in_gb_; | |||
| std::atomic<bool> global_shutdown_; | |||
| float memory_cap_ratio_; | |||
| int64_t memory_cap_; | |||
| std::atomic<int64_t> cur_mem_usage_; | |||
| /// \brief Constructor | |||
| /// \param spill_path Top directory for spilling buffers to. | |||
| /// \param num_workers Number of threads for handling requests. | |||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb); | |||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb, | |||
| float memory_cap_ratio); | |||
| /// \brief Locate a cache service from connection id. | |||
| /// \return Pointer to cache service. Null if not found | |||
| @@ -179,11 +221,9 @@ class CacheServer : public Service { | |||
| Status CreateService(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Destroy a cache service | |||
| /// \param cs | |||
| /// \param rq | |||
| /// \return | |||
| Status DestroyCache(CacheService *cs, CacheRequest *rq); | |||
| Status PurgeCache(CacheService *cs); | |||
| /// \return Status object | |||
| Status DestroyCache(CacheRequest *rq); | |||
| /// \brief Entry point for all internal server threads. | |||
| Status ServerRequest(int32_t worker_id); | |||
| @@ -207,7 +247,7 @@ class CacheServer : public Service { | |||
| /// \brief Generate a session ID for the client | |||
| /// \return Session ID | |||
| session_id_type GenerateSessionID() const; | |||
| session_id_type GenerateSessionID(); | |||
| /// \brief Handle kAllocateSharedBlock request | |||
| /// \param rq CacheRequest | |||
| @@ -220,20 +260,55 @@ class CacheServer : public Service { | |||
| /// \return Status object | |||
| Status FreeSharedMemory(CacheRequest *rq); | |||
| /// \brief Handle kFastCacheRow request | |||
| /// \brief Handle CacheRow request | |||
| /// \note There are two different implementation depends if shared memory is used for transportation. | |||
| /// \return Status object | |||
| Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply); | |||
| Status FastCacheRow(CacheRequest *rq, CacheReply *reply); | |||
| Status CacheRow(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Internal function to do row batch fetch | |||
| /// \param cs CacheService | |||
| /// \param rq Request | |||
| /// \param reply Reply | |||
| /// \return | |||
| Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply); | |||
| /// \return Status object | |||
| Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Internal function to get statistics | |||
| /// \param rq | |||
| /// \param reply | |||
| /// \return Status object | |||
| Status GetStat(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Cache a schema request | |||
| /// \param rq | |||
| /// \return Status object | |||
| Status CacheSchema(CacheRequest *rq); | |||
| /// \brief Fetch a schema request | |||
| /// \param rq | |||
| /// \param reply | |||
| /// \return Status object | |||
| Status FetchSchema(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Mark Build phase done (for non-mappable case) | |||
| /// \param rq | |||
| /// \return Status object | |||
| Status BuildPhaseDone(CacheRequest *rq); | |||
| /// \brief A proper shutdown of the server | |||
| /// \return Status object | |||
| Status GlobalShutdown(); | |||
| /// \brief Find keys that will be cache miss | |||
| /// \return Status object | |||
| Status GetCacheMissKeys(CacheRequest *rq, CacheReply *reply); | |||
| /// \brief Toggle write mode for a service | |||
| Status ToggleWriteMode(CacheRequest *rq); | |||
| /// \brief List the sessions and their caches | |||
| /// \param reply | |||
| /// \return Status object | |||
| Status ListSessions(CacheReply *reply); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||
| #include "minddata/dataset/util/slice.h" | |||
| namespace mindspore { | |||
| @@ -22,42 +23,62 @@ CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool genera | |||
| : root_(root), | |||
| cache_mem_sz_(mem_sz), | |||
| cp_(nullptr), | |||
| map_(nullptr), | |||
| next_id_(0), | |||
| generate_id_(generate_id), | |||
| schema_key_(-1), | |||
| st_(generate_id ? State::kBuildPhase : State::kNone) {} | |||
| st_(generate_id ? State::kBuildPhase : State::kNone), | |||
| cur_mem_usage_(0), | |||
| cur_disk_usage_(0) {} | |||
| CacheService::~CacheService() { (void)ServiceStop(); } | |||
| bool CacheService::UseArena() { | |||
| // If fixed size, use Arena instead of the pool from global context. | |||
| return (cache_mem_sz_ > 0); | |||
| } | |||
| Status CacheService::DoServiceStart() { | |||
| std::shared_ptr<MemoryPool> mp_; | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| if (UseArena()) { | |||
| auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L; | |||
| if (cache_mem_sz_ > avail_mem) { | |||
| // Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway. | |||
| MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem | |||
| << " MB"; | |||
| } | |||
| // Create a fixed size arena based on the parameter. | |||
| std::shared_ptr<Arena> arena; | |||
| RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); | |||
| mp_ = std::move(arena); | |||
| // update the global usage only. | |||
| cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate); | |||
| } else { | |||
| // Unlimited size. Simply use a system pool. Another choice is CircularPool. | |||
| mp_ = std::make_shared<SystemPool>(); | |||
| } | |||
| // Put together a CachePool for backing up the Tensor | |||
| cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), root_); | |||
| cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), UseArena(), root_); | |||
| RETURN_IF_NOT_OK(cp_->ServiceStart()); | |||
| // Set up the B+ tree as well. But use the system pool instead. | |||
| map_ = std::make_shared<row_map>(); | |||
| // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. | |||
| cookie_ = cp_->MyName(); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::DoServiceStop() { | |||
| if (cp_ != nullptr) { | |||
| RETURN_IF_NOT_OK(cp_->ServiceStop()); | |||
| } | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| if (UseArena()) { | |||
| cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree); | |||
| } else { | |||
| MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage() | |||
| << " bytes."; | |||
| cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | |||
| @@ -66,6 +87,11 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| if (st_ == State::kNoLocking) { | |||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | |||
| // return out of memory from now on. | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| try { | |||
| // The first buffer is a flatbuffer which describes the rest of the buffers follow | |||
| auto fb = buf.front(); | |||
| @@ -86,6 +112,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| *row_id_generated = msg->row_id(); | |||
| } | |||
| auto size_of_this = msg->size_of_this(); | |||
| size_t total_sz = size_of_this; | |||
| auto column_hdr = msg->column(); | |||
| // Number of tensor buffer should match the number of columns plus one. | |||
| if (buf.size() != column_hdr->size() + 1) { | |||
| @@ -99,16 +126,28 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||
| all_data.emplace_back(fb, size_of_this); | |||
| for (auto i = 0; i < column_hdr->size(); ++i) { | |||
| all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); | |||
| total_sz += msg->data_sz()->Get(i); | |||
| } | |||
| // Now we cache the flat buffer. | |||
| CachePool::key_type key; | |||
| RETURN_IF_NOT_OK(cp_->Insert(all_data, &key)); | |||
| Status rc = map_->DoInsert(*row_id_generated, key); | |||
| // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. | |||
| // Otherwise, we check how much (globally) how much we use and may simply spill to disk | |||
| // directly. | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); | |||
| Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly); | |||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | |||
| } else { | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| // All good, then update the memory usage local and global (if not using arena) | |||
| if (write_to_disk_directly) { | |||
| cur_disk_usage_ += total_sz; | |||
| } else { | |||
| cur_mem_usage_ += total_sz; | |||
| if (!UseArena()) { | |||
| cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| @@ -123,6 +162,11 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| if (st_ == State::kNoLocking) { | |||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | |||
| // return out of memory from now on. | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| try { | |||
| // If we don't need to generate id, we need to find it from the buffer. | |||
| if (generate_id_) { | |||
| @@ -139,20 +183,33 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ | |||
| } | |||
| *row_id_generated = msg->row_id(); | |||
| } | |||
| // Now we cache the flat buffer. | |||
| CachePool::key_type key; | |||
| RETURN_IF_NOT_OK(cp_->Insert({src}, &key)); | |||
| Status rc = map_->DoInsert(*row_id_generated, key); | |||
| // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. | |||
| // Otherwise, we check how much (globally) how much we use and may simply spill to disk | |||
| // directly. | |||
| auto total_sz = src.GetSize(); | |||
| CacheServer &cs = CacheServer::GetInstance(); | |||
| bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); | |||
| Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly); | |||
| if (rc == Status(StatusCode::kDuplicateKey)) { | |||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | |||
| } else { | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| // All good, then update the memory usage local and global (if not using arena) | |||
| if (write_to_disk_directly) { | |||
| cur_disk_usage_ += total_sz; | |||
| } else { | |||
| cur_mem_usage_ += total_sz; | |||
| if (!UseArena()) { | |||
| cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } catch (const std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| } | |||
| std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nCache memory size: " << cs.cache_mem_sz_; | |||
| @@ -164,34 +221,29 @@ std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | |||
| } | |||
| return out; | |||
| } | |||
| Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } | |||
| Status CacheService::Purge() { | |||
| // First we must lock exclusively. No one else can cache/restore anything. | |||
| UniqueLock rw(&rw_lock_); | |||
| RETURN_IF_NOT_OK(cp_->ServiceStop()); | |||
| auto new_map = std::make_shared<row_map>(); | |||
| map_.reset(); | |||
| map_ = std::move(new_map); | |||
| next_id_ = 0; | |||
| RETURN_IF_NOT_OK(cp_->ServiceStart()); | |||
| Status CacheService::FindKeysMiss(std::vector<row_id_type> *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| std::unique_lock<std::mutex> lock(get_key_miss_mux_); | |||
| if (key_miss_results_ == nullptr) { | |||
| // Just do it once. | |||
| key_miss_results_ = std::make_shared<std::vector<row_id_type>>(); | |||
| auto stat = cp_->GetStat(true); | |||
| key_miss_results_->push_back(stat.min_key); | |||
| key_miss_results_->push_back(stat.max_key); | |||
| key_miss_results_->insert(key_miss_results_->end(), stat.gap.begin(), stat.gap.end()); | |||
| } | |||
| out->insert(out->end(), key_miss_results_->begin(), key_miss_results_->end()); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::GetStat(CacheService::ServiceStat *out) { | |||
| SharedLock rw(&rw_lock_); | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| if (st_ == State::kNone || st_ == State::kFetchPhase) { | |||
| out->stat_ = cp_->GetStat(); | |||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||
| auto it = map_->begin(); | |||
| if (it != map_->end()) { | |||
| out->min_ = it.key(); | |||
| auto end_it = map_->end(); | |||
| --end_it; | |||
| out->max_ = end_it.key(); | |||
| } | |||
| } else { | |||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||
| } | |||
| out->stat_ = cp_->GetStat(); | |||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||
| return Status::OK(); | |||
| } | |||
| @@ -204,19 +256,12 @@ Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vecto | |||
| *mem_sz = (num_elements + 1) * sizeof(int64_t); | |||
| (*out).reserve(num_elements); | |||
| for (auto row_id : v) { | |||
| auto r = map_->Search(row_id); | |||
| if (r.second) { | |||
| auto &it = r.first; | |||
| CachePool::key_type key = it.value(); | |||
| auto sz = cp_->GetSize(key); | |||
| if (sz == 0) { | |||
| std::string errMsg = "Key not found: "; | |||
| errMsg += std::to_string(key); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| (*out).emplace_back(key, sz); | |||
| auto sz = cp_->GetSize(row_id); | |||
| if (sz > 0) { | |||
| (*out).emplace_back(row_id, sz); | |||
| (*mem_sz) += sz; | |||
| } else { | |||
| // key not found | |||
| (*out).emplace_back(-1, 0); | |||
| } | |||
| } | |||
| @@ -252,27 +297,19 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::ve | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::CacheSchema(const void *buf, int64_t len) { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kFetchPhase) { | |||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||
| // allow other to cache more rows. | |||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||
| } | |||
| // This is a special request and we need to remember where we store it. | |||
| UniqueLock rw(&rw_lock_); | |||
| // In case we are calling the same function from multiple threads, only | |||
| // the first one is considered. Rest is ignored. | |||
| CachePool::key_type cur_key = schema_key_; | |||
| CachePool::key_type key; | |||
| if (cur_key < 0) { | |||
| RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key)); | |||
| auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key); | |||
| MS_LOG(DEBUG) << "Caching Schema. Result = " << result; | |||
| if (schema_.empty()) { | |||
| schema_.assign(static_cast<const char *>(buf), len); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Caching Schema already done"; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::FetchSchema(std::string *out) const { | |||
| SharedLock rw(&rw_lock_); | |||
| if (st_ == State::kBuildPhase) { | |||
| @@ -283,32 +320,44 @@ Status CacheService::FetchSchema(std::string *out) const { | |||
| // We are going to use std::string to allocate and hold the result which will be eventually | |||
| // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | |||
| // to minimize memory copy. | |||
| std::string mem; | |||
| if (schema_key_ >= 0) { | |||
| auto len = cp_->GetSize(schema_key_); | |||
| try { | |||
| mem.resize(len); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error"); | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| auto slice = WritableSlice(mem.data(), len); | |||
| RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); | |||
| std::string mem(schema_); | |||
| if (!mem.empty()) { | |||
| *out = std::move(mem); | |||
| } else { | |||
| return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CacheService::BuildPhaseDone() { | |||
| if (HasBuildPhase()) { | |||
| // Exclusive lock to switch phase | |||
| UniqueLock rw(&rw_lock_); | |||
| st_ = State::kFetchPhase; | |||
| cp_->SetLocking(false); | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); | |||
| } | |||
| } | |||
| Status CacheService::ToggleWriteMode(bool on_off) { | |||
| UniqueLock rw(&rw_lock_); | |||
| if (HasBuildPhase()) { | |||
| RETURN_STATUS_UNEXPECTED("Not applicable to non-mappable dataset"); | |||
| } else { | |||
| // If we stop accepting write request, we turn off locking for the | |||
| // underlying B+ tree. All future write request we will return kOutOfMemory. | |||
| if (st_ == State::kNone && !on_off) { | |||
| st_ = State::kNoLocking; | |||
| cp_->SetLocking(on_off); | |||
| MS_LOG(WARNING) << "Locking mode is switched off."; | |||
| } else if (st_ == State::kNoLocking && on_off) { | |||
| st_ = State::kNone; | |||
| cp_->SetLocking(on_off); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -20,6 +20,7 @@ | |||
| #include <algorithm> | |||
| #include <atomic> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <type_traits> | |||
| #include <utility> | |||
| @@ -44,9 +45,8 @@ using key_size_pair = std::pair<CachePool::key_type, size_t>; | |||
| class CacheService : public Service { | |||
| public: | |||
| friend class CacheServer; | |||
| using row_map = BPlusTree<row_id_type, CachePool::key_type>; | |||
| enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase }; | |||
| enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; | |||
| /// \brief Constructor | |||
| /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited | |||
| @@ -97,11 +97,9 @@ class CacheService : public Service { | |||
| class ServiceStat { | |||
| public: | |||
| using state_type = std::underlying_type<State>::type; | |||
| ServiceStat() : min_(0), max_(0), state_(0) {} | |||
| ServiceStat() : state_(0) {} | |||
| ~ServiceStat() = default; | |||
| CachePool::CacheStat stat_{}; | |||
| row_id_type min_; | |||
| row_id_type max_; | |||
| state_type state_; | |||
| }; | |||
| /// \brief Statistics for the current service | |||
| @@ -117,9 +115,9 @@ class CacheService : public Service { | |||
| /// \param out A contiguous memory that contains the serialized form of schema. | |||
| /// \return Status object | |||
| Status FetchSchema(std::string *out) const; | |||
| /// \brief Purge the content of a cache | |||
| /// \brief Return a set of keys that are definitely cache miss | |||
| /// \return Status object | |||
| Status Purge(); | |||
| Status FindKeysMiss(std::vector<row_id_type> *out); | |||
| /// \brief Overload the << operator to print a cache service | |||
| /// \param out std::ostream | |||
| /// \param cs A cache service | |||
| @@ -136,19 +134,33 @@ class CacheService : public Service { | |||
| /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. | |||
| /// \return Status object | |||
| Status BuildPhaseDone(); | |||
| /// \brief Find out the current memory usage | |||
| int64_t GetMemoryUsage() { return cur_mem_usage_; } | |||
| /// \brief Find out the current disk usage | |||
| int64_t GetDiskUsage() { return cur_disk_usage_; } | |||
| /// \brief For kToggleWriteMode request | |||
| Status ToggleWriteMode(bool on_off); | |||
| private: | |||
| mutable RWLock rw_lock_; | |||
| std::string root_; | |||
| uint64_t cache_mem_sz_; | |||
| std::shared_ptr<CachePool> cp_; | |||
| std::shared_ptr<row_map> map_; | |||
| std::atomic<row_id_type> next_id_; | |||
| bool generate_id_; | |||
| std::atomic<CachePool::key_type> schema_key_; | |||
| std::string cookie_; | |||
| State st_; | |||
| std::string schema_; | |||
| // If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it. | |||
| // Otherwise we track how much is in memory and how much is on disk (if root_ is not empty). | |||
| // We use them to control when we should stop caching in memory in the case when there is no | |||
| // Arena. | |||
| std::atomic<int64_t> cur_mem_usage_; | |||
| std::atomic<int64_t> cur_disk_usage_; | |||
| // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make | |||
| // this request after we hit memory full or disk full. So the result is unlikely to change. | |||
| std::mutex get_key_miss_mux_; | |||
| std::shared_ptr<std::vector<row_id_type>> key_miss_results_; | |||
| /// \brief Private function to generate a row id | |||
| /// \return Row id assigned. | |||
| row_id_type GetNextRowId() { return next_id_.fetch_add(1); } | |||
| @@ -92,3 +92,13 @@ table CreateCacheReplyMsg { | |||
| connection_id:int64; | |||
| cookie:string; | |||
| } | |||
| table ListSessionMsg { | |||
| session_id:uint32; | |||
| connection_id:uint64; | |||
| stats:ServiceStatMsg; | |||
| } | |||
| table ListSessionsMsg { | |||
| sessions:[ListSessionMsg]; | |||
| } | |||
| @@ -53,22 +53,35 @@ CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t row | |||
| num_cache_miss_(0), | |||
| cache_client_(std::move(cache_client)), | |||
| rows_per_buffer_(rows_per_buf), | |||
| // We can cause deadlock if this internal Connector size is too small. | |||
| keys_miss_(num_workers_, 1, connector_capacity_), | |||
| prefetch_size_(cache_client_->getPrefetchSize()) { | |||
| prefetch_size_(rows_per_buffer_), | |||
| num_prefetchers_(num_workers_) { | |||
| // Adjust the prefetch size based on the number of workers. | |||
| auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_; | |||
| if (prefetch_size_ < prefetch_sz_per_thread) { | |||
| prefetch_size_ = prefetch_sz_per_thread; | |||
| MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_; | |||
| } | |||
| io_block_queues_.Init(num_workers, op_connector_size); | |||
| prefetch_queues_.Init(num_workers, op_connector_size); | |||
| sampler_queue_ = std::make_unique<Queue<std::shared_ptr<Tensor>>>(op_connector_size); | |||
| prefetch_queues_.Init(num_prefetchers_, op_connector_size); | |||
| // We can cause deadlock if this internal Connector size is too small. | |||
| keys_miss_ = std::make_unique<Connector<std::vector<row_id_type>>>(num_prefetchers_, 1, connector_capacity_); | |||
| } | |||
| // Common function to fetch samples from the sampler and send them using the io_block_queues to | |||
| // the parallel workers | |||
| Status CacheBase::FetchSamplesToWorkers() { | |||
| int64_t buf_cnt = 0; | |||
| int64_t wait_cnt = 0; | |||
| int64_t prefetch_cnt = 0; | |||
| // Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_ | |||
| // is too small (1 by default) and won't help performance. | |||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this))); | |||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); | |||
| auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id, | |||
| std::vector<row_id_type> &keys) -> Status { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(qList[worker_id]->Add(std::move(blk))); | |||
| return Status::OK(); | |||
| }; | |||
| // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them | |||
| // to the WorkerEntry. | |||
| do { | |||
| @@ -82,33 +95,54 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| ++wait_cnt; | |||
| std::vector<row_id_type> keys; | |||
| keys.reserve(rows_per_buffer_); | |||
| std::vector<row_id_type> prefetch_keys; | |||
| prefetch_keys.reserve(prefetch_size_); | |||
| std::unique_ptr<DataBuffer> sampler_buffer; | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| while (!sampler_buffer->eoe()) { | |||
| TensorRow sample_row; | |||
| RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | |||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | |||
| // Send the sampler tensor to other thread for prefetching. We are using shared pointer so it | |||
| // won't go out scope until it is really not in use. | |||
| RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids)); | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | |||
| keys.push_back(*itr); | |||
| ++row_cnt_; | |||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| prefetch_keys.push_back(*itr); | |||
| // Batch enough rows for performance reason. | |||
| if (row_cnt_ % prefetch_size_ == 0) { | |||
| RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys)); | |||
| // Now we tell the WorkerEntry to wait for them to come back. If prefetch_size_ is a multiple | |||
| // of rows_per_buffer_, the keys vector will always be empty. But it can be partially filled. | |||
| // The only requirement we set up is rows_per_buffer_ is less than or equal to prefetch_size_. | |||
| for (auto row_id : prefetch_keys) { | |||
| keys.push_back(row_id); | |||
| if (keys.size() == rows_per_buffer_) { | |||
| RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); | |||
| keys.clear(); | |||
| } | |||
| } | |||
| prefetch_keys.clear(); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | |||
| } | |||
| // Deal with any partial keys left. | |||
| if (!prefetch_keys.empty()) { | |||
| RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys)); | |||
| for (auto row_id : prefetch_keys) { | |||
| keys.push_back(row_id); | |||
| if (keys.size() == rows_per_buffer_) { | |||
| RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); | |||
| keys.clear(); | |||
| } | |||
| } | |||
| } | |||
| if (!keys.empty()) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); | |||
| } | |||
| // send the eoe | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| RETURN_IF_NOT_OK(prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add( | |||
| std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||
| // If repeat but the not last repeat, wait for reset. | |||
| if (!IsLastIteration()) { | |||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; | |||
| @@ -123,8 +157,6 @@ Status CacheBase::FetchSamplesToWorkers() { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | |||
| // Shutdown threads | |||
| std::shared_ptr<Tensor> empty; | |||
| RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty))); | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK( | |||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| @@ -145,13 +177,6 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| if (blk->eof()) { | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); | |||
| } else if (blk->eoe()) { | |||
| if (AllowCacheMiss()) { | |||
| // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from | |||
| // a sampler, send a eoe to physical leaf op as well. | |||
| std::vector<row_id_type> eoe; | |||
| eoe.push_back(eoe_row_id); | |||
| RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe)); | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | |||
| } else { | |||
| std::vector<int64_t> keys; | |||
| @@ -162,22 +187,21 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||
| } | |||
| std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone); | |||
| std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>(); | |||
| std::vector<row_id_type> cache_miss; | |||
| cache_miss.reserve(keys.size()); | |||
| for (auto row_id : keys) { | |||
| TensorRow row; | |||
| // Block until the row shows up in the pool. | |||
| RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row)); | |||
| RETURN_IF_NOT_OK(GetPrefetchRow(row_id, &row)); | |||
| if (row.empty()) { | |||
| cache_miss.push_back(row_id); | |||
| if (AllowCacheMiss()) { | |||
| ++num_cache_miss_; | |||
| } else { | |||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| } | |||
| que->push_back(std::move(row)); | |||
| } | |||
| db->set_tensor_table(std::move(que)); | |||
| if (AllowCacheMiss()) { | |||
| // Because of the way connector works, we push unconditionally even cache_miss can be empty. | |||
| RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss)); | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); | |||
| buffer_id += num_workers_; | |||
| } | |||
| @@ -189,7 +213,6 @@ Status CacheBase::RegisterResources() { | |||
| RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); | |||
| return Status::OK(); | |||
| } | |||
| @@ -208,73 +231,97 @@ Status CacheBase::UpdateColumnMapFromCache() { | |||
| return rc; | |||
| } | |||
| Status CacheBase::Dispatcher() { | |||
| TaskManager::FindMe()->Post(); | |||
| int64_t buf_cnt = 0; | |||
| int64_t num_row = 0; | |||
| std::vector<row_id_type> keys; | |||
| keys.reserve(prefetch_size_); | |||
| do { | |||
| keys.clear(); | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids)); | |||
| if (sample_ids == nullptr) { | |||
| // A null shared pointer signal times to quit. | |||
| // Also signal all prefetchers to quit. | |||
| for (int32_t i = 0; i < num_workers_; i++) { | |||
| RETURN_IF_NOT_OK( | |||
| prefetch_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||
| } | |||
| break; | |||
| Status CacheBase::GetPrefetchRow(row_id_type row_id, TensorRow *out) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(row_id >= 0, "Expect positive row id"); | |||
| RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, out)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss) { | |||
| RETURN_UNEXPECTED_IF_NULL(cache_miss); | |||
| std::vector<row_id_type> prefetch_keys; | |||
| prefetch_keys.reserve(keys.size()); | |||
| // Filter out all those keys that unlikely we will find at the server | |||
| for (auto row_id : keys) { | |||
| if (cache_client_->KeyIsCacheMiss(row_id)) { | |||
| // Just put an empty row in the cache. | |||
| TensorRow row; | |||
| row.setId(row_id); | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| cache_miss->push_back(row_id); | |||
| } else { | |||
| prefetch_keys.push_back(row_id); | |||
| } | |||
| // Now we distribute the sampler ids to each prefetcher according to the prefetch size. | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | |||
| keys.push_back(*itr); | |||
| ++num_row; | |||
| if (num_row % prefetch_size_ == 0) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| keys.clear(); | |||
| } | |||
| // Early exit if nothing to fetch | |||
| if (prefetch_keys.empty()) { | |||
| return Status::OK(); | |||
| } | |||
| // Get the rows from the server | |||
| TensorTable ttbl; | |||
| Status rc = cache_client_->GetRows(prefetch_keys, &ttbl); | |||
| if (rc.IsOk()) { | |||
| auto row_it = ttbl.begin(); | |||
| for (auto row_id : prefetch_keys) { | |||
| auto &row = *row_it; | |||
| if (row.empty()) { | |||
| cache_miss->push_back(row_id); | |||
| } | |||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| ++row_it; | |||
| } | |||
| // Send the remaining sample id | |||
| if (!keys.empty()) { | |||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||
| RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||
| } else { | |||
| // In case any thread is waiting for the rows to come back and blocked on a semaphore, | |||
| // we will put an empty row in the local cache. | |||
| for (auto row_id : prefetch_keys) { | |||
| TensorRow row; | |||
| row.setId(row_id); | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| cache_miss->push_back(row_id); | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| return rc; | |||
| } | |||
| Status CacheBase::Prefetcher(int32_t worker_id) { | |||
| TaskManager::FindMe()->Post(); | |||
| std::vector<row_id_type> prefetch_keys; | |||
| prefetch_keys.reserve(prefetch_size_); | |||
| std::vector<row_id_type> cache_miss; | |||
| cache_miss.reserve(prefetch_size_); | |||
| do { | |||
| prefetch_keys.clear(); | |||
| cache_miss.clear(); | |||
| std::unique_ptr<IOBlock> blk; | |||
| RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk)); | |||
| RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); | |||
| if (prefetch_keys.empty()) { | |||
| // Empty keys mean time to quit. | |||
| break; | |||
| } | |||
| TensorTable ttbl; | |||
| RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); | |||
| auto row_it = ttbl.begin(); | |||
| for (auto row_id : prefetch_keys) { | |||
| auto &row = *row_it; | |||
| if (row.empty()) { | |||
| if (AllowCacheMiss()) { | |||
| ++num_cache_miss_; | |||
| } else { | |||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!blk->eof(), "Expect eoe or a regular io block"); | |||
| if (!blk->eoe()) { | |||
| RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); | |||
| Status rc; | |||
| const int32_t max_retries = 5; | |||
| int32_t retry_count = 0; | |||
| do { | |||
| rc = PrefetchRows(prefetch_keys, &cache_miss); | |||
| if (rc.IsNetWorkError() && retry_count < max_retries) { | |||
| // If we get some network error, we will attempt some retries | |||
| retry_count++; | |||
| } else if (rc.IsError()) { | |||
| return rc; | |||
| } | |||
| } while (rc.IsNetWorkError()); | |||
| } else { | |||
| if (AllowCacheMiss()) { | |||
| // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from | |||
| // a sampler, send a eoe to physical leaf op as well. | |||
| cache_miss.push_back(eoe_row_id); | |||
| } | |||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||
| ++row_it; | |||
| } | |||
| if (AllowCacheMiss()) { | |||
| // Because of the way connector works, we push unconditionally even cache_miss can be empty. | |||
| RETURN_IF_NOT_OK(keys_miss_->Push(worker_id, cache_miss)); | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| @@ -22,6 +22,7 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/connector.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/cache/cache_service.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| @@ -90,8 +91,7 @@ class CacheBase : public ParallelOp { | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| WaitPost epoch_sync_; | |||
| int32_t rows_per_buffer_; | |||
| Connector<std::vector<row_id_type>> keys_miss_; | |||
| QueueMap<row_id_type, TensorRow> prefetch_; | |||
| std::unique_ptr<Connector<std::vector<row_id_type>>> keys_miss_; | |||
| /// \brief Common function to register resources for interrupt | |||
| /// \note Derived should override this function for extra resources to be registered | |||
| @@ -111,13 +111,16 @@ class CacheBase : public ParallelOp { | |||
| constexpr static int32_t connector_capacity_ = 1024; | |||
| int32_t prefetch_size_; | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| int32_t num_prefetchers_; | |||
| QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; | |||
| std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_; | |||
| QueueMap<row_id_type, TensorRow> prefetch_; | |||
| Status Dispatcher(); | |||
| /// \brief Prefetcher. It prefetch the rows from cache server | |||
| /// \return Status object. | |||
| Status Prefetcher(int32_t worker_id); | |||
| /// \brief Functions used by prefetcher and WorkerEntry | |||
| Status PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss); | |||
| Status GetPrefetchRow(row_id_type row_id, TensorRow *out); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -87,10 +87,10 @@ Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } | |||
| void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | |||
| Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| std::vector<row_id_type> cache_miss; | |||
| RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); | |||
| RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); | |||
| // Ignore the case we have no cache miss, we can't return empty samples. | |||
| while (cache_miss.empty()) { | |||
| RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); | |||
| RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); | |||
| } | |||
| // Special code for eoe | |||
| if (cache_miss.at(0) == eoe_row_id) { | |||
| @@ -25,6 +25,7 @@ | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/util/system_pool.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| @@ -48,7 +49,8 @@ CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t | |||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | |||
| : ParallelOp(numWorkers, opConnectorSize, sampler), | |||
| num_cleaners_(numCleaners), | |||
| cache_client_(std::move(cache_client)) {} | |||
| cache_client_(std::move(cache_client)), | |||
| cache_missing_rows_(true) {} | |||
| Status CacheMergeOp::operator()() { | |||
| // A queue of row id to let cleaner send cache miss rows to the cache server | |||
| @@ -129,17 +131,19 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||
| std::string errMsg = "Expect positive row id: " + std::to_string(row_id); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||
| TensorRowCacheRequest *rq; | |||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||
| if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { | |||
| // We will send the request async. But any error we most | |||
| // likely ignore and continue. | |||
| Status rc; | |||
| rc = rq->AsyncSendCacheRequest(cache_client_, row); | |||
| if (rc.IsOk()) { | |||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||
| if (cache_missing_rows_) { | |||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||
| TensorRowCacheRequest *rq; | |||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||
| if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { | |||
| // We will send the request async. But any error we most | |||
| // likely ignore and continue. | |||
| Status rc; | |||
| rc = rq->AsyncSendCacheRequest(cache_client_, row); | |||
| if (rc.IsOk()) { | |||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||
| } | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row))); | |||
| @@ -168,13 +172,18 @@ Status CacheMergeOp::Cleaner() { | |||
| Status rc = rq->CheckCacheResult(); | |||
| if (rc.IsError()) { | |||
| // If interrupt, time to quit. | |||
| if (rc.get_code() == StatusCode::kInterrupted) { | |||
| if (rc.IsInterrupted()) { | |||
| return Status::OK(); | |||
| } else if (rc.IsOutofMemory() || rc.IsNoSpace()) { | |||
| // The server is hitting some limit and we will turn off caching from now on. | |||
| cache_missing_rows_ = false; | |||
| cache_client_->ServerRunningOutOfResources(); | |||
| } else { | |||
| MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); | |||
| // Bad rc should not bring down the pipeline. We will simply continue and | |||
| // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. | |||
| rq->SetState(TensorRowCacheRequest::State::kEmpty); | |||
| } | |||
| MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); | |||
| // Bad rc should not bring down the pipeline. We will simply continue and | |||
| // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. | |||
| rq->SetState(TensorRowCacheRequest::State::kEmpty); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| @@ -253,7 +262,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) { | |||
| Status CacheMergeOp::EoeReceived(int32_t worker_id) { | |||
| // If we are in a repeat path, send the eoe up. | |||
| // Otherwise ignore it. | |||
| if (op_total_repeats_ > 1) { | |||
| if (op_total_repeats_ != 1) { | |||
| return DatasetOp::EoeReceived(worker_id); | |||
| } | |||
| return Status::OK(); | |||
| @@ -281,7 +290,7 @@ Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheReque | |||
| *out = it->second.GetMutablePointer(); | |||
| } else { | |||
| // We will create a new one. | |||
| auto alloc = Services::GetAllocator<TensorRowCacheRequest>(); | |||
| auto alloc = SystemPool::GetAllocator<TensorRowCacheRequest>(); | |||
| auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc)); | |||
| if (r.second) { | |||
| auto &mem = r.first->second; | |||
| @@ -202,6 +202,7 @@ class CacheMergeOp : public ParallelOp { | |||
| std::unique_ptr<Queue<row_id_type>> io_que_; | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| int32_t num_cleaners_; | |||
| std::atomic<bool> cache_missing_rows_; | |||
| /// \brief Locate the cache request from the io_request_ map | |||
| /// \param row_id | |||
| @@ -16,6 +16,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| @@ -64,7 +65,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) { | |||
| // Constructor of CacheOp | |||
| CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | |||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | |||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), | |||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), | |||
| num_guys_in_(0), | |||
| phase_(Phase::kBuildPhase) {} | |||
| @@ -174,7 +175,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) { | |||
| Status CacheOp::RegisterResources() { | |||
| RETURN_IF_NOT_OK(CacheBase::RegisterResources()); | |||
| RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks())); | |||
| return Status::OK(); | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/db_connector.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| @@ -188,5 +189,11 @@ Status ConcatOp::ComputeColMap() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Visitor pre-accept method for NodePass | |||
| Status ConcatOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -105,6 +105,12 @@ class ConcatOp : public PipelineOp { | |||
| // @return - Status | |||
| Status ComputeColMap() override; | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| private: | |||
| Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf); | |||
| @@ -243,6 +243,7 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { | |||
| out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_ | |||
| << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; | |||
| if (sampler_) { | |||
| out << "\nSampler:\n"; | |||
| sampler_->Print(out, show_all); | |||
| } | |||
| } | |||
| @@ -268,5 +268,11 @@ Status FilterOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<FilterOp>(), modified); | |||
| } | |||
| // Visitor pre-accept method for NodePass | |||
| Status FilterOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->PreRunOnNode(shared_from_base<FilterOp>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -121,6 +121,12 @@ class FilterOp : public ParallelOp { | |||
| // @param show_all A bool to control if you want to show all info or just a summary. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| @@ -458,6 +458,12 @@ Status MapOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_base<MapOp>(), modified); | |||
| } | |||
| // Visitor pre-accept method for NodePass | |||
| Status MapOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->PreRunOnNode(shared_from_base<MapOp>(), modified); | |||
| } | |||
| Status MapOp::WaitForWorkers() { | |||
| // reset num_paused workers to 0 | |||
| num_workers_paused_ = 0; | |||
| @@ -177,10 +177,16 @@ class MapOp : public ParallelOp { | |||
| // @return the number of threads consuming data from previous op's output Connector. | |||
| int32_t num_consumers() const override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for NodePass visitor acceptor. | |||
| /// \param[in] p Pointer to the NodePass to be accepted. | |||
| /// \param[out] modified Whether this node visit modified the pipeline. | |||
| /// \return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| @@ -52,9 +52,9 @@ Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { | |||
| void ParallelOp::Print(std::ostream &out, bool show_all) const { | |||
| // Summary 1-liner print | |||
| if (!show_all) { | |||
| out << " [workers: " << num_workers_ << "]"; | |||
| // Call super class printer | |||
| DatasetOp::Print(out, show_all); | |||
| out << " [workers: " << num_workers_ << "]"; | |||
| } else { | |||
| // Detailed print | |||
| DatasetOp::Print(out, show_all); | |||
| @@ -27,14 +27,14 @@ PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampl | |||
| void PipelineOp::Print(std::ostream &out, bool show_all) const { | |||
| // Summary 1-liner print | |||
| if (!show_all) { | |||
| // Call super class printer | |||
| DatasetOp::Print(out, show_all); | |||
| out << " [workers: "; | |||
| if (this->inlined()) { | |||
| out << "0 (inlined)]"; | |||
| } else { | |||
| out << "1]"; // Pipeline ops only have 1 worker | |||
| } | |||
| // Call super class printer | |||
| DatasetOp::Print(out, show_all); | |||
| } else { | |||
| // Detailed print | |||
| DatasetOp::Print(out, show_all); | |||
| @@ -235,6 +235,12 @@ Status ZipOp::EoeReceived(int32_t) { | |||
| return Status::OK(); | |||
| } | |||
| // Visitor pre-accept method for NodePass | |||
| Status ZipOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->PreRunOnNode(shared_from_base<ZipOp>(), modified); | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status ZipOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| @@ -104,10 +104,16 @@ class ZipOp : public PipelineOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for NodePass visitor acceptor. | |||
| /// \param[in] p Pointer to the NodePass to be accepted. | |||
| /// \param[out] modified Whether this node visit modified the pipeline. | |||
| /// \return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| @@ -26,6 +26,7 @@ | |||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "minddata/dataset/engine/opt/post/repeat_pass.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/opt/pre/cache_error_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | |||
| #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | |||
| #include "minddata/dataset/engine/perf/profiling.h" | |||
| @@ -235,6 +236,7 @@ Status ExecutionTree::PrepareTreePreAction() { | |||
| std::vector<std::unique_ptr<Pass>> pre_actions; | |||
| // Construct pre actions | |||
| MS_LOG(INFO) << "Running pre pass loops."; | |||
| pre_actions.push_back(std::make_unique<CacheErrorPass>()); | |||
| pre_actions.push_back(std::make_unique<EpochInjectionPass>()); | |||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE | |||
| add_library(engine-opt OBJECT | |||
| pass.cc | |||
| post/repeat_pass.cc | |||
| pre/cache_error_pass.cc | |||
| pre/cache_transform_pass.cc | |||
| pre/epoch_injection_pass.cc | |||
| pre/removal_pass.cc | |||
| @@ -23,6 +23,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | |||
| @@ -143,125 +144,122 @@ Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| @@ -271,24 +269,38 @@ Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #endif | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||
| #ifdef ENABLE_PYTHON | |||
| Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) { | |||
| Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -37,18 +37,6 @@ class SkipOp; | |||
| class ShuffleOp; | |||
| #ifndef ENABLE_ANDROID | |||
| class MindRecordOp; | |||
| class TFReaderOp; | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| class FilterOp; | |||
| class GeneratorOp; | |||
| #endif | |||
| class AlbumOp; | |||
| class RandomDataOp; | |||
| @@ -63,10 +51,6 @@ class DeviceQueueOp; | |||
| class ImageFolderOp; | |||
| #ifndef ENABLE_ANDROID | |||
| class CacheOp; | |||
| #endif | |||
| class MnistOp; | |||
| class ManifestOp; | |||
| @@ -79,18 +63,30 @@ class CocoOp; | |||
| class CelebAOp; | |||
| class EpochCtrlOp; | |||
| class BuildVocabOp; | |||
| class ConcatOp; | |||
| #ifndef ENABLE_ANDROID | |||
| class MindRecordOp; | |||
| class TFReaderOp; | |||
| class CacheOp; | |||
| class CacheMergeOp; | |||
| class CacheLookupOp; | |||
| #endif | |||
| class EpochCtrlOp; | |||
| class BuildSentencePieceVocabOp; | |||
| #endif | |||
| class BuildVocabOp; | |||
| #ifdef ENABLE_PYTHON | |||
| class FilterOp; | |||
| #ifndef ENABLE_ANDROID | |||
| class BuildSentencePieceVocabOp; | |||
| class GeneratorOp; | |||
| #endif | |||
| // The base class Pass is the basic unit of tree transformation. | |||
| @@ -168,22 +164,6 @@ class NodePass : public Pass { | |||
| virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | |||
| #endif | |||
| virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified); | |||
| @@ -194,10 +174,6 @@ class NodePass : public Pass { | |||
| virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| #endif | |||
| virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified); | |||
| @@ -210,30 +186,48 @@ class NodePass : public Pass { | |||
| virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); | |||
| #endif | |||
| virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||
| #endif | |||
| virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified); | |||
| #endif | |||
| virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||
| #ifdef ENABLE_PYTHON | |||
| virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||
| #endif | |||
| private: | |||
| @@ -225,13 +225,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||
| // Turns off the tracking for operations under merge op | |||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||
| // would not have been consumed yet. In that case, we need to set its total repeats for it. | |||
| if (cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| } | |||
| // Setting the flag is needed since we didn't call the base class DatasetOp version | |||
| if (is_repeated_) { | |||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||
| // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack | |||
| if (cache_lookup_) { | |||
| cache_lookup_->set_total_repeats(num_repeats_); | |||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||
| AddToEOEOpStack(std::move(cache_lookup_)); | |||
| } | |||
| } | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_error_pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CacheErrorPass::CacheErrorPass() : is_cached_(false) {} | |||
| // Identifies the subtree below this node as being cached | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| // Turn on the flag that we're under a merge op | |||
| is_cached_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if ZipOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("ZipOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if MapOp with non-deterministic TensorOps exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||
| if (is_cached_) { | |||
| auto tfuncs = node->TFuncs(); | |||
| for (size_t i = 0; i < tfuncs.size(); i++) { | |||
| if (!tfuncs[i]->Deterministic()) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache."); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if ConcatOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) { | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("ConcatOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| // Returns an error if FilterOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||
| if (is_cached_) { | |||
| RETURN_STATUS_UNEXPECTED("FilterOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class CacheErrorPass cache_error_pass.h | |||
| /// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures. | |||
| class CacheErrorPass : public NodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| CacheErrorPass(); | |||
| /// \brief Destructor | |||
| ~CacheErrorPass() = default; | |||
| /// \brief Identifies the subtree below this node as being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Returns an error if ZipOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override; | |||
| /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||
| /// \brief Returns an error if ConcatOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | |||
| #ifdef ENABLE_PYTHON | |||
| /// \brief Returns an error if FilterOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | |||
| #endif | |||
| private: | |||
| bool is_cached_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_ | |||
| @@ -155,50 +155,77 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> n | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for AlbumOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for MindRecordOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for GeneratorOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| @@ -40,13 +40,6 @@ Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSe | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| // Temporary code to prevent the injection of epoch control when cache op is present | |||
| // Remove this code in cache op phase 2 | |||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||
| injection_point_ = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||
| @@ -54,13 +54,6 @@ class EpochInjectionPass : public TreePass { | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override; | |||
| /// \brief Temporary code to prevent the injection of epoch control when cache op is present. | |||
| /// Remove this code in cache op phase 2 | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Register the DeviceQueueOp for further action. | |||
| @@ -62,6 +62,7 @@ Status RandomApplyOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| RandomApplyOp::RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops) | |||
| : prob_(prob), gen_(GetSeed()), rand_double_(0, 1) { | |||
| compose_ = std::make_unique<ComposeOp>(ops); | |||
| is_deterministic_ = false; | |||
| } | |||
| } // namespace dataset | |||
| @@ -92,6 +92,7 @@ RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops | |||
| } else if (ops_.size() == 1) { | |||
| MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time."; | |||
| } | |||
| is_deterministic_ = false; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -44,6 +44,7 @@ RandomAffineOp::RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t | |||
| interpolation_ = interpolation; | |||
| fill_value_ = fill_value; | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| @@ -36,6 +36,7 @@ RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_f | |||
| hue_factor_start_(s_hue_factor), | |||
| hue_factor_end_(e_hue_factor) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| Status RandomColorAdjustOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| @@ -19,7 +19,9 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {} | |||
| RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) { | |||
| is_deterministic_ = false; | |||
| } | |||
| Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) { | |||
| IO_CHECK(in, out); | |||
| @@ -41,6 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ | |||
| aspect_ub_(aspect_ub), | |||
| max_iter_(max_iter) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| @@ -46,6 +46,7 @@ RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_ | |||
| fill_g_(fill_g), | |||
| fill_b_(fill_b) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| Status RandomCropOp::ImagePadding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *pad_image, | |||
| @@ -33,6 +33,7 @@ class RandomHorizontalFlipOp : public TensorOp { | |||
| static const float kDefProbability; | |||
| explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) { | |||
| is_deterministic_ = false; | |||
| rnd_.seed(GetSeed()); | |||
| } | |||
| @@ -35,6 +35,7 @@ class RandomHorizontalFlipWithBBoxOp : public TensorOp { | |||
| explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| ~RandomHorizontalFlipWithBBoxOp() override = default; | |||
| @@ -29,6 +29,7 @@ const std::vector<uint8_t> RandomPosterizeOp::kBitRange = {4, 8}; | |||
| RandomPosterizeOp::RandomPosterizeOp(const std::vector<uint8_t> &bit_range) | |||
| : PosterizeOp(bit_range[0]), bit_range_(bit_range) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| Status RandomPosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| @@ -35,6 +35,7 @@ class RandomResizeOp : public ResizeOp { | |||
| explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) { | |||
| random_generator_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| ~RandomResizeOp() = default; | |||
| @@ -36,6 +36,7 @@ class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { | |||
| static const int32_t kDefTargetWidth; | |||
| explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { | |||
| random_generator_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| ~RandomResizeWithBBoxOp() = default; | |||
| @@ -46,6 +46,7 @@ RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float c | |||
| fill_g_(fill_g), | |||
| fill_b_(fill_b) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| // main function call for random rotation : Generate the random degrees | |||
| @@ -90,6 +90,7 @@ RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &p | |||
| if (policy_.empty()) { | |||
| MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty."; | |||
| } | |||
| is_deterministic_ = false; | |||
| } | |||
| } // namespace dataset | |||
| @@ -31,6 +31,7 @@ const float RandomSharpnessOp::kDefEndDegree = 1.9; | |||
| RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree) | |||
| : start_degree_(start_degree), end_degree_(end_degree) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| /// main function call for random sharpness : Generate the random degrees | |||
| @@ -32,7 +32,10 @@ namespace dataset { | |||
| class RandomSolarizeOp : public SolarizeOp { | |||
| public: | |||
| // Pick a random threshold value to solarize the image with | |||
| explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); } | |||
| explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| ~RandomSolarizeOp() = default; | |||
| @@ -34,6 +34,7 @@ class RandomVerticalFlipOp : public TensorOp { | |||
| explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| ~RandomVerticalFlipOp() override = default; | |||
| @@ -34,6 +34,7 @@ class RandomVerticalFlipWithBBoxOp : public TensorOp { | |||
| // @param probability: Probablity of Image flipping, 0.5 by default | |||
| explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { | |||
| rnd_.seed(GetSeed()); | |||
| is_deterministic_ = false; | |||
| } | |||
| ~RandomVerticalFlipWithBBoxOp() override = default; | |||
| @@ -168,6 +168,10 @@ class TensorOp { | |||
| // @return true/false | |||
| bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } | |||
| // Returns true oif the TensorOp produces deterministic result. | |||
| // @return true/false | |||
| bool Deterministic() { return is_deterministic_; } | |||
| // Function to determine the number of inputs the TensorOp can take. 0: means undefined. | |||
| // @return uint32_t | |||
| virtual uint32_t NumInput() { return 1; } | |||
| @@ -191,6 +195,9 @@ class TensorOp { | |||
| virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs); | |||
| virtual std::string Name() const = 0; | |||
| protected: | |||
| bool is_deterministic_{true}; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -88,21 +88,21 @@ class Allocator { | |||
| std::shared_ptr<MemoryPool> pool_; | |||
| }; | |||
| /// \brief It is a wrapper of unique_ptr with a custom Allocator class defined above | |||
| template <typename T, typename... Args> | |||
| Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, Allocator<T> alloc, size_t n, Args &&... args) { | |||
| template <typename T, typename C = std::allocator<T>, typename... Args> | |||
| Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc, size_t n, Args &&... args) { | |||
| RETURN_UNEXPECTED_IF_NULL(out); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive"); | |||
| try { | |||
| T *data = alloc.allocate(n); | |||
| if (!std::is_arithmetic<T>::value) { | |||
| for (auto i = 0; i < n; i++) { | |||
| std::allocator_traits<Allocator<T>>::construct(alloc, &(data[i]), std::forward<Args>(args)...); | |||
| std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...); | |||
| } | |||
| } | |||
| auto deleter = [](T *p, Allocator<T> f_alloc, size_t f_n) { | |||
| auto deleter = [](T *p, C f_alloc, size_t f_n) { | |||
| if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) { | |||
| for (auto i = 0; i < f_n; ++i) { | |||
| std::allocator_traits<Allocator<T>>::destroy(f_alloc, &p[i]); | |||
| std::allocator_traits<C>::destroy(f_alloc, &p[i]); | |||
| } | |||
| } | |||
| f_alloc.deallocate(p, f_n); | |||
| @@ -129,7 +129,7 @@ class MemGuard { | |||
| MemGuard(const MemGuard &) = delete; | |||
| MemGuard &operator=(const MemGuard &) = delete; | |||
| // On the other hand, We can support move constructor | |||
| MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} | |||
| MemGuard(MemGuard &&lhs) noexcept : n_(lhs.n_), alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)) {} | |||
| MemGuard &operator=(MemGuard &&lhs) noexcept { | |||
| if (this != &lhs) { | |||
| this->deallocate(); | |||
| @@ -37,7 +37,8 @@ struct MemHdr { | |||
| ArenaImpl::ArenaImpl(void *ptr, size_t sz) : size_in_bytes_(sz), ptr_(ptr) { | |||
| // Divide the memory into blocks. Ignore the last partial block. | |||
| uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ; | |||
| MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; | |||
| MS_LOG(DEBUG) << "Arena memory pool is created. Number of blocks : " << num_blks << ". Block size : " << ARENA_BLK_SZ | |||
| << "."; | |||
| tr_.Insert(0, num_blks); | |||
| } | |||
| @@ -233,9 +234,9 @@ std::ostream &operator<<(std::ostream &os, const ArenaImpl &s) { | |||
| Status Arena::Init() { | |||
| try { | |||
| auto sz = size_in_MB_ * 1048576L; | |||
| mem_ = std::make_unique<uint8_t[]>(sz); | |||
| impl_ = std::make_unique<ArenaImpl>(mem_.get(), sz); | |||
| int64_t sz = size_in_MB_ * 1048576L; | |||
| RETURN_IF_NOT_OK(mem_.allocate(sz)); | |||
| impl_ = std::make_unique<ArenaImpl>(mem_.GetMutablePointer(), sz); | |||
| } catch (std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <utility> | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/memory_pool.h" | |||
| #include "minddata/dataset/util/treap.h" | |||
| @@ -140,7 +141,7 @@ class Arena : public MemoryPool { | |||
| protected: | |||
| mutable std::mutex mux_; | |||
| std::unique_ptr<ArenaImpl> impl_; | |||
| std::unique_ptr<uint8_t[]> mem_; | |||
| MemGuard<uint8_t> mem_; | |||
| size_t size_in_MB_; | |||
| explicit Arena(size_t val_in_MB = 4096); | |||
| @@ -131,6 +131,23 @@ class BPlusTree { | |||
| tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {} | |||
| }; | |||
| /// \brief Statistics functions | |||
| /// \return Return the height of the tree | |||
| auto GetHeight() const { return empty() ? 0 : stats_.level_ + 1; } | |||
| /// \return Order of the B+ tree | |||
| auto GetOrder() const { return traits::kLeafSlots; } | |||
| /// \return Number of leaves nodes | |||
| auto GetNumLeaves() const { return stats_.leaves_; } | |||
| /// \return Number of inner nodes | |||
| auto GetNumInnerNodes() const { return stats_.inner_nodes_; } | |||
| /// \brief Toggle locking | |||
| /// \note Once locking is off. It is user's responsibility to ensure concurrency | |||
| void SetLocking(bool on_off) { | |||
| UniqueLock lck(&rw_lock_); | |||
| acquire_lock_ = on_off; | |||
| } | |||
| private: | |||
| // Abstract class of a node (leaf or inner) | |||
| class BaseNode { | |||
| @@ -288,6 +305,17 @@ class BPlusTree { | |||
| key_compare key_less_; | |||
| // Stat | |||
| tree_stats stats_; | |||
| // lock mode | |||
| bool acquire_lock_; | |||
| void Init() { | |||
| typename LeafNode::alloc_type alloc(alloc_); | |||
| auto *p = alloc.allocate(1); | |||
| root_ = new (p) LeafNode(alloc_); | |||
| all_.Prepend(p); | |||
| leaf_nodes_.Append(p); | |||
| stats_.leaves_++; | |||
| } | |||
| bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); } | |||
| @@ -350,11 +378,11 @@ class BPlusTree { | |||
| ~Iterator(); | |||
| explicit Iterator(const Iterator &); | |||
| Iterator(const Iterator &); | |||
| Iterator &operator=(const Iterator &lhs); | |||
| explicit Iterator(Iterator &&); | |||
| Iterator(Iterator &&) noexcept; | |||
| Iterator &operator=(Iterator &&lhs); | |||
| @@ -399,11 +427,11 @@ class BPlusTree { | |||
| ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) | |||
| : cur_(leaf), slot_(slot), locked_(locked) {} | |||
| explicit ConstIterator(const ConstIterator &); | |||
| ConstIterator(const ConstIterator &); | |||
| ConstIterator &operator=(const ConstIterator &lhs); | |||
| explicit ConstIterator(ConstIterator &&); | |||
| ConstIterator(ConstIterator &&) noexcept; | |||
| ConstIterator &operator=(ConstIterator &&lhs); | |||
| @@ -413,11 +413,16 @@ typename BPlusTree<K, V, A, C, T>::IndexRc BPlusTree<K, V, A, C, T>::Locate(RWLo | |||
| } | |||
| template <typename K, typename V, typename A, typename C, typename T> | |||
| BPlusTree<K, V, A, C, T>::BPlusTree() : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {} | |||
| BPlusTree<K, V, A, C, T>::BPlusTree() | |||
| : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) { | |||
| Init(); | |||
| } | |||
| template <typename K, typename V, typename A, typename C, typename T> | |||
| BPlusTree<K, V, A, C, T>::BPlusTree(const Allocator<V> &alloc) | |||
| : alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {} | |||
| : alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) { | |||
| Init(); | |||
| } | |||
| template <typename K, typename V, typename A, typename C, typename T> | |||
| BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept { | |||
| @@ -446,20 +451,6 @@ BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept { | |||
| template <typename K, typename V, typename A, typename C, typename T> | |||
| Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<value_type> &&value) { | |||
| IndexRc rc; | |||
| if (root_ == nullptr) { | |||
| UniqueLock lck(&rw_lock_); | |||
| // Check again after we get the lock. Other thread may have created the root node already. | |||
| if (root_ == nullptr) { | |||
| LeafNode *leaf = nullptr; | |||
| rc = AllocateLeaf(&leaf); | |||
| if (rc != IndexRc::kOk) { | |||
| return IndexRc2Status(rc); | |||
| } | |||
| leaf_nodes_.Append(leaf); | |||
| root_ = leaf; | |||
| } | |||
| // lock will be unlocked when it goes out of scope. | |||
| } | |||
| bool retry = false; | |||
| do { | |||
| // Track all the paths to the target and lock each internal node in S. | |||
| @@ -468,7 +459,7 @@ Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<v | |||
| retry = false; | |||
| BaseNode *new_child = nullptr; | |||
| key_type new_key = key_type(); | |||
| rc = InsertKeyValue(&InsCB, root_, key, std::move(value), &new_key, &new_child); | |||
| rc = InsertKeyValue(acquire_lock_ ? &InsCB : nullptr, root_, key, std::move(value), &new_key, &new_child); | |||
| if (rc == IndexRc::kRetry) { | |||
| retry = true; | |||
| } else if (rc != IndexRc::kOk) { | |||
| @@ -511,9 +502,12 @@ std::unique_ptr<V> BPlusTree<K, V, A, C, T>::DoUpdate(const key_type &key, std:: | |||
| if (root_ != nullptr) { | |||
| LeafNode *leaf = nullptr; | |||
| slot_type slot; | |||
| RWLock *myLock = &this->rw_lock_; | |||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||
| myLock->LockShared(); | |||
| RWLock *myLock = nullptr; | |||
| if (acquire_lock_) { | |||
| myLock = &this->rw_lock_; | |||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||
| myLock->LockShared(); | |||
| } | |||
| IndexRc rc = Locate(myLock, true, root_, key, &leaf, &slot); | |||
| if (rc == IndexRc::kOk) { | |||
| // All locks from the tree to the parent of leaf are all gone. We still have a X lock | |||
| @@ -521,7 +515,9 @@ std::unique_ptr<V> BPlusTree<K, V, A, C, T>::DoUpdate(const key_type &key, std:: | |||
| // Swap out the old value and replace it with new value. | |||
| std::unique_ptr<value_type> old = std::move(leaf->data_[leaf->slot_dir_[slot]]); | |||
| leaf->data_[leaf->slot_dir_[slot]] = std::move(new_value); | |||
| leaf->rw_lock_.Unlock(); | |||
| if (acquire_lock_) { | |||
| leaf->rw_lock_.Unlock(); | |||
| } | |||
| return old; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Key not found. rc = " << static_cast<int>(rc) << "."; | |||
| @@ -109,7 +109,7 @@ BPlusTree<K, V, A, C, T>::Iterator::Iterator(const BPlusTree<K, V, A, C, T>::Ite | |||
| } | |||
| template <typename K, typename V, typename A, typename C, typename T> | |||
| BPlusTree<K, V, A, C, T>::Iterator::Iterator(BPlusTree<K, V, A, C, T>::Iterator &&lhs) { | |||
| BPlusTree<K, V, A, C, T>::Iterator::Iterator(BPlusTree<K, V, A, C, T>::Iterator &&lhs) noexcept { | |||
| this->cur_ = lhs.cur_; | |||
| this->slot_ = lhs.slot_; | |||
| this->locked_ = lhs.locked_; | |||
| @@ -241,7 +241,7 @@ BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(const BPlusTree<K, V, A, | |||
| } | |||
| template <typename K, typename V, typename A, typename C, typename T> | |||
| BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(BPlusTree<K, V, A, C, T>::ConstIterator &&lhs) { | |||
| BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(BPlusTree<K, V, A, C, T>::ConstIterator &&lhs) noexcept { | |||
| this->cur_ = lhs.cur_; | |||
| this->slot_ = lhs.slot_; | |||
| this->locked_ = lhs.locked_; | |||
| @@ -290,9 +290,12 @@ std::pair<typename BPlusTree<K, V, A, C, T>::ConstIterator, bool> BPlusTree<K, V | |||
| if (root_ != nullptr) { | |||
| LeafNode *leaf = nullptr; | |||
| slot_type slot; | |||
| RWLock *myLock = &this->rw_lock_; | |||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||
| myLock->LockShared(); | |||
| RWLock *myLock = nullptr; | |||
| if (acquire_lock_) { | |||
| myLock = &this->rw_lock_; | |||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||
| myLock->LockShared(); | |||
| } | |||
| IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); | |||
| bool find = (rc == IndexRc::kOk); | |||
| return std::make_pair(ConstIterator(leaf, slot, find), find); | |||
| @@ -306,9 +309,12 @@ std::pair<typename BPlusTree<K, V, A, C, T>::Iterator, bool> BPlusTree<K, V, A, | |||
| if (root_ != nullptr) { | |||
| LeafNode *leaf = nullptr; | |||
| slot_type slot; | |||
| RWLock *myLock = &this->rw_lock_; | |||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||
| myLock->LockShared(); | |||
| RWLock *myLock = nullptr; | |||
| if (acquire_lock_) { | |||
| myLock = &this->rw_lock_; | |||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||
| myLock->LockShared(); | |||
| } | |||
| IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); | |||
| bool find = (rc == IndexRc::kOk); | |||
| return std::make_pair(Iterator(leaf, slot, find), find); | |||
| @@ -69,7 +69,7 @@ Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) n | |||
| *p = addr; | |||
| return Status::OK(); | |||
| } else { | |||
| return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); | |||
| return Status(StatusCode::kBuddySpaceFull, "BuddySpace full. Not an error. Please ignore."); | |||
| } | |||
| } | |||
| @@ -20,8 +20,13 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CachePool::CachePool(const value_allocator &alloc, const std::string &root) | |||
| : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} | |||
| CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root) | |||
| : alloc_(alloc), | |||
| root_(root), | |||
| subfolder_(Services::GetUniqueID()), | |||
| sm_(nullptr), | |||
| tree_(nullptr), | |||
| custom_arena_(ourOwnArena) {} | |||
| Status CachePool::DoServiceStart() { | |||
| tree_ = std::make_shared<data_index>(); | |||
| @@ -45,9 +50,12 @@ Status CachePool::DoServiceStop() { | |||
| } | |||
| } | |||
| sm_.reset(); | |||
| for (auto &bl : *tree_) { | |||
| if (bl.ptr != nullptr) { | |||
| alloc_.deallocate(bl.ptr, bl.sz); | |||
| // If it is our own arena, skip freeing individual pieces. | |||
| if (!custom_arena_) { | |||
| for (auto &bl : *tree_) { | |||
| if (bl.ptr != nullptr) { | |||
| alloc_.deallocate(bl.ptr, bl.sz); | |||
| } | |||
| } | |||
| } | |||
| tree_.reset(); | |||
| @@ -68,7 +76,7 @@ Status CachePool::DoServiceStop() { | |||
| return rc2; | |||
| } | |||
| CachePool::~CachePool() noexcept { (void)ServiceStop(); } | |||
| Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_type *key) { | |||
| Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly) { | |||
| DataLocator bl; | |||
| Status rc; | |||
| size_t sz = 0; | |||
| @@ -78,22 +86,31 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t | |||
| } | |||
| bl.sz = sz; | |||
| try { | |||
| bl.ptr = alloc_.allocate(sz); | |||
| // We will do a piecewise copy. | |||
| WritableSlice dest(bl.ptr, bl.sz); | |||
| size_t pos = 0; | |||
| for (auto &v : buf) { | |||
| WritableSlice out(dest, pos); | |||
| rc = WritableSlice::Copy(&out, v); | |||
| if (!writeToDiskDirectly) { | |||
| bl.ptr = alloc_.allocate(sz); | |||
| // We will do a piecewise copy. | |||
| WritableSlice dest(bl.ptr, bl.sz); | |||
| size_t pos = 0; | |||
| for (auto &v : buf) { | |||
| WritableSlice out(dest, pos); | |||
| rc = WritableSlice::Copy(&out, v); | |||
| if (rc.IsError()) { | |||
| break; | |||
| } | |||
| pos += v.GetSize(); | |||
| } | |||
| if (rc.IsError()) { | |||
| break; | |||
| alloc_.deallocate(bl.ptr, sz); | |||
| bl.ptr = nullptr; | |||
| return rc; | |||
| } | |||
| pos += v.GetSize(); | |||
| } | |||
| if (rc.IsError()) { | |||
| alloc_.deallocate(bl.ptr, sz); | |||
| bl.ptr = nullptr; | |||
| return rc; | |||
| } else if (sm_ != nullptr) { | |||
| MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes."; | |||
| RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | |||
| } else { | |||
| // If asked to spill to disk instead but there is no storage set up, simply return no memory | |||
| // instead. | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } catch (std::bad_alloc &e) { | |||
| if (sm_ != nullptr) { | |||
| @@ -102,7 +119,13 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| rc = tree_->insert(bl, key); | |||
| // Insert into the B+ tree. We may still get out of memory error. So need to catch it. | |||
| try { | |||
| rc = tree_->DoInsert(key, bl); | |||
| } catch (const std::bad_alloc &e) { | |||
| rc = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| // Duplicate key is treated as error and we will also free the memory. | |||
| if (rc.IsError() && bl.ptr != nullptr) { | |||
| alloc_.deallocate(bl.ptr, sz); | |||
| } | |||
| @@ -138,15 +161,26 @@ Path CachePool::GetSpillPath() const { | |||
| auto spill = Path(root_) / subfolder_; | |||
| return spill; | |||
| } | |||
| CachePool::CacheStat CachePool::GetStat() const { | |||
| CacheStat cs{0}; | |||
| CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||
| CacheStat cs{-1, -1, 0, 0, 0}; | |||
| int64_t total_sz = 0; | |||
| for (auto &it : *tree_) { | |||
| total_sz += it.sz; | |||
| if (it.ptr != nullptr) { | |||
| ++cs.num_mem_cached; | |||
| } else { | |||
| ++cs.num_disk_cached; | |||
| if (tree_->begin() != tree_->end()) { | |||
| cs.min_key = tree_->begin().key(); | |||
| cs.max_key = cs.min_key; // will adjust later. | |||
| for (auto it = tree_->begin(); it != tree_->end(); ++it) { | |||
| total_sz += it.value().sz; | |||
| if (it.value().ptr != nullptr) { | |||
| ++cs.num_mem_cached; | |||
| } else { | |||
| ++cs.num_disk_cached; | |||
| } | |||
| auto cur_key = it.key(); | |||
| if (GetMissingKeys) { | |||
| for (auto i = cs.max_key + 1; i < cur_key; ++i) { | |||
| cs.gap.push_back((i)); | |||
| } | |||
| } | |||
| cs.max_key = cur_key; | |||
| } | |||
| } | |||
| if (total_sz > 0) { | |||
| @@ -25,13 +25,13 @@ | |||
| #include "minddata/dataset/util/slice.h" | |||
| #include "minddata/dataset/util/storage_manager.h" | |||
| #include "minddata/dataset/util/auto_index.h" | |||
| #include "minddata/dataset/util/btree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of | |||
| /// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to | |||
| /// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to | |||
| /// restore the buffer. | |||
| /// disk (if a disk directory is provided). User must provide a key to insert the buffer. | |||
| /// \see ReadableSlice | |||
| class CachePool : public Service { | |||
| public: | |||
| @@ -73,22 +73,25 @@ class CachePool : public Service { | |||
| StorageManager::key_type storage_key; | |||
| }; | |||
| using data_index = AutoIndexObj<DataLocator>; | |||
| using data_index = BPlusTree<int64_t, DataLocator>; | |||
| using key_type = data_index::key_type; | |||
| using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other; | |||
| /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and | |||
| /// how many elements are spilled to disk. | |||
| struct CacheStat { | |||
| key_type min_key; | |||
| key_type max_key; | |||
| int64_t num_mem_cached; | |||
| int64_t num_disk_cached; | |||
| int64_t average_cache_sz; | |||
| std::vector<key_type> gap; | |||
| }; | |||
| /// \brief Constructor | |||
| /// \param alloc Allocator to allocate memory from | |||
| /// \param root Optional disk folder to spill | |||
| explicit CachePool(const value_allocator &alloc, const std::string &root = ""); | |||
| explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = ""); | |||
| CachePool(const CachePool &) = delete; | |||
| CachePool(CachePool &&) = delete; | |||
| @@ -103,10 +106,11 @@ class CachePool : public Service { | |||
| /// \brief Insert a sequence of ReadableSlice objects into the pool. | |||
| /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. | |||
| /// \param[in] key User supplied key | |||
| /// \param[in] buf A sequence of ReadableSlice objects. | |||
| /// \param[out] key Generated key | |||
| /// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory | |||
| /// \return Error code | |||
| Status Insert(const std::vector<ReadableSlice> &buf, key_type *key); | |||
| Status Insert(key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly); | |||
| /// \brief Restore a cached buffer (from memory or disk) | |||
| /// \param[in] key A previous key returned from Insert | |||
| /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice | |||
| @@ -122,18 +126,23 @@ class CachePool : public Service { | |||
| /// \brief Get statistics. | |||
| /// \return CacheStat object | |||
| CacheStat GetStat() const; | |||
| CacheStat GetStat(bool GetMissingKeys = false) const; | |||
| const value_allocator &get_allocator() const; | |||
| std::string MyName() const { return subfolder_; } | |||
| /// \brief Toggle locking | |||
| /// \note Once locking is off. It is user's responsibility to ensure concurrency | |||
| void SetLocking(bool on_off) { tree_->SetLocking(on_off); } | |||
| private: | |||
| value_allocator alloc_; | |||
| Path root_; | |||
| const std::string subfolder_; | |||
| std::shared_ptr<StorageManager> sm_; | |||
| std::shared_ptr<data_index> tree_; | |||
| bool custom_arena_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -133,12 +133,13 @@ void CircularPool::Deallocate(void *p) { | |||
| // Lock in the chain in shared mode and find out which | |||
| // segment it comes from | |||
| SharedLock lock(&rw_lock_); | |||
| auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr<Arena> &b) -> bool { | |||
| auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool { | |||
| char *q = reinterpret_cast<char *>(p); | |||
| char *base = const_cast<char *>(reinterpret_cast<const char *>(b->get_base_addr())); | |||
| return (q > base && q < base + b->get_max_size()); | |||
| auto *base = reinterpret_cast<const char *>(b->get_base_addr()); | |||
| return (q > base && q < base + arena_size_ * 1048576L); | |||
| }); | |||
| lock.Unlock(); | |||
| MS_ASSERT(it != mem_segments_.end()); | |||
| it->get()->Deallocate(p); | |||
| } | |||
| @@ -150,10 +151,10 @@ Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) { | |||
| } | |||
| void *p = *pp; | |||
| SharedLock lock(&rw_lock_); | |||
| auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr<Arena> &b) -> bool { | |||
| auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool { | |||
| char *q = reinterpret_cast<char *>(p); | |||
| char *base = const_cast<char *>(reinterpret_cast<const char *>(b->get_base_addr())); | |||
| return (q > base && q < base + b->get_max_size()); | |||
| auto *base = reinterpret_cast<const char *>(b->get_base_addr()); | |||
| return (q > base && q < base + arena_size_ * 1048576L); | |||
| }); | |||
| lock.Unlock(); | |||
| MS_ASSERT(it != mem_segments_.end()); | |||
| @@ -16,11 +16,14 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | |||
| #include <atomic> | |||
| #include <deque> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include "minddata/dataset/util/allocator.h" | |||
| #include "minddata/dataset/util/system_pool.h" | |||
| #include "minddata/dataset/util/semaphore.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| namespace mindspore { | |||
| @@ -37,7 +40,7 @@ class QueueMap { | |||
| using key_type = K; | |||
| using value_type = T; | |||
| QueueMap() = default; | |||
| QueueMap() : num_rows_(0) {} | |||
| virtual ~QueueMap() = default; | |||
| /// Add an element <key, T> to the map and wake up any consumer that is waiting | |||
| @@ -48,6 +51,7 @@ class QueueMap { | |||
| RequestQueue *rq = nullptr; | |||
| RETURN_IF_NOT_OK(GetRq(key, &rq)); | |||
| RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload))); | |||
| ++num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| @@ -56,9 +60,35 @@ class QueueMap { | |||
| RequestQueue *rq = nullptr; | |||
| RETURN_IF_NOT_OK(GetRq(key, &rq)); | |||
| RETURN_IF_NOT_OK(rq->Wait(out)); | |||
| --num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| /// Get the number of elements in the container | |||
| /// \return The number of elements in the container | |||
| int64_t size() const { return num_rows_; } | |||
| /// \return if the container is empty | |||
| bool empty() const { return num_rows_ == 0; } | |||
| /// Print out some useful information about the container | |||
| friend std::ostream &operator<<(std::ostream &out, const QueueMap &qm) { | |||
| std::unique_lock<std::mutex> lck(qm.mux_); | |||
| out << "Number of elements: " << qm.num_rows_ << "\n"; | |||
| out << "Dumping internal info:\n"; | |||
| int64_t k = 0; | |||
| for (auto &it : qm.all_) { | |||
| auto key = it.first; | |||
| const RequestQueue *rq = it.second.GetPointer(); | |||
| out << "(k:" << key << "," << *rq << ") "; | |||
| ++k; | |||
| if (k % 6 == 0) { | |||
| out << "\n"; | |||
| } | |||
| } | |||
| return out; | |||
| } | |||
| protected: | |||
| /// This is a handshake structure between producer and consumer | |||
| class RequestQueue { | |||
| @@ -86,8 +116,13 @@ class QueueMap { | |||
| return Status::OK(); | |||
| } | |||
| friend std::ostream &operator<<(std::ostream &out, const RequestQueue &rq) { | |||
| out << "sz:" << rq.row_.size() << ",uc:" << rq.use_count_.Peek(); | |||
| return out; | |||
| } | |||
| private: | |||
| std::mutex dq_mux_; | |||
| mutable std::mutex dq_mux_; | |||
| Semaphore use_count_; | |||
| std::deque<T> row_; | |||
| }; | |||
| @@ -104,7 +139,7 @@ class QueueMap { | |||
| *out = it->second.GetMutablePointer(); | |||
| } else { | |||
| // We will create a new one. | |||
| auto alloc = Services::GetAllocator<RequestQueue>(); | |||
| auto alloc = SystemPool::GetAllocator<RequestQueue>(); | |||
| auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc)); | |||
| if (r.second) { | |||
| auto &mem = r.first->second; | |||
| @@ -118,8 +153,9 @@ class QueueMap { | |||
| } | |||
| private: | |||
| std::mutex mux_; | |||
| mutable std::mutex mux_; | |||
| std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_; | |||
| std::atomic<int64_t> num_rows_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -29,10 +29,7 @@ void Semaphore::V() { | |||
| ++value_; | |||
| wait_cond_.NotifyOne(); | |||
| } | |||
| int Semaphore::Peek() { | |||
| std::unique_lock<std::mutex> lck(mutex_); | |||
| return value_; | |||
| } | |||
| int Semaphore::Peek() const { return value_; } | |||
| Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } | |||
| Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } | |||
| void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } | |||
| @@ -38,7 +38,7 @@ class Semaphore { | |||
| void V(); | |||
| /// \brief Peek the internal value | |||
| /// \return The internal value | |||
| int Peek(); | |||
| int Peek() const; | |||
| Status Register(TaskGroup *vg); | |||
| Status Deregister(); | |||
| void ResetIntrpState(); | |||
| @@ -51,6 +51,12 @@ std::string CodeAsString(const StatusCode c) { | |||
| case StatusCode::kSyntaxError: | |||
| s = "Syntax error"; | |||
| break; | |||
| case StatusCode::kBuddySpaceFull: | |||
| s = "BuddySpace full"; | |||
| break; | |||
| case StatusCode::kNetWorkError: | |||
| s = "Network error"; | |||
| break; | |||
| case StatusCode::kUnexpectedError: | |||
| default: | |||
| s = "Unexpected error"; | |||
| @@ -82,6 +82,8 @@ enum class StatusCode : char { | |||
| kBoundingBoxInvalidShape = 12, | |||
| kSyntaxError = 13, | |||
| kTimeOut = 14, | |||
| kBuddySpaceFull = 14, | |||
| kNetWorkError = 15, | |||
| // Make this error code the last one. Add new error code above it. | |||
| kUnexpectedError = 127 | |||
| }; | |||
| @@ -137,6 +139,8 @@ class Status { | |||
| bool IsNoSpace() const { return (get_code() == StatusCode::kNoSpace); } | |||
| bool IsNetWorkError() const { return (get_code() == StatusCode::kNetWorkError); } | |||
| private: | |||
| StatusCode code_; | |||
| std::string err_msg_; | |||
| @@ -99,7 +99,11 @@ Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const | |||
| #endif | |||
| if (r_sz != sz) { | |||
| errno_t err = (r_sz == 0) ? EOF : errno; | |||
| RETURN_STATUS_UNEXPECTED(strerror(err)); | |||
| if (errno == ENOSPC) { | |||
| return Status(StatusCode::kNoSpace, __LINE__, __FILE__); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED(strerror(err)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -71,10 +71,11 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu | |||
| key_type out_key; | |||
| value_type out_value; | |||
| bool create_new_container = false; | |||
| size_t last_num_container = -1; | |||
| do { | |||
| SharedLock lock_s(&rw_lock_); | |||
| size_t num_containers = containers_.size(); | |||
| if (create_new_container) { | |||
| if (create_new_container && (num_containers == last_num_container)) { | |||
| // Upgrade to exclusvie lock. | |||
| lock_s.Upgrade(); | |||
| create_new_container = false; | |||
| @@ -95,8 +96,11 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu | |||
| cont = containers_.at(num_containers - 1); | |||
| off64_t offset; | |||
| Status rc = cont->Insert(buf, &offset); | |||
| if (rc.IsNoSpace()) { | |||
| if (rc.get_code() == StatusCode::kBuddySpaceFull) { | |||
| create_new_container = true; | |||
| // Remember how many containers we saw. In the next iteration we will do a comparision to see | |||
| // if someone has already created it. | |||
| last_num_container = num_containers; | |||
| } else if (rc.IsOk()) { | |||
| out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); | |||
| RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); | |||
| @@ -15,6 +15,7 @@ | |||
| """Cache client | |||
| """ | |||
| import os | |||
| import copy | |||
| from ..core.validator_helpers import type_check, check_uint32, check_uint64 | |||
| @@ -25,11 +26,11 @@ class DatasetCache: | |||
| A client to interface with tensor caching service | |||
| """ | |||
| def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20): | |||
| def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None, | |||
| prefetch_size=None): | |||
| check_uint32(session_id, "session_id") | |||
| check_uint64(size, "size") | |||
| type_check(spilling, (bool,), "spilling") | |||
| check_uint32(prefetch_size, "prefetch size") | |||
| self.session_id = session_id | |||
| self.size = size | |||
| @@ -37,8 +38,13 @@ class DatasetCache: | |||
| self.hostname = hostname | |||
| self.port = port | |||
| self.prefetch_size = prefetch_size | |||
| # temporary disable cache feature in the current release | |||
| self.cache_client = None | |||
| self.num_connections = num_connections | |||
| if os.getenv('MS_ENABLE_CACHE') != 'TRUE': | |||
| # temporary disable cache feature in the current release | |||
| self.cache_client = None | |||
| else: | |||
| from mindspore._c_dataengine import CacheClient | |||
| self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size) | |||
| def GetStat(self): | |||
| return self.cache_client.GetStat() | |||
| @@ -55,5 +61,6 @@ class DatasetCache: | |||
| new_cache.hostname = copy.deepcopy(self.hostname, memodict) | |||
| new_cache.port = copy.deepcopy(self.port, memodict) | |||
| new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | |||
| new_cache.num_connections = copy.deepcopy(self.num_connections, memodict) | |||
| new_cache.cache_client = self.cache_client | |||
| return new_cache | |||
| @@ -1234,5 +1234,8 @@ def check_paddeddataset(method): | |||
| def check_cache_option(cache): | |||
| """Sanity check for cache parameter""" | |||
| if cache is not None: | |||
| # temporary disable cache feature in the current release | |||
| raise ValueError("Caching is disabled in the current release") | |||
| if os.getenv('MS_ENABLE_CACHE') != 'TRUE': | |||
| # temporary disable cache feature in the current release | |||
| raise ValueError("Caching is disabled in the current release") | |||
| from . import cache_client | |||
| type_check(cache, (cache_client.DatasetCache,), "cache") | |||
| @@ -156,6 +156,24 @@ def update_permissions(path): | |||
| if filename == "ms_serving": | |||
| os.chmod(file_fullpath, stat.S_IREAD | stat.S_IEXEC) | |||
| def bin_files(): | |||
| """ | |||
| Gets the binary files to be installed. | |||
| """ | |||
| data_files = [] | |||
| binary_files = [] | |||
| cache_server_bin = os.path.join('mindspore', 'bin', 'cache_server') | |||
| if not os.path.exists(cache_server_bin): | |||
| return data_files | |||
| binary_files.append(cache_server_bin) | |||
| cache_admin_bin = os.path.join('mindspore', 'bin', 'cache_admin') | |||
| if not os.path.exists(cache_admin_bin): | |||
| return data_files | |||
| binary_files.append(cache_admin_bin) | |||
| data_files.append(('bin', binary_files)) | |||
| return data_files | |||
| class EggInfo(egg_info): | |||
| """Egg info.""" | |||
| @@ -192,6 +210,7 @@ setup( | |||
| 'framework that could be used for mobile, edge and cloud scenarios.', | |||
| long_description="\n\n".join([readme, release]), | |||
| long_description_content_type="text/markdown", | |||
| data_files=bin_files(), | |||
| packages=find_packages(), | |||
| package_data=package_data, | |||
| include_package_data=True, | |||
| @@ -24,27 +24,26 @@ | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::MsLogLevel::INFO; | |||
| // For testing purposes, we will make the branching factor very low. | |||
| struct mytraits { | |||
| using slot_type = uint16_t; | |||
| static const slot_type kLeafSlots = 6; | |||
| static const slot_type kInnerSlots = 3; | |||
| using slot_type = uint16_t; | |||
| static const slot_type kLeafSlots = 6; | |||
| static const slot_type kInnerSlots = 3; | |||
| }; | |||
| class MindDataTestBPlusTree : public UT::Common { | |||
| public: | |||
| MindDataTestBPlusTree() = default; | |||
| MindDataTestBPlusTree() = default; | |||
| }; | |||
| // Test serial insert. | |||
| TEST_F(MindDataTestBPlusTree, Test1) { | |||
| Allocator<std::string> alloc(std::make_shared<SystemPool>()); | |||
| BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<uint64_t>, mytraits> btree(alloc); | |||
| BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<>, mytraits> btree(alloc); | |||
| Status rc; | |||
| for (int i = 0; i < 100; i++) { | |||
| uint64_t key = 2 * i; | |||
| @@ -109,23 +108,24 @@ TEST_F(MindDataTestBPlusTree, Test1) { | |||
| // Test concurrent insert. | |||
| TEST_F(MindDataTestBPlusTree, Test2) { | |||
| Allocator<std::string> alloc(std::make_shared<SystemPool>()); | |||
| BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<uint64_t>, mytraits> btree(alloc); | |||
| BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<>, mytraits> btree(alloc); | |||
| TaskGroup vg; | |||
| auto f = [&](int k) -> Status { | |||
| TaskManager::FindMe()->Post(); | |||
| for (int i = 0; i < 100; i++) { | |||
| uint64_t key = k * 100 + i; | |||
| std::ostringstream oss; | |||
| oss << "Hello World. I am " << key; | |||
| Status rc = btree.DoInsert(key, oss.str()); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| return Status::OK(); | |||
| for (int i = 0; i < 100; i++) { | |||
| uint64_t key = k * 100 + i; | |||
| std::ostringstream oss; | |||
| oss << "Hello World. I am " << key; | |||
| Status rc = btree.DoInsert(key, oss.str()); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| return Status::OK(); | |||
| }; | |||
| auto g = [&](int k) -> Status { | |||
| TaskManager::FindMe()->Post(); | |||
| for (int i = 0; i < 1000; i++) { | |||
| uint64_t key = rand() % 10000;; | |||
| uint64_t key = rand() % 10000; | |||
| ; | |||
| auto it = btree.Search(key); | |||
| } | |||
| return Status::OK(); | |||
| @@ -226,3 +226,22 @@ TEST_F(MindDataTestBPlusTree, Test4) { | |||
| EXPECT_EQ(cnt, 1000); | |||
| } | |||
| } | |||
| TEST_F(MindDataTestBPlusTree, TestPerfNoLocking) { | |||
| AutoIndexObj<int64_t> btree; | |||
| // No locking test | |||
| btree.SetLocking(false); | |||
| // Insert a million entries using the default traits. | |||
| for (auto i = 0; i < 1000000; ++i) { | |||
| ASSERT_TRUE(btree.insert(i)); | |||
| } | |||
| std::cout << "Tree height : " << btree.GetHeight() << std::endl; | |||
| std::cout << "Tree Order : " << btree.GetOrder() << std::endl; | |||
| std::cout << "Number of leaves : " << btree.GetNumLeaves() << std::endl; | |||
| std::cout << "Number of inner nodes : " << btree.GetNumInnerNodes() << std::endl; | |||
| auto r = btree.Search(3); | |||
| EXPECT_TRUE(r.second); | |||
| r = btree.Search(999999); | |||
| EXPECT_TRUE(r.second); | |||
| } | |||
| @@ -35,6 +35,23 @@ using mindspore::dataset::TaskGroup; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::MsLogLevel::INFO; | |||
| // Helper function to get the session id from SESSION_ID env variable | |||
| Status GetSessionFromEnv(session_id_type *session_id) { | |||
| RETURN_UNEXPECTED_IF_NULL(session_id); | |||
| if (const char *session_env = std::getenv("SESSION_ID")) { | |||
| std::string session_id_str(session_env); | |||
| try { | |||
| *session_id = std::stoul(session_id_str); | |||
| } catch (const std::exception &e) { | |||
| std::string err_msg = "Invalid numeric value for session id in env var: " + session_id_str; | |||
| return Status(StatusCode::kSyntaxError, err_msg); | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Test case requires a session id to be provided via SESSION_ID environment variable."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| class MindDataTestCacheOp : public UT::DatasetOpTesting { | |||
| public: | |||
| void SetUp() override { | |||
| @@ -46,8 +63,12 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting { | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { | |||
| Status rc; | |||
| CacheClient::Builder builder; | |||
| session_id_type env_session; | |||
| rc = GetSessionFromEnv(&env_session); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||
| builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = builder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -118,9 +139,6 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { | |||
| cmp = (map_out == map); | |||
| ASSERT_TRUE(cmp); | |||
| // Test Purge and Destroy | |||
| rc = myClient->PurgeCache(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myClient->DestroyCache(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| @@ -130,10 +148,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| TaskGroup vg; | |||
| Status rc; | |||
| session_id_type env_session; | |||
| rc = GetSessionFromEnv(&env_session); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // use arbitrary session of 1, size 1, spilling is true | |||
| CacheClient::Builder builder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true); | |||
| builder.SetSessionId(env_session).SetCacheMemSz(1).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = builder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -199,8 +222,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { | |||
| // RandomDataOp | |||
| // | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||
| // Clear the rc of the master thread if any | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| session_id_type env_session; | |||
| rc = GetSessionFromEnv(&env_session); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "UT test TestRandomDataCache1"; | |||
| // Start with an empty execution tree | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| @@ -236,8 +266,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||
| // CacheOp | |||
| // size of 0, spilling is true | |||
| CacheClient::Builder builder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||
| builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = builder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -273,7 +302,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| rc = myTree->Prepare(1); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // quick check to see what tree looks like | |||
| @@ -314,9 +343,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||
| //// RandomDataOp | |||
| //// | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||
| // Clear the rc of the master thread if any | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; | |||
| session_id_type env_session; | |||
| rc = GetSessionFromEnv(&env_session); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start with an empty execution tree | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| @@ -353,8 +389,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| CacheClient::Builder builder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); | |||
| builder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = builder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -386,7 +421,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| rc = myTree->Prepare(1); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| @@ -413,14 +448,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||
| } | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| // Clear the rc of the master thread if any | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| Status rc; | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| session_id_type env_session; | |||
| rc = GetSessionFromEnv(&env_session); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| CacheClient::Builder ccbuilder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||
| ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = ccbuilder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -468,7 +509,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| rc = myCacheOp->AddChild(so); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->Prepare(); | |||
| rc = myTree->Prepare(1); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->Launch(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -507,10 +548,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| //// RandomDataOp | |||
| //// | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | |||
| // Clear the rc of the master thread if any | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestCacheInheritSampler"; | |||
| session_id_type env_session; | |||
| rc = GetSessionFromEnv(&env_session); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| @@ -550,7 +597,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | |||
| // CacheOp | |||
| CacheClient::Builder ccbuilder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); | |||
| ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = ccbuilder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -577,7 +624,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(); | |||
| rc = myTree->Prepare(1); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| @@ -25,13 +25,13 @@ using namespace mindspore::dataset; | |||
| class MindDataTestMemoryPool : public UT::Common { | |||
| public: | |||
| std::shared_ptr<MemoryPool> mp_; | |||
| MindDataTestMemoryPool() {} | |||
| std::shared_ptr<MemoryPool> mp_; | |||
| MindDataTestMemoryPool() {} | |||
| void SetUp() { | |||
| Status rc = CircularPool::CreateCircularPool(&mp_, 1, 1, true); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| void SetUp() { | |||
| Status rc = CircularPool::CreateCircularPool(&mp_, 1, 1, true); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| }; | |||
| TEST_F(MindDataTestMemoryPool, DumpPoolInfo) { | |||
| @@ -40,7 +40,7 @@ TEST_F(MindDataTestMemoryPool, DumpPoolInfo) { | |||
| TEST_F(MindDataTestMemoryPool, TestOperator1) { | |||
| Status rc; | |||
| int *p = new(&rc, mp_) int; | |||
| int *p = new (&rc, mp_) int; | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| *p = 2048; | |||
| ::operator delete(p, mp_); | |||
| @@ -61,12 +61,11 @@ TEST_F(MindDataTestMemoryPool, TestOperator3) { | |||
| TEST_F(MindDataTestMemoryPool, TestAllocator) { | |||
| class A { | |||
| public: | |||
| explicit A (int x) : a(x) {} | |||
| int val_a() const { | |||
| return a; | |||
| } | |||
| explicit A(int x) : a(x) {} | |||
| int val_a() const { return a; } | |||
| private: | |||
| int a; | |||
| int a; | |||
| }; | |||
| Allocator<A> alloc(mp_); | |||
| std::shared_ptr<A> obj_a = std::allocate_shared<A>(alloc, 3); | |||
| @@ -74,3 +73,16 @@ TEST_F(MindDataTestMemoryPool, TestAllocator) { | |||
| ASSERT_EQ(v, 3); | |||
| MS_LOG(DEBUG) << *(std::dynamic_pointer_cast<CircularPool>(mp_)) << std::endl; | |||
| } | |||
| TEST_F(MindDataTestMemoryPool, TestMemGuard) { | |||
| MemGuard<uint8_t> mem; | |||
| // Try some large value. | |||
| int64_t sz = 5LL * 1024LL * 1024LL * 1024LL; | |||
| Status rc = mem.allocate(sz); | |||
| ASSERT_TRUE(rc.IsOk() || rc.IsOutofMemory()); | |||
| if (rc.IsOk()) { | |||
| // Try write a character half way. | |||
| auto *p = mem.GetMutablePointer(); | |||
| p[sz / 2] = 'a'; | |||
| } | |||
| } | |||
| @@ -0,0 +1,48 @@ | |||
| #!/bin/bash | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # This script is the driver of the individual test scenarios | |||
| CURRPATH=$(cd $(dirname $0); pwd) | |||
| echo "----------------------------------------------" | |||
| echo "Invalid syntax and cache_admin failure testing" | |||
| echo "----------------------------------------------" | |||
| echo | |||
| ${CURRPATH}/cachetest_args.sh | |||
| num_failures=$? | |||
| echo | |||
| echo "Invalid syntax and cache_admin failure testing complete. Number of failures: $num_failures" | |||
| echo | |||
| echo "----------------------------------------------" | |||
| echo "Test pipelines with cache (python)" | |||
| echo "----------------------------------------------" | |||
| echo | |||
| ${CURRPATH}/cachetest_py.sh | |||
| num_failures=$? | |||
| echo | |||
| echo "Test pipelines with cache complete. Number of failures: $num_failures" | |||
| echo | |||
| echo "----------------------------------------------" | |||
| echo "Cache cpp tests" | |||
| echo "----------------------------------------------" | |||
| echo | |||
| ${CURRPATH}/cachetest_cpp.sh | |||
| num_failures=$? | |||
| echo | |||
| echo "Cache cpp tests complete. Number of failures: $num_failures" | |||
| echo | |||
| @@ -0,0 +1,207 @@ | |||
| #!/bin/bash | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # source the globals and functions for use with cache testing | |||
| SKIP_ADMIN_COUNTER=false | |||
| . cachetest_lib.sh | |||
| echo | |||
| ################################################################################ | |||
| # Cache testing: cache_admin argument testing # | |||
| # Summary: Various tests that expect to get failure messages returned # | |||
| ################################################################################ | |||
| # Double-command test | |||
| cmd="${CACHE_ADMIN} --start --stop" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # missing command test | |||
| cmd="${CACHE_ADMIN} --port 50082" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # bad arg test | |||
| cmd="${CACHE_ADMIN} -p abc --start" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # missing arg test | |||
| cmd="${CACHE_ADMIN} -p --start" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # invalid command | |||
| cmd="${CACHE_ADMIN} -p 50082 --start --not_exist_cmd" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # spill directory does not exist | |||
| cmd="${CACHE_ADMIN} --start --spilldir /path_that_does_not_exist" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # start cache server twice | |||
| StartServer | |||
| HandleRcExit $? 1 1 | |||
| # start the cache server again, however, this time we expect an error | |||
| cmd="${CACHE_ADMIN} --start" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| StopServer | |||
| HandleRcExit $? 1 1 | |||
| # start cache server twice with different ports | |||
| # this one starts with the default port 50052 | |||
| StartServer | |||
| HandleRcExit $? 1 1 | |||
| # this one starts with port 50053 | |||
| cmd="${CACHE_ADMIN} --start -p 50053" | |||
| CacheAdminCmd "${cmd}" 0 | |||
| HandleRcExit $? 1 1 | |||
| # stop the cache server with default port | |||
| StopServer | |||
| HandleRcExit $? 1 1 | |||
| # stop the cache server with port 50053 | |||
| cmd="${CACHE_ADMIN} --stop -p 50053" | |||
| CacheAdminCmd "${cmd}" 0 | |||
| HandleRcExit $? 1 1 | |||
| # stop the cache server without bringing it up | |||
| cmd="${CACHE_ADMIN} --stop" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| # start the cache server with illegal hostname | |||
| cmd="${CACHE_ADMIN} --start -h 0.0.0.0" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| cmd="${CACHE_ADMIN} --start -h illegal" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| cmd="${CACHE_ADMIN} --start -h" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| cmd="${CACHE_ADMIN} --start -h --hostname" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| cmd="${CACHE_ADMIN} --start -h --hostname 127.0.0.1" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| # start the cache server with illegal port | |||
| cmd="${CACHE_ADMIN} --start -p 0" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| cmd="${CACHE_ADMIN} --start -p -1" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| cmd="${CACHE_ADMIN} --start -p 65536" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| cmd="${CACHE_ADMIN} --start -p illegal" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| cmd="${CACHE_ADMIN} --start -p" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| # find a port that is occupied using netstat | |||
| if [ -x "$(command -v netstat)" ]; then | |||
| port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) | |||
| if [ ${port} -gt 1025 ]; then | |||
| # start cache server with occupied port | |||
| cmd="${CACHE_ADMIN} --start -p ${port}" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 1 | |||
| fi | |||
| fi | |||
| # generate session before starting the cache server | |||
| cmd="${CACHE_ADMIN} -g" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # illegal generate session command | |||
| StartServer | |||
| HandleRcExit $? 1 1 | |||
| cmd="${CACHE_ADMIN} -g 1" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # illegal destroy session command | |||
| cmd="${CACHE_ADMIN} -d -2" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} -d illegal" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} -d" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # destroy a non-existing session | |||
| cmd="${CACHE_ADMIN} -d 99999" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # stop cache server at this point | |||
| StopServer | |||
| HandleRcExit $? 1 1 | |||
| # illegal number of workers | |||
| cmd="${CACHE_ADMIN} --start -w 0" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -w -1" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -w illegal" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -w 101" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -w 9999999" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -w" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # illegal spill path | |||
| cmd="${CACHE_ADMIN} --start -s" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| # spill path without writing perm | |||
| if [ "$EUID" -ne 0 ]; then | |||
| cmd="${CACHE_ADMIN} --start -s /" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| fi | |||
| # illegal log level | |||
| cmd="${CACHE_ADMIN} --start -l 4" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -l -1" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| cmd="${CACHE_ADMIN} --start -l" | |||
| CacheAdminCmd "${cmd}" 1 | |||
| HandleRcExit $? 0 0 | |||
| exit ${failed_tests} | |||