Merge pull request !5930 from Jamie/CacheOp_devtags/v1.1.0
| @@ -36,6 +36,7 @@ include(CPack) | |||||
| set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries") | set(INSTALL_LIB_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Installation directory for libraries") | ||||
| set(INSTALL_PY_DIR ".") | set(INSTALL_PY_DIR ".") | ||||
| set(INSTALL_BASE_DIR ".") | set(INSTALL_BASE_DIR ".") | ||||
| set(INSTALL_BIN_DIR "bin") | |||||
| if (CMAKE_SYSTEM_NAME MATCHES "Windows") | if (CMAKE_SYSTEM_NAME MATCHES "Windows") | ||||
| set(INSTALL_LIB_DIR ".") | set(INSTALL_LIB_DIR ".") | ||||
| @@ -78,7 +79,14 @@ if (ENABLE_MINDDATA) | |||||
| DESTINATION ${INSTALL_BASE_DIR} | DESTINATION ${INSTALL_BASE_DIR} | ||||
| COMPONENT mindspore | COMPONENT mindspore | ||||
| ) | ) | ||||
| if (CMAKE_SYSTEM_NAME MATCHES "Linux") | |||||
| install( | |||||
| TARGETS cache_admin cache_server | |||||
| OPTIONAL | |||||
| DESTINATION ${INSTALL_BIN_DIR} | |||||
| COMPONENT mindspore | |||||
| ) | |||||
| endif() | |||||
| file(GLOB_RECURSE OPENCV_LIB_LIST | file(GLOB_RECURSE OPENCV_LIB_LIST | ||||
| ${opencv_LIBPATH}/libopencv_core* | ${opencv_LIBPATH}/libopencv_core* | ||||
| ${opencv_LIBPATH}/libopencv_imgcodecs* | ${opencv_LIBPATH}/libopencv_imgcodecs* | ||||
| @@ -14,6 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <optional> | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | #include "minddata/dataset/api/python/pybind_register.h" | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| @@ -22,17 +23,19 @@ namespace dataset { | |||||
| PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | ||||
| (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | ||||
| .def( | |||||
| py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname, | |||||
| std::optional<int32_t> port, int32_t prefetch_sz) { | |||||
| std::shared_ptr<CacheClient> cc; | |||||
| CacheClient::Builder builder; | |||||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz); | |||||
| if (hostname) builder.SetHostname(hostname.value()); | |||||
| if (port) builder.SetPort(port.value()); | |||||
| THROW_IF_ERROR(builder.Build(&cc)); | |||||
| return cc; | |||||
| })) | |||||
| .def(py::init([](session_id_type id, uint64_t mem_sz, bool spill, | |||||
| std::optional<std::string> hostname, std::optional<int32_t> port, | |||||
| std::optional<int32_t> num_connections, std::optional<int32_t> prefetch_sz) { | |||||
| std::shared_ptr<CacheClient> cc; | |||||
| CacheClient::Builder builder; | |||||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill); | |||||
| if (hostname) builder.SetHostname(hostname.value()); | |||||
| if (port) builder.SetPort(port.value()); | |||||
| if (num_connections) builder.SetNumConnections(num_connections.value()); | |||||
| if (prefetch_sz) builder.SetPrefetchSize(prefetch_sz.value()); | |||||
| THROW_IF_ERROR(builder.Build(&cc)); | |||||
| return cc; | |||||
| })) | |||||
| .def("GetStat", [](CacheClient &cc) { | .def("GetStat", [](CacheClient &cc) { | ||||
| CacheServiceStat stat{}; | CacheServiceStat stat{}; | ||||
| THROW_IF_ERROR(cc.GetStat(&stat)); | THROW_IF_ERROR(cc.GetStat(&stat)); | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include "mindspore/core/utils/log_adapter.h" | #include "mindspore/core/utils/log_adapter.h" | ||||
| #include "minddata/dataset/util/system_pool.h" | #include "minddata/dataset/util/system_pool.h" | ||||
| @@ -33,7 +34,9 @@ ConfigManager::ConfigManager() | |||||
| monitor_sampling_interval_(kCfgMonitorSamplingInterval), | monitor_sampling_interval_(kCfgMonitorSamplingInterval), | ||||
| callback_timout_(kCfgCallbackTimeout), | callback_timout_(kCfgCallbackTimeout), | ||||
| cache_host_(kCfgDefaultCacheHost), | cache_host_(kCfgDefaultCacheHost), | ||||
| cache_port_(kCfgDefaultCachePort) { | |||||
| cache_port_(kCfgDefaultCachePort), | |||||
| num_connections_(kDftNumConnections), | |||||
| prefetch_size_(kDftPrefetchSize) { | |||||
| auto env_cache_host = std::getenv("MS_CACHE_HOST"); | auto env_cache_host = std::getenv("MS_CACHE_HOST"); | ||||
| auto env_cache_port = std::getenv("MS_CACHE_PORT"); | auto env_cache_port = std::getenv("MS_CACHE_PORT"); | ||||
| if (env_cache_host != nullptr) { | if (env_cache_host != nullptr) { | ||||
| @@ -71,6 +74,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) { | |||||
| set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); | set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); | ||||
| set_cache_host(j.value("cacheHost", cache_host_)); | set_cache_host(j.value("cacheHost", cache_host_)); | ||||
| set_cache_port(j.value("cachePort", cache_port_)); | set_cache_port(j.value("cachePort", cache_port_)); | ||||
| set_num_connections(j.value("numConnections", num_connections_)); | |||||
| set_prefetch_size(j.value("prefetchSize", prefetch_size_)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -120,8 +125,12 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s | |||||
| void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } | void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } | ||||
| void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; } | |||||
| void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = std::move(cache_host); } | |||||
| void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; } | void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; } | ||||
| void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; } | |||||
| void ConfigManager::set_prefetch_size(int32_t prefetch_size) { prefetch_size_ = prefetch_size; } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -97,6 +97,14 @@ class ConfigManager { | |||||
| // @return The port of cache server | // @return The port of cache server | ||||
| int32_t cache_port() const { return cache_port_; } | int32_t cache_port() const { return cache_port_; } | ||||
| /// getter function | |||||
| /// \return Number of tcp/ip connection | |||||
| int32_t num_connections() const { return num_connections_; } | |||||
| /// getter function | |||||
| /// \return Prefetch size | |||||
| int32_t prefetch_size() const { return prefetch_size_; } | |||||
| // setter function | // setter function | ||||
| // @param rows_per_buffer - The setting to apply to the config | // @param rows_per_buffer - The setting to apply to the config | ||||
| void set_rows_per_buffer(int32_t rows_per_buffer); | void set_rows_per_buffer(int32_t rows_per_buffer); | ||||
| @@ -121,6 +129,14 @@ class ConfigManager { | |||||
| // @param cache_port - The port of cache server | // @param cache_port - The port of cache server | ||||
| void set_cache_port(int32_t cache_port); | void set_cache_port(int32_t cache_port); | ||||
| /// setter function | |||||
| /// \param num_connections | |||||
| void set_num_connections(int32_t num_connections); | |||||
| /// setter function | |||||
| /// \param prefetch_size | |||||
| void set_prefetch_size(int32_t prefetch_size); | |||||
| uint32_t seed() const; | uint32_t seed() const; | ||||
| // setter function | // setter function | ||||
| @@ -153,6 +169,8 @@ class ConfigManager { | |||||
| uint32_t callback_timout_; | uint32_t callback_timout_; | ||||
| std::string cache_host_; | std::string cache_host_; | ||||
| int32_t cache_port_; | int32_t cache_port_; | ||||
| int32_t num_connections_; | |||||
| int32_t prefetch_size_; | |||||
| // Private helper function that takes a nlohmann json format and populates the settings | // Private helper function that takes a nlohmann json format and populates the settings | ||||
| // @param j - The json nlohmann json info | // @param j - The json nlohmann json info | ||||
| @@ -71,6 +71,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10; | |||||
| constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds | constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds | ||||
| constexpr int32_t kCfgDefaultCachePort = 50052; | constexpr int32_t kCfgDefaultCachePort = 50052; | ||||
| constexpr char kCfgDefaultCacheHost[] = "127.0.0.1"; | constexpr char kCfgDefaultCacheHost[] = "127.0.0.1"; | ||||
| constexpr int32_t kDftPrefetchSize = 20; | |||||
| constexpr int32_t kDftNumConnections = 12; | |||||
| // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | ||||
| constexpr uint8_t kCVInvalidType = 255; | constexpr uint8_t kCVInvalidType = 255; | ||||
| @@ -79,6 +79,14 @@ class TensorRow { | |||||
| const vector_type &getRow() const { return row_; } | const vector_type &getRow() const { return row_; } | ||||
| int64_t SizeInBytes() const { | |||||
| size_t sz = 0; | |||||
| for (auto &it : row_) { | |||||
| sz += it->SizeInBytes(); | |||||
| } | |||||
| return sz; | |||||
| } | |||||
| // Wrapper functions to support vector operations | // Wrapper functions to support vector operations | ||||
| void emplace_back(value_type t) { row_.emplace_back(t); } | void emplace_back(value_type t) { row_.emplace_back(t); } | ||||
| @@ -12,7 +12,9 @@ add_library(engine-cache-client OBJECT | |||||
| if (ENABLE_CACHE) | if (ENABLE_CACHE) | ||||
| ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto) | ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto) | ||||
| target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc) | |||||
| target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} | |||||
| cache_grpc_client.cc | |||||
| cache_ipc.cc) | |||||
| add_library(engine-cache-server OBJECT | add_library(engine-cache-server OBJECT | ||||
| ${CACHE_GRPC_SRCS} | ${CACHE_GRPC_SRCS} | ||||
| @@ -37,12 +37,17 @@ int main(int argc, char **argv) { | |||||
| warningMsg += "WARNING:\n"; | warningMsg += "WARNING:\n"; | ||||
| warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research"; | warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research"; | ||||
| warningMsg += " purposes at this time.\n"; | warningMsg += " purposes at this time.\n"; | ||||
| warningMsg += "This command is currently disabled. Quitting.\n"; | |||||
| auto env_enable_cache = std::getenv("MS_ENABLE_CACHE"); | |||||
| if (env_enable_cache == nullptr || strcmp(env_enable_cache, "TRUE") != 0) { | |||||
| // temporary disable cache feature in the current release | |||||
| warningMsg += "This command is currently disabled. Quitting.\n"; | |||||
| std::cerr << warningMsg << std::endl; | |||||
| return 0; | |||||
| } | |||||
| warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n"; | |||||
| // A warning message until the code is mature enough. | // A warning message until the code is mature enough. | ||||
| std::cerr << warningMsg << std::endl; | std::cerr << warningMsg << std::endl; | ||||
| // temporary disable cache feature in the current release | |||||
| return 0; | |||||
| if (argc == 1) { | if (argc == 1) { | ||||
| args.Help(); | args.Help(); | ||||
| @@ -19,9 +19,11 @@ | |||||
| #include <sys/wait.h> | #include <sys/wait.h> | ||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include <cerrno> | #include <cerrno> | ||||
| #include <iomanip> | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <string> | #include <string> | ||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| @@ -39,6 +41,7 @@ CacheAdminArgHandler::CacheAdminArgHandler() | |||||
| num_workers_(kDefaultNumWorkers), | num_workers_(kDefaultNumWorkers), | ||||
| shm_mem_sz_(kDefaultSharedMemorySizeInGB), | shm_mem_sz_(kDefaultSharedMemorySizeInGB), | ||||
| log_level_(kDefaultLogLevel), | log_level_(kDefaultLogLevel), | ||||
| memory_cap_ratio_(kMemoryCapRatio), | |||||
| hostname_(kCfgDefaultCacheHost), | hostname_(kCfgDefaultCacheHost), | ||||
| spill_dir_(kDefaultSpillDir), | spill_dir_(kDefaultSpillDir), | ||||
| command_id_(CommandId::kCmdUnknown) { | command_id_(CommandId::kCmdUnknown) { | ||||
| @@ -62,6 +65,9 @@ CacheAdminArgHandler::CacheAdminArgHandler() | |||||
| arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize; | arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize; | ||||
| arg_map_["-l"] = ArgValue::kArgLogLevel; | arg_map_["-l"] = ArgValue::kArgLogLevel; | ||||
| arg_map_["--minloglevel"] = ArgValue::kArgLogLevel; | arg_map_["--minloglevel"] = ArgValue::kArgLogLevel; | ||||
| arg_map_["-r"] = ArgValue::kArgMemoryCapRatio; | |||||
| arg_map_["--memory_cap_ratio"] = ArgValue::kArgMemoryCapRatio; | |||||
| arg_map_["--list_sessions"] = ArgValue::kArgListSessions; | |||||
| // Initialize argument tracker with false values | // Initialize argument tracker with false values | ||||
| for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) { | for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) { | ||||
| ArgValue currAV = static_cast<ArgValue>(i); | ArgValue currAV = static_cast<ArgValue>(i); | ||||
| @@ -69,6 +75,8 @@ CacheAdminArgHandler::CacheAdminArgHandler() | |||||
| } | } | ||||
| } | } | ||||
| CacheAdminArgHandler::~CacheAdminArgHandler() = default; | |||||
| Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | ||||
| CommandId command_id) { | CommandId command_id) { | ||||
| // Detect if the user tried to provide this argument more than once | // Detect if the user tried to provide this argument more than once | ||||
| @@ -102,7 +110,7 @@ Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | return Status(StatusCode::kSyntaxError, err_msg); | ||||
| } | } | ||||
| // Now, attempt to convert the value into it's string format for output | |||||
| // Now, attempt to convert the value into it's numeric format for output | |||||
| try { | try { | ||||
| *out_arg = std::stoul(value_as_string); | *out_arg = std::stoul(value_as_string); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| @@ -140,7 +148,13 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, | |||||
| // If there is no argument to get, such as the --start command, then out_arg will be a nullptr. | // If there is no argument to get, such as the --start command, then out_arg will be a nullptr. | ||||
| if (out_arg != nullptr) { | if (out_arg != nullptr) { | ||||
| // Fetch the argument from the arg stream into a string | // Fetch the argument from the arg stream into a string | ||||
| *arg_stream >> *out_arg; | |||||
| if (arg_stream->rdbuf()->in_avail() != 0) { | |||||
| *arg_stream >> *out_arg; | |||||
| } else { | |||||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| if (out_arg->empty()) { | if (out_arg->empty()) { | ||||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | ||||
| return Status(StatusCode::kSyntaxError, err_msg); | return Status(StatusCode::kSyntaxError, err_msg); | ||||
| @@ -150,12 +164,62 @@ Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheAdminArgHandler::AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream, | |||||
| CommandId command_id) { | |||||
| // Detect if the user tried to provide this argument more than once | |||||
| ArgValue selected_arg = arg_map_[option]; | |||||
| if (used_args_[selected_arg]) { | |||||
| std::string err_msg = "The " + option + " argument was given more than once."; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| // Flag that this arg is used now | |||||
| used_args_[selected_arg] = true; | |||||
| // Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument. | |||||
| // Other options are actual commands, for example "--start". | |||||
| // If this option is also a command, make sure there has not been multiple commands given before assigning it. | |||||
| if (command_id != CommandId::kCmdUnknown) { | |||||
| if (command_id_ != CommandId::kCmdUnknown) { | |||||
| std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } else { | |||||
| command_id_ = command_id; | |||||
| } | |||||
| } | |||||
| std::string value_as_string; | |||||
| // Fetch the argument from the arg stream into a string | |||||
| *arg_stream >> value_as_string; | |||||
| if (value_as_string.empty()) { | |||||
| std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>"; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| // Now, attempt to convert the value into it's string format for output | |||||
| try { | |||||
| *out_arg = std::stof(value_as_string, nullptr); | |||||
| } catch (const std::exception &e) { | |||||
| std::string err_msg = "Invalid numeric value: " + value_as_string; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { | Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { | ||||
| std::string tok; | std::string tok; | ||||
| while (*arg_stream >> tok) { | while (*arg_stream >> tok) { | ||||
| switch (arg_map_[tok]) { | switch (arg_map_[tok]) { | ||||
| case ArgValue::kArgHost: { | case ArgValue::kArgHost: { | ||||
| RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream)); | RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream)); | ||||
| // Temporary sanity check. We only support localhost for now | |||||
| if (hostname_ != std::string(kCfgDefaultCacheHost)) { | |||||
| std::string err_msg = | |||||
| "Invalid host interface: " + hostname_ + ". Current limitation, only 127.0.0.1 can be used."; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| break; | break; | ||||
| } | } | ||||
| case ArgValue::kArgPort: { | case ArgValue::kArgPort: { | ||||
| @@ -203,6 +267,14 @@ Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream)); | RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream)); | ||||
| break; | break; | ||||
| } | } | ||||
| case ArgValue::kArgMemoryCapRatio: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, &memory_cap_ratio_, arg_stream)); | |||||
| break; | |||||
| } | |||||
| case ArgValue::kArgListSessions: { | |||||
| RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdListSessions)); | |||||
| break; | |||||
| } | |||||
| default: { | default: { | ||||
| // Save space delimited trailing arguments | // Save space delimited trailing arguments | ||||
| trailing_args_ += (" " + tok); | trailing_args_ += (" " + tok); | ||||
| @@ -232,9 +304,12 @@ Status CacheAdminArgHandler::Validate() { | |||||
| } | } | ||||
| // Additional checks here | // Additional checks here | ||||
| if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value."); | |||||
| if (num_workers_ < 1 || num_workers_ > 100) | |||||
| return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100."); | |||||
| if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); | if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); | ||||
| // port range check? | |||||
| if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) | |||||
| return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1"); | |||||
| if (port_ < 1025 || port_ > 65535) return Status(StatusCode::kSyntaxError, "Port must be in range (1025..65535)."); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -245,12 +320,9 @@ Status CacheAdminArgHandler::RunCommand() { | |||||
| Help(); | Help(); | ||||
| break; | break; | ||||
| } | } | ||||
| case CommandId::kCmdStart: { | |||||
| RETURN_IF_NOT_OK(StartServer()); | |||||
| break; | |||||
| } | |||||
| case CommandId::kCmdStart: | |||||
| case CommandId::kCmdStop: { | case CommandId::kCmdStop: { | ||||
| RETURN_IF_NOT_OK(StopServer()); | |||||
| RETURN_IF_NOT_OK(StartStopServer(command_id_)); | |||||
| break; | break; | ||||
| } | } | ||||
| case CommandId::kCmdGenerateSession: { | case CommandId::kCmdGenerateSession: { | ||||
| @@ -259,7 +331,7 @@ Status CacheAdminArgHandler::RunCommand() { | |||||
| auto rq = std::make_shared<GenerateSessionIdRequest>(); | auto rq = std::make_shared<GenerateSessionIdRequest>(); | ||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | ||||
| RETURN_IF_NOT_OK(rq->Wait()); | RETURN_IF_NOT_OK(rq->Wait()); | ||||
| std::cout << rq->GetSessionId() << std::endl; | |||||
| std::cout << "Session: " << rq->GetSessionId() << std::endl; | |||||
| break; | break; | ||||
| } | } | ||||
| case CommandId::kCmdDestroySession: { | case CommandId::kCmdDestroySession: { | ||||
| @@ -273,6 +345,39 @@ Status CacheAdminArgHandler::RunCommand() { | |||||
| std::cout << "Drop session successful" << std::endl; | std::cout << "Drop session successful" << std::endl; | ||||
| break; | break; | ||||
| } | } | ||||
| case CommandId::kCmdListSessions: { | |||||
| CacheClientGreeter comm(hostname_, port_, 1); | |||||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||||
| auto rq = std::make_shared<ListSessionsRequest>(); | |||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| std::vector<SessionCacheInfo> session_info = rq->GetSessionCacheInfo(); | |||||
| if (!session_info.empty()) { | |||||
| std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached" | |||||
| << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl; | |||||
| for (auto curr_session : session_info) { | |||||
| std::string cache_id; | |||||
| std::string stat_mem_cached; | |||||
| std::string stat_disk_cached; | |||||
| std::string stat_avg_cached; | |||||
| int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); | |||||
| cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc); | |||||
| stat_mem_cached = | |||||
| (curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached); | |||||
| stat_disk_cached = | |||||
| (curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached); | |||||
| stat_avg_cached = | |||||
| (curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz); | |||||
| std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12) | |||||
| << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached | |||||
| << std::endl; | |||||
| } | |||||
| } else { | |||||
| std::cout << "No active sessions." << std::endl; | |||||
| } | |||||
| break; | |||||
| } | |||||
| default: { | default: { | ||||
| RETURN_STATUS_UNEXPECTED("Invalid cache admin command id."); | RETURN_STATUS_UNEXPECTED("Invalid cache admin command id."); | ||||
| break; | break; | ||||
| @@ -282,7 +387,7 @@ Status CacheAdminArgHandler::RunCommand() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheAdminArgHandler::StartServer() { | |||||
| Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { | |||||
| // There currently does not exist any "install path" or method to identify which path the installed binaries will | // There currently does not exist any "install path" or method to identify which path the installed binaries will | ||||
| // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the | ||||
| // cache_admin binary (this process). | // cache_admin binary (this process). | ||||
| @@ -324,7 +429,10 @@ Status CacheAdminArgHandler::StartServer() { | |||||
| close(fd[1]); | close(fd[1]); | ||||
| dup2(fd[0], 0); | dup2(fd[0], 0); | ||||
| close(fd[0]); | close(fd[0]); | ||||
| wait(nullptr); | |||||
| int status; | |||||
| if (waitpid(pid, &status, 0) == -1) { | |||||
| RETURN_STATUS_UNEXPECTED("waitpid fails. errno = " + std::to_string(errno)); | |||||
| } | |||||
| std::string msg; | std::string msg; | ||||
| const int32_t buf_sz = 1024; | const int32_t buf_sz = 1024; | ||||
| msg.resize(buf_sz); | msg.resize(buf_sz); | ||||
| @@ -335,6 +443,13 @@ Status CacheAdminArgHandler::StartServer() { | |||||
| } | } | ||||
| msg.resize(n); | msg.resize(n); | ||||
| std::cout << msg << std::endl; | std::cout << msg << std::endl; | ||||
| if (WIFEXITED(status)) { | |||||
| auto exit_status = WEXITSTATUS(status); | |||||
| if (exit_status) { | |||||
| std::string errMsg = "Child exit status " + std::to_string(exit_status); | |||||
| return Status(StatusCode::kUnexpectedError, errMsg); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } else { | } else { | ||||
| // Child here ... | // Child here ... | ||||
| @@ -350,19 +465,29 @@ Status CacheAdminArgHandler::StartServer() { | |||||
| std::string shared_memory_string = std::to_string(shm_mem_sz_); | std::string shared_memory_string = std::to_string(shm_mem_sz_); | ||||
| std::string minloglevel_string = std::to_string(log_level_); | std::string minloglevel_string = std::to_string(log_level_); | ||||
| std::string daemonize_string = "true"; | std::string daemonize_string = "true"; | ||||
| char *argv[8]; | |||||
| argv[0] = cache_server_binary.data(); // First arg is usually the binary name | |||||
| argv[1] = spill_dir_.data(); | |||||
| argv[2] = workers_string.data(); | |||||
| argv[3] = port_string.data(); | |||||
| argv[4] = shared_memory_string.data(); | |||||
| argv[5] = minloglevel_string.data(); | |||||
| argv[6] = daemonize_string.data(); | |||||
| argv[7] = nullptr; | |||||
| std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_); | |||||
| char *argv[9]; | |||||
| if (command_id == CommandId::kCmdStart) { | |||||
| argv[0] = cache_server_binary.data(); | |||||
| argv[1] = spill_dir_.data(); | |||||
| argv[2] = workers_string.data(); | |||||
| argv[3] = port_string.data(); | |||||
| argv[4] = shared_memory_string.data(); | |||||
| argv[5] = minloglevel_string.data(); | |||||
| argv[6] = daemonize_string.data(); | |||||
| argv[7] = memory_cap_ratio_string.data(); | |||||
| argv[8] = nullptr; | |||||
| } else { | |||||
| // We are doing a --stop. Change the name to '-' and we also need the port number. | |||||
| // The rest we don't need. | |||||
| argv[0] = std::string("-").data(); | |||||
| argv[1] = port_string.data(); | |||||
| argv[2] = nullptr; | |||||
| } | |||||
| // Now exec the binary | // Now exec the binary | ||||
| execv(argv[0], argv); | |||||
| execv(cache_server_binary.data(), argv); | |||||
| // If the exec was successful, this line will never be reached due to process image being replaced. | // If the exec was successful, this line will never be reached due to process image being replaced. | ||||
| // ..unless exec failed. | // ..unless exec failed. | ||||
| std::string err_msg = "Failed to exec cache server: " + cache_server_binary; | std::string err_msg = "Failed to exec cache server: " + cache_server_binary; | ||||
| @@ -371,16 +496,6 @@ Status CacheAdminArgHandler::StartServer() { | |||||
| } | } | ||||
| } | } | ||||
| Status CacheAdminArgHandler::StopServer() { | |||||
| CacheClientGreeter comm(hostname_, port_, 1); | |||||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | |||||
| auto rq = std::make_shared<ShutdownRequest>(); | |||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | |||||
| // We will ignore the rc because if the shutdown is successful, the server will not reply back. | |||||
| (void)rq->Wait(); | |||||
| return Status::OK(); | |||||
| } | |||||
| void CacheAdminArgHandler::Help() { | void CacheAdminArgHandler::Help() { | ||||
| std::cerr << "Syntax:\n"; | std::cerr << "Syntax:\n"; | ||||
| std::cerr << " cache_admin [--start | --stop]\n"; | std::cerr << " cache_admin [--start | --stop]\n"; | ||||
| @@ -390,8 +505,12 @@ void CacheAdminArgHandler::Help() { | |||||
| std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; | std::cerr << " [ [-d | --destroy_session] <session id> ]\n"; | ||||
| std::cerr << " [ [-w | --workers] <number of workers> ]\n"; | std::cerr << " [ [-w | --workers] <number of workers> ]\n"; | ||||
| std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; | std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n"; | ||||
| std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | |||||
| std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; | std::cerr << " [ [-l | --minloglevel] <log level> ]\n"; | ||||
| std::cerr << " [ --list_sessions ]\n"; | |||||
| // Do not expose these option to the user via help or documentation, but the options do exist to aid with | |||||
| // development and tuning. | |||||
| // std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n"; | |||||
| // std::cerr << " [ [-r | --memory_cap_ratio] <float percent value>]\n"; | |||||
| std::cerr << " [--help]" << std::endl; | std::cerr << " [--help]" << std::endl; | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -32,6 +32,7 @@ class CacheAdminArgHandler { | |||||
| static constexpr int32_t kDefaultNumWorkers = 32; | static constexpr int32_t kDefaultNumWorkers = 32; | ||||
| static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | ||||
| static constexpr int32_t kDefaultLogLevel = 1; | static constexpr int32_t kDefaultLogLevel = 1; | ||||
| static constexpr float kMemoryCapRatio = 0.8; | |||||
| static const char kServerBinary[]; | static const char kServerBinary[]; | ||||
| static const char kDefaultSpillDir[]; | static const char kDefaultSpillDir[]; | ||||
| @@ -42,12 +43,13 @@ class CacheAdminArgHandler { | |||||
| kCmdStop = 2, | kCmdStop = 2, | ||||
| kCmdGenerateSession = 3, | kCmdGenerateSession = 3, | ||||
| kCmdDestroySession = 4, | kCmdDestroySession = 4, | ||||
| kCmdListSessions = 5, | |||||
| kCmdUnknown = 32767 | kCmdUnknown = 32767 | ||||
| }; | }; | ||||
| CacheAdminArgHandler(); | CacheAdminArgHandler(); | ||||
| ~CacheAdminArgHandler() = default; | |||||
| virtual ~CacheAdminArgHandler(); | |||||
| Status ParseArgStream(std::stringstream *arg_stream); | Status ParseArgStream(std::stringstream *arg_stream); | ||||
| @@ -70,12 +72,12 @@ class CacheAdminArgHandler { | |||||
| kArgNumWorkers = 9, | kArgNumWorkers = 9, | ||||
| kArgSharedMemorySize = 10, | kArgSharedMemorySize = 10, | ||||
| kArgLogLevel = 11, | kArgLogLevel = 11, | ||||
| kArgNumArgs = 12 // Must be the last position to provide a count | |||||
| kArgMemoryCapRatio = 12, | |||||
| kArgListSessions = 13, | |||||
| kArgNumArgs = 14 // Must be the last position to provide a count | |||||
| }; | }; | ||||
| Status StartServer(); | |||||
| Status StopServer(); | |||||
| Status StartStopServer(CommandId); | |||||
| Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, | ||||
| CommandId command_id = CommandId::kCmdUnknown); | CommandId command_id = CommandId::kCmdUnknown); | ||||
| @@ -83,6 +85,9 @@ class CacheAdminArgHandler { | |||||
| Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, | Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream, | ||||
| CommandId command_id = CommandId::kCmdUnknown); | CommandId command_id = CommandId::kCmdUnknown); | ||||
| Status AssignArg(std::string option, float *out_arg, std::stringstream *arg_stream, | |||||
| CommandId command_id = CommandId::kCmdUnknown); | |||||
| Status Validate(); | Status Validate(); | ||||
| CommandId command_id_; | CommandId command_id_; | ||||
| @@ -90,6 +95,7 @@ class CacheAdminArgHandler { | |||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t shm_mem_sz_; | int32_t shm_mem_sz_; | ||||
| int32_t log_level_; | int32_t log_level_; | ||||
| float memory_cap_ratio_; | |||||
| session_id_type session_id_; | session_id_type session_id_; | ||||
| std::string hostname_; | std::string hostname_; | ||||
| std::string spill_dir_; | std::string spill_dir_; | ||||
| @@ -17,27 +17,19 @@ | |||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) | |||||
| : ptr_(nullptr), val_in_GB_(val_in_GB), port_(port), shmid_(-1) {} | |||||
| CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) { | |||||
| // We create the shared memory and we will destroy it. All other client just detach only. | |||||
| shm_.RemoveResourcesOnExit(); | |||||
| } | |||||
| CachedSharedMemoryArena::~CachedSharedMemoryArena() { | CachedSharedMemoryArena::~CachedSharedMemoryArena() { | ||||
| #if CACHE_LOCAL_CLIENT | |||||
| if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) { | |||||
| shmdt(this->ptr_); | |||||
| } | |||||
| this->ptr_ = nullptr; | |||||
| if (shmid_ != -1) { | |||||
| shmctl(shmid_, IPC_RMID, nullptr); | |||||
| // Also remove the path we use to generate ftok. | |||||
| Path p(PortToUnixSocketPath(port_)); | |||||
| (void)p.Remove(); | |||||
| } | |||||
| #endif | |||||
| // Also remove the path we use to generate ftok. | |||||
| Path p(PortToUnixSocketPath(port_)); | |||||
| (void)p.Remove(); | |||||
| } | } | ||||
| Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, | ||||
| size_t val_in_GB) { | size_t val_in_GB) { | ||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| #if CACHE_LOCAL_CLIENT | |||||
| auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); | auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); | ||||
| if (ba == nullptr) { | if (ba == nullptr) { | ||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| @@ -46,26 +38,13 @@ Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryAr | |||||
| // the destructor of *out to deal. | // the destructor of *out to deal. | ||||
| (*out).reset(ba); | (*out).reset(ba); | ||||
| // Generate the ftok using a combination of port. | // Generate the ftok using a combination of port. | ||||
| int err; | |||||
| auto shm_key = PortToFtok(port, &err); | |||||
| if (shm_key == (key_t)-1) { | |||||
| std::string errMsg = "Ftok failed with errno " + std::to_string(err); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||||
| SharedMemory::shm_key_t shm_key; | |||||
| RETURN_IF_NOT_OK(PortToFtok(port, &shm_key)); | |||||
| ba->shm_.SetPublicKey(shm_key); | |||||
| // Value is in GB. Convert into bytes. | // Value is in GB. Convert into bytes. | ||||
| int64_t sz = val_in_GB * 1073741824L; | int64_t sz = val_in_GB * 1073741824L; | ||||
| ba->shmid_ = shmget(shm_key, sz, IPC_CREAT | IPC_EXCL | access_mode); | |||||
| if (ba->shmid_) { | |||||
| ba->ptr_ = shmat(ba->shmid_, nullptr, 0); | |||||
| if (ba->ptr_ == reinterpret_cast<void *>(-1)) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| ba->impl_ = std::make_unique<ArenaImpl>(ba->ptr_, sz); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| #endif | |||||
| RETURN_IF_NOT_OK(ba->shm_.Create(sz)); | |||||
| ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "minddata/dataset/util/arena.h" | #include "minddata/dataset/util/arena.h" | ||||
| #include "minddata/dataset/engine/cache/cache_common.h" | #include "minddata/dataset/engine/cache/cache_common.h" | ||||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| /// This is a derived class of Arena but resides in shared memory | /// This is a derived class of Arena but resides in shared memory | ||||
| @@ -73,10 +74,9 @@ class CachedSharedMemoryArena : public MemoryPool { | |||||
| private: | private: | ||||
| mutable std::mutex mux_; | mutable std::mutex mux_; | ||||
| void *ptr_; | |||||
| int32_t val_in_GB_; | int32_t val_in_GB_; | ||||
| int32_t port_; | int32_t port_; | ||||
| int shmid_; | |||||
| SharedMemory shm_; | |||||
| std::unique_ptr<ArenaImpl> impl_; | std::unique_ptr<ArenaImpl> impl_; | ||||
| /// Private constructor. Not to be called directly. | /// Private constructor. Not to be called directly. | ||||
| CachedSharedMemoryArena(int32_t port, size_t val_in_GB); | CachedSharedMemoryArena(int32_t port, size_t val_in_GB); | ||||
| @@ -24,26 +24,26 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CacheClient::Builder::Builder() | CacheClient::Builder::Builder() | ||||
| : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) { | |||||
| : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_connections_(0), prefetch_size_(0) { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| hostname_ = cfg->cache_host(); | hostname_ = cfg->cache_host(); | ||||
| port_ = cfg->cache_port(); | port_ = cfg->cache_port(); | ||||
| num_workers_ = cfg->num_parallel_workers(); | |||||
| prefetch_size_ = 20; // rows_per_buf is too small (1 by default). | |||||
| num_connections_ = cfg->num_connections(); // number of async tcp/ip connections | |||||
| prefetch_size_ = cfg->prefetch_size(); // prefetch size | |||||
| } | } | ||||
| Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) { | Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) { | ||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| RETURN_IF_NOT_OK(SanityCheck()); | RETURN_IF_NOT_OK(SanityCheck()); | ||||
| *out = | |||||
| std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_); | |||||
| *out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_connections_, | |||||
| prefetch_size_); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheClient::Builder::SanityCheck() { | Status CacheClient::Builder::SanityCheck() { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive"); | CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive"); | ||||
| @@ -55,26 +55,32 @@ Status CacheClient::Builder::SanityCheck() { | |||||
| // Constructor | // Constructor | ||||
| CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, | CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, | ||||
| int32_t port, int32_t num_workers, int32_t prefetch_size) | |||||
| int32_t port, int32_t num_connections, int32_t prefetch_size) | |||||
| : server_connection_id_(0), | : server_connection_id_(0), | ||||
| cache_mem_sz_(cache_mem_sz), | cache_mem_sz_(cache_mem_sz), | ||||
| spill_(spill), | spill_(spill), | ||||
| local_bypass_(false), | local_bypass_(false), | ||||
| hostname_(std::move(hostname)), | hostname_(std::move(hostname)), | ||||
| port_(port), | port_(port), | ||||
| num_workers_(num_workers), | |||||
| prefetch_size_(prefetch_size) { | |||||
| num_connections_(num_connections), | |||||
| prefetch_size_(prefetch_size), | |||||
| fetch_all_keys_(true) { | |||||
| cinfo_.set_session_id(session_id); | cinfo_.set_session_id(session_id); | ||||
| comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_); | |||||
| comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_connections_); | |||||
| } | |||||
| CacheClient::~CacheClient() { | |||||
| cache_miss_keys_wp_.Set(); | |||||
| (void)comm_->ServiceStop(); | |||||
| } | } | ||||
| // print method for display cache details | // print method for display cache details | ||||
| void CacheClient::Print(std::ostream &out) const { | void CacheClient::Print(std::ostream &out) const { | ||||
| out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc() | out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc() | ||||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz() | |||||
| << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname() | |||||
| << "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers() | |||||
| << "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha | |||||
| << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << GetCacheMemSz() | |||||
| << "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << GetHostname() | |||||
| << "\n Port: " << GetPort() << "\n Number of rpc workers: " << GetNumConnections() | |||||
| << "\n Prefetch size: " << GetPrefetchSize() << "\n Local client support: " << std::boolalpha | |||||
| << SupportLocalClient(); | << SupportLocalClient(); | ||||
| } | } | ||||
| @@ -199,14 +205,6 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheClient::PurgeCache() { | |||||
| UniqueLock lck(&mux_); | |||||
| auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_); | |||||
| RETURN_IF_NOT_OK(PushRequest(rq)); | |||||
| RETURN_IF_NOT_OK(rq->Wait()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheClient::DestroyCache() { | Status CacheClient::DestroyCache() { | ||||
| UniqueLock lck(&mux_); | UniqueLock lck(&mux_); | ||||
| auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_); | auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_); | ||||
| @@ -253,5 +251,71 @@ Status CacheClient::BuildPhaseDone() const { | |||||
| } | } | ||||
| Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); } | Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); } | ||||
| void CacheClient::ServerRunningOutOfResources() { | |||||
| bool expected = true; | |||||
| if (fetch_all_keys_.compare_exchange_strong(expected, false)) { | |||||
| Status rc; | |||||
| // Server runs out of memory or disk space to cache any more rows. | |||||
| // First of all, we will turn off the locking. | |||||
| auto toggle_write_mode_rq = std::make_shared<ToggleWriteModeRequest>(server_connection_id_, false); | |||||
| rc = PushRequest(toggle_write_mode_rq); | |||||
| if (rc.IsError()) { | |||||
| return; | |||||
| } | |||||
| // Wait until we can toggle the state of the server to non-locking | |||||
| rc = toggle_write_mode_rq->Wait(); | |||||
| if (rc.IsError()) { | |||||
| return; | |||||
| } | |||||
| // Now we get a list of all the keys not cached at the server so | |||||
| // we can filter out at the prefetch level. | |||||
| auto cache_miss_rq = std::make_shared<GetCacheMissKeysRequest>(server_connection_id_); | |||||
| rc = PushRequest(cache_miss_rq); | |||||
| if (rc.IsError()) { | |||||
| return; | |||||
| } | |||||
| rc = cache_miss_rq->Wait(); | |||||
| if (rc.IsError()) { | |||||
| return; | |||||
| } | |||||
| // We will get back a vector of row id between [min,max] that are absent in the server. | |||||
| auto &row_id_buf = cache_miss_rq->reply_.result(); | |||||
| auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data()); | |||||
| std::vector<row_id_type> row_ids; | |||||
| auto sz = p->row_id()->size(); | |||||
| row_ids.reserve(sz); | |||||
| for (auto i = 0; i < sz; ++i) { | |||||
| row_ids.push_back(p->row_id()->Get(i)); | |||||
| } | |||||
| cache_miss_keys_ = std::make_unique<CacheMissKeys>(row_ids); | |||||
| // We are all set. | |||||
| cache_miss_keys_wp_.Set(); | |||||
| } | |||||
| } | |||||
| CacheClient::CacheMissKeys::CacheMissKeys(const std::vector<row_id_type> &v) { | |||||
| auto it = v.begin(); | |||||
| min_ = *it; | |||||
| ++it; | |||||
| max_ = *it; | |||||
| ++it; | |||||
| while (it != v.end()) { | |||||
| gap_.insert(*it); | |||||
| ++it; | |||||
| } | |||||
| MS_LOG(WARNING) << "# of cache miss keys between min(" << min_ << ") and max(" << max_ << ") is " << gap_.size(); | |||||
| } | |||||
| bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { | |||||
| if (key > max_ || key < min_) { | |||||
| return true; | |||||
| } else if (key == min_ || key == max_) { | |||||
| return false; | |||||
| } else { | |||||
| auto it = gap_.find(key); | |||||
| return it != gap_.end(); | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,8 +16,13 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ | ||||
| #include <atomic> | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <limits> | |||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include <mutex> | |||||
| #include <set> | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -31,6 +36,8 @@ | |||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/data_buffer.h" | #include "minddata/dataset/engine/data_buffer.h" | ||||
| #include "minddata/dataset/util/lock.h" | #include "minddata/dataset/util/lock.h" | ||||
| #include "minddata/dataset/util/cond_var.h" | |||||
| #include "minddata/dataset/util/queue_map.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -89,10 +96,10 @@ class CacheClient { | |||||
| } | } | ||||
| /// Setter function to set number of async rpc workers | /// Setter function to set number of async rpc workers | ||||
| /// \param num_workers | |||||
| /// \param num_connections | |||||
| /// \return Builder object itself | /// \return Builder object itself | ||||
| Builder &SetNumWorkers(int32_t num_workers) { | |||||
| num_workers_ = num_workers; | |||||
| Builder &SetNumConnections(int32_t num_connections) { | |||||
| num_connections_ = num_connections; | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -105,13 +112,13 @@ class CacheClient { | |||||
| } | } | ||||
| /// Getter functions | /// Getter functions | ||||
| session_id_type getSessionId() const { return session_id_; } | |||||
| uint64_t getCacheMemSz() const { return cache_mem_sz_; } | |||||
| session_id_type GetSessionId() const { return session_id_; } | |||||
| uint64_t GetCacheMemSz() const { return cache_mem_sz_; } | |||||
| bool isSpill() const { return spill_; } | bool isSpill() const { return spill_; } | ||||
| const std::string &getHostname() const { return hostname_; } | const std::string &getHostname() const { return hostname_; } | ||||
| int32_t getPort() const { return port_; } | |||||
| int32_t getNumWorkers() const { return num_workers_; } | |||||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||||
| int32_t GetPort() const { return port_; } | |||||
| int32_t GetNumConnections() const { return num_connections_; } | |||||
| int32_t GetPrefetchSize() const { return prefetch_size_; } | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| @@ -123,7 +130,7 @@ class CacheClient { | |||||
| bool spill_; | bool spill_; | ||||
| std::string hostname_; | std::string hostname_; | ||||
| int32_t port_; | int32_t port_; | ||||
| int32_t num_workers_; | |||||
| int32_t num_connections_; | |||||
| int32_t prefetch_size_; | int32_t prefetch_size_; | ||||
| }; | }; | ||||
| @@ -132,10 +139,10 @@ class CacheClient { | |||||
| /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited | /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited | ||||
| /// \param spill Spill to disk if out of memory | /// \param spill Spill to disk if out of memory | ||||
| CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port, | CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port, | ||||
| int32_t num_workers, int32_t prefetch_size); | |||||
| int32_t num_connections, int32_t prefetch_size); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CacheClient() { (void)comm_->ServiceStop(); } | |||||
| ~CacheClient(); | |||||
| /// \brief Send a TensorRow to the cache server | /// \brief Send a TensorRow to the cache server | ||||
| /// \param[in] row | /// \param[in] row | ||||
| @@ -161,10 +168,6 @@ class CacheClient { | |||||
| /// \return Status object | /// \return Status object | ||||
| Status CreateCache(uint32_t tree_crc, bool generate_id); | Status CreateCache(uint32_t tree_crc, bool generate_id); | ||||
| /// \brief Purge a cache. Cache can be reused after reset. | |||||
| /// \return Status object | |||||
| Status PurgeCache(); | |||||
| /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. | /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. | ||||
| /// \return Status object | /// \return Status object | ||||
| Status DestroyCache(); | Status DestroyCache(); | ||||
| @@ -218,12 +221,31 @@ class CacheClient { | |||||
| /// Getter functions | /// Getter functions | ||||
| session_id_type session_id() const { return cinfo_.session_id(); } | session_id_type session_id() const { return cinfo_.session_id(); } | ||||
| uint64_t getCacheMemSz() const { return cache_mem_sz_; } | |||||
| uint64_t GetCacheMemSz() const { return cache_mem_sz_; } | |||||
| bool isSpill() const { return spill_; } | bool isSpill() const { return spill_; } | ||||
| const std::string &getHostname() const { return hostname_; } | |||||
| int32_t getPort() const { return port_; } | |||||
| int32_t getNumWorkers() const { return num_workers_; } | |||||
| int32_t getPrefetchSize() const { return prefetch_size_; } | |||||
| const std::string &GetHostname() const { return hostname_; } | |||||
| int32_t GetPort() const { return port_; } | |||||
| int32_t GetNumConnections() const { return num_connections_; } | |||||
| int32_t GetPrefetchSize() const { return prefetch_size_; } | |||||
| /// MergeOp will notify us when the server can't cache any more rows. | |||||
| /// We will stop any attempt to fetch any rows that are most likely | |||||
| /// not present at the server. | |||||
| void ServerRunningOutOfResources(); | |||||
| /// \brief Check if a row is 100% cache miss at the server by checking the local information | |||||
| /// \param key row id to be test | |||||
| /// \return true if not at the server | |||||
| bool KeyIsCacheMiss(row_id_type key) { | |||||
| if (cache_miss_keys_) { | |||||
| // Make sure it is fully built even though the pointer is not null | |||||
| Status rc = cache_miss_keys_wp_.Wait(); | |||||
| if (rc.IsOk()) { | |||||
| return cache_miss_keys_->KeyIsCacheMiss(key); | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| private: | private: | ||||
| mutable RWLock mux_; | mutable RWLock mux_; | ||||
| @@ -240,9 +262,27 @@ class CacheClient { | |||||
| bool local_bypass_; | bool local_bypass_; | ||||
| std::string hostname_; | std::string hostname_; | ||||
| int32_t port_; | int32_t port_; | ||||
| int32_t num_workers_; | |||||
| int32_t num_connections_; | |||||
| int32_t prefetch_size_; | int32_t prefetch_size_; | ||||
| mutable std::shared_ptr<CacheClientGreeter> comm_; | mutable std::shared_ptr<CacheClientGreeter> comm_; | ||||
| std::atomic<bool> fetch_all_keys_; | |||||
| WaitPost cache_miss_keys_wp_; | |||||
| /// A structure shared by all the prefetchers to know what keys are missing at the server. | |||||
| class CacheMissKeys { | |||||
| public: | |||||
| explicit CacheMissKeys(const std::vector<row_id_type> &v); | |||||
| ~CacheMissKeys() = default; | |||||
| /// This checks if a key is missing. | |||||
| /// \param key | |||||
| /// \return true if definitely a key miss | |||||
| bool KeyIsCacheMiss(row_id_type key); | |||||
| private: | |||||
| row_id_type min_; | |||||
| row_id_type max_; | |||||
| std::set<row_id_type> gap_; | |||||
| }; | |||||
| std::unique_ptr<CacheMissKeys> cache_miss_keys_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,13 +25,6 @@ | |||||
| #define CACHE_LOCAL_CLIENT 1 | #define CACHE_LOCAL_CLIENT 1 | ||||
| #endif | #endif | ||||
| #ifdef CACHE_LOCAL_CLIENT | |||||
| #include <sys/types.h> | |||||
| #include <sys/ipc.h> | |||||
| #include <sys/shm.h> | |||||
| #else | |||||
| typedef int key_t; | |||||
| #endif | |||||
| #ifdef ENABLE_CACHE | #ifdef ENABLE_CACHE | ||||
| #include <grpcpp/grpcpp.h> | #include <grpcpp/grpcpp.h> | ||||
| #endif | #endif | ||||
| @@ -54,6 +47,8 @@ constexpr static uint32_t kLocalClientSupport = 1; | |||||
| /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is | /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is | ||||
| /// inline in the protobuf. This also implies kLocalClientSupport is also true. | /// inline in the protobuf. This also implies kLocalClientSupport is also true. | ||||
| constexpr static uint32_t kDataIsInSharedMemory = 2; | constexpr static uint32_t kDataIsInSharedMemory = 2; | ||||
| /// \brief Size of each message used in message queue. | |||||
| constexpr static int32_t kSharedMessageSize = 2048; | |||||
| /// \brief Convert a Status object into a protobuf | /// \brief Convert a Status object into a protobuf | ||||
| /// \param rc[in] Status object | /// \param rc[in] Status object | ||||
| @@ -62,29 +57,10 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) { | |||||
| reply->set_rc(static_cast<int32_t>(rc.get_code())); | reply->set_rc(static_cast<int32_t>(rc.get_code())); | ||||
| reply->set_msg(rc.ToString()); | reply->set_msg(rc.ToString()); | ||||
| } | } | ||||
| /// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number | /// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number | ||||
| /// \param port | /// \param port | ||||
| /// \return unix socket url | /// \return unix socket url | ||||
| inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } | inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } | ||||
| /// \brief Generate a shared memory key using the tcp/ip port. | |||||
| /// \note It must be called after the cache server generates the unix socket or ftok will fail. | |||||
| /// \note Caller must check the return value. -1 means ftok failed. | |||||
| /// \param[in] port | |||||
| /// \param[out] err. If not null and ftok fails, this will contain the value of errno | |||||
| /// \return key | |||||
| inline key_t PortToFtok(int port, int *err) { | |||||
| key_t shmkey = -1; | |||||
| #ifdef CACHE_LOCAL_CLIENT | |||||
| const std::string unix_path = PortToUnixSocketPath(port); | |||||
| shmkey = ftok(unix_path.data(), 'a'); | |||||
| if (err != nullptr && shmkey == (key_t)-1) { | |||||
| *err = errno; | |||||
| } | |||||
| #endif | |||||
| return shmkey; | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ | ||||
| @@ -17,34 +17,10 @@ | |||||
| #include <chrono> | #include <chrono> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, | |||||
| std::unique_ptr<CacheClientRequestTag> &&tag) { | |||||
| // If there is anything extra we need to do before we send. | |||||
| RETURN_IF_NOT_OK(tag->base_rq_->Prepare()); | |||||
| // One minute timeout | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||||
| tag->ctx_.set_deadline(deadline); | |||||
| tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq); | |||||
| tag->rpc_->StartCall(); | |||||
| // Last step is we release the ownership and transfer it to the completion queue. | |||||
| // The memory will be released by WorkerEntry or by the destructor when we drain the queue | |||||
| auto ccReqTag = tag.release(); | |||||
| ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, | |||||
| ccReqTag); // inject this object into the completion queue | |||||
| return Status::OK(); | |||||
| } | |||||
| CacheClientGreeter::~CacheClientGreeter() { (void)ServiceStop(); } | |||||
| CacheClientGreeter::~CacheClientGreeter() { | |||||
| (void)ServiceStop(); | |||||
| // Detach from shared memory if any | |||||
| if (shmat_addr_ != nullptr) { | |||||
| shmdt(shmat_addr_); | |||||
| shmat_addr_ = nullptr; | |||||
| } | |||||
| } | |||||
| CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) | |||||
| : num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) { | |||||
| CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections) | |||||
| : num_connections_(num_connections), request_cnt_(0) { | |||||
| grpc::ChannelArguments args; | grpc::ChannelArguments args; | ||||
| // We need to bump up the message size to unlimited. The default receiving | // We need to bump up the message size to unlimited. The default receiving | ||||
| // message limit is 4MB which is not big enough. | // message limit is 4MB which is not big enough. | ||||
| @@ -68,21 +44,11 @@ CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port | |||||
| Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) { | Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) { | ||||
| *local_bypass = false; | *local_bypass = false; | ||||
| #if CACHE_LOCAL_CLIENT | #if CACHE_LOCAL_CLIENT | ||||
| int err; | |||||
| shm_key_ = PortToFtok(port, &err); | |||||
| if (shm_key_ == (key_t)-1) { | |||||
| std::string errMsg = "Ftok failed with errno " + std::to_string(err); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| SharedMemory::shm_key_t shm_key; | |||||
| RETURN_IF_NOT_OK(PortToFtok(port, &shm_key)); | |||||
| // Attach to the shared memory | // Attach to the shared memory | ||||
| shm_id_ = shmget(shm_key_, 0, 0); | |||||
| if (shm_id_ == -1) { | |||||
| RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| shmat_addr_ = shmat(shm_id_, nullptr, 0); | |||||
| if (shmat_addr_ == reinterpret_cast<void *>(-1)) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| mem_.SetPublicKey(shm_key); | |||||
| RETURN_IF_NOT_OK(mem_.Attach()); | |||||
| *local_bypass = true; | *local_bypass = true; | ||||
| #endif | #endif | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -90,7 +56,7 @@ Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass | |||||
| Status CacheClientGreeter::DoServiceStart() { | Status CacheClientGreeter::DoServiceStart() { | ||||
| RETURN_IF_NOT_OK(vg_.ServiceStart()); | RETURN_IF_NOT_OK(vg_.ServiceStart()); | ||||
| RETURN_IF_NOT_OK(DispatchWorkers(num_workers_)); | |||||
| RETURN_IF_NOT_OK(DispatchWorkers(num_connections_)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -100,19 +66,40 @@ Status CacheClientGreeter::DoServiceStop() { | |||||
| // Shutdown the TaskGroup. | // Shutdown the TaskGroup. | ||||
| vg_.interrupt_all(); | vg_.interrupt_all(); | ||||
| vg_.join_all(Task::WaitFlag::kNonBlocking); | vg_.join_all(Task::WaitFlag::kNonBlocking); | ||||
| // Drain the queue | |||||
| bool success; | |||||
| void *tag; | |||||
| while (cq_.Next(&tag, &success)) { | |||||
| auto r = reinterpret_cast<CacheClientRequestTag *>(tag); | |||||
| delete r; | |||||
| // Drain the queue. We know how many requests we send out | |||||
| while (!req_.empty()) { | |||||
| bool success; | |||||
| void *tag; | |||||
| while (cq_.Next(&tag, &success)) { | |||||
| auto r = reinterpret_cast<CacheClientRequestTag *>(tag); | |||||
| req_.erase(r->seqNo_); | |||||
| } | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) { | Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) { | ||||
| auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq)); | |||||
| return tag->MakeCall(stub_.get(), &cq_, std::move(tag)); | |||||
| // If there is anything extra we need to do before we send. | |||||
| RETURN_IF_NOT_OK(rq->Prepare()); | |||||
| auto seqNo = request_cnt_.fetch_add(1); | |||||
| auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq), seqNo); | |||||
| // One minute timeout | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||||
| tag->ctx_.set_deadline(deadline); | |||||
| tag->rpc_ = stub_->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, &cq_); | |||||
| tag->rpc_->StartCall(); | |||||
| auto ccReqTag = tag.get(); | |||||
| // Insert it into the map. | |||||
| { | |||||
| std::unique_lock<std::mutex> lck(mux_); | |||||
| auto r = req_.emplace(seqNo, std::move(tag)); | |||||
| if (!r.second) { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__); | |||||
| } | |||||
| } | |||||
| // Last step is to tag the request. | |||||
| ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_, ccReqTag); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| Status CacheClientGreeter::WorkerEntry() { | Status CacheClientGreeter::WorkerEntry() { | ||||
| @@ -129,15 +116,26 @@ Status CacheClientGreeter::WorkerEntry() { | |||||
| auto &rc = rq->rc_; | auto &rc = rq->rc_; | ||||
| if (!rc.ok()) { | if (!rc.ok()) { | ||||
| auto error_code = rq->rc_.error_code(); | auto error_code = rq->rc_.error_code(); | ||||
| std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); | |||||
| Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| std::string err_msg; | |||||
| if (error_code == grpc::StatusCode::UNAVAILABLE) { | |||||
| err_msg = | |||||
| "Cache server is unreachable. Make sure the server is running. GRPC Code" + std::to_string(error_code); | |||||
| } else { | |||||
| err_msg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code); | |||||
| } | |||||
| Status remote_rc = Status(StatusCode::kNetWorkError, __LINE__, __FILE__, err_msg); | |||||
| Status2CacheReply(remote_rc, &rq->base_rq_->reply_); | Status2CacheReply(remote_rc, &rq->base_rq_->reply_); | ||||
| } | } | ||||
| // Notify the waiting thread. | // Notify the waiting thread. | ||||
| rq->Notify(); | rq->Notify(); | ||||
| } | } | ||||
| // We can now free the memory | |||||
| delete rq; | |||||
| { | |||||
| // We can now free the memory | |||||
| std::unique_lock<std::mutex> lck(mux_); | |||||
| auto seqNo = rq->seqNo_; | |||||
| auto n = req_.erase(seqNo); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(n == 1, "Sequence " + std::to_string(seqNo) + " not found"); | |||||
| } | |||||
| } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { | } else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) { | ||||
| // If we are interrupted, exit. Otherwise wait again. | // If we are interrupted, exit. Otherwise wait again. | ||||
| RETURN_IF_INTERRUPTED(); | RETURN_IF_INTERRUPTED(); | ||||
| @@ -16,10 +16,14 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_ | ||||
| #include <atomic> | |||||
| #include <map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/engine/cache/cache_common.h" | #include "minddata/dataset/engine/cache/cache_common.h" | ||||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||||
| #include "minddata/dataset/util/service.h" | #include "minddata/dataset/util/service.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -34,16 +38,10 @@ namespace dataset { | |||||
| class CacheClientRequestTag { | class CacheClientRequestTag { | ||||
| public: | public: | ||||
| friend class CacheClientGreeter; | friend class CacheClientGreeter; | ||||
| explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {} | |||||
| explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq, int64_t seqNo) | |||||
| : base_rq_(std::move(rq)), seqNo_(seqNo) {} | |||||
| ~CacheClientRequestTag() = default; | ~CacheClientRequestTag() = default; | ||||
| /// \brief Make a RPC call | |||||
| /// \param stub from CacheClientGreeter | |||||
| /// \param cq from CacheClientGreeter | |||||
| /// \return Status object | |||||
| static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq, | |||||
| std::unique_ptr<CacheClientRequestTag> &&tag); | |||||
| /// \brief Notify the client that a result has come back from the server | /// \brief Notify the client that a result has come back from the server | ||||
| void Notify() { base_rq_->wp_.Set(); } | void Notify() { base_rq_->wp_.Set(); } | ||||
| @@ -52,6 +50,7 @@ class CacheClientRequestTag { | |||||
| grpc::Status rc_; | grpc::Status rc_; | ||||
| grpc::ClientContext ctx_; | grpc::ClientContext ctx_; | ||||
| std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_; | std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_; | ||||
| int64_t seqNo_; | |||||
| }; | }; | ||||
| /// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC | /// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC | ||||
| @@ -60,7 +59,7 @@ class CacheClientGreeter : public Service { | |||||
| friend class CacheClient; | friend class CacheClient; | ||||
| public: | public: | ||||
| explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers); | |||||
| explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_connections); | |||||
| ~CacheClientGreeter(); | ~CacheClientGreeter(); | ||||
| /// Override base Service class | /// Override base Service class | ||||
| @@ -85,17 +84,18 @@ class CacheClientGreeter : public Service { | |||||
| /// \brief This returns where we attach to the shared memory. | /// \brief This returns where we attach to the shared memory. | ||||
| /// \return Base address of the shared memory. | /// \return Base address of the shared memory. | ||||
| const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | |||||
| const void *SharedMemoryBaseAddr() const { return mem_.SharedMemoryBaseAddr(); } | |||||
| private: | private: | ||||
| std::shared_ptr<grpc::Channel> channel_; | std::shared_ptr<grpc::Channel> channel_; | ||||
| std::unique_ptr<CacheServerGreeter::Stub> stub_; | std::unique_ptr<CacheServerGreeter::Stub> stub_; | ||||
| grpc::CompletionQueue cq_; | grpc::CompletionQueue cq_; | ||||
| TaskGroup vg_; | TaskGroup vg_; | ||||
| int32_t num_workers_; | |||||
| key_t shm_key_; | |||||
| int32_t shm_id_; | |||||
| void *shmat_addr_; | |||||
| int32_t num_connections_; | |||||
| std::atomic<int64_t> request_cnt_; | |||||
| mutable std::mutex mux_; | |||||
| std::map<int64_t, std::unique_ptr<CacheClientRequestTag>> req_; | |||||
| SharedMemory mem_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -47,53 +47,10 @@ void CacheServerGreeterImpl::Shutdown() { | |||||
| CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); } | CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); } | ||||
| Status CacheServerGreeterImpl::IpcResourceCleanup() { | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| int err; | |||||
| auto shm_key = PortToFtok(port_, &err); | |||||
| // We are expecting the unix path doesn't exist. | |||||
| if (shm_key == (key_t)-1) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // Attach to the shared memory | |||||
| auto shm_id = shmget(shm_key, 0, 0); | |||||
| if (shm_id == -1) { | |||||
| return Status::OK(); | |||||
| } | |||||
| struct shmid_ds ds {}; | |||||
| auto inx = shmctl(shm_id, IPC_STAT, &ds); | |||||
| if (inx == -1) { | |||||
| std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id); | |||||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| if (ds.shm_nattch == 0) { | |||||
| // Stale shared memory from last time. | |||||
| // Remove both the memory and the socket path | |||||
| inx = shmctl(shm_id, IPC_RMID, nullptr); | |||||
| if (inx == -1) { | |||||
| std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id); | |||||
| errMsg += ". Errno :" + std::to_string(errno); | |||||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| Path p(unix_socket_); | |||||
| (void)p.Remove(); | |||||
| } else { | |||||
| // Server is already up. | |||||
| MS_LOG(ERROR) << "Cache server is already up and running"; | |||||
| // We return a duplicate error. The main() will intercept | |||||
| // and output a proper message | |||||
| return Status(StatusCode::kDuplicateKey); | |||||
| } | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServerGreeterImpl::Run() { | Status CacheServerGreeterImpl::Run() { | ||||
| // To listen on all interfaces, use 0.0.0.0 | // To listen on all interfaces, use 0.0.0.0 | ||||
| // Use 127.0.0.1 if just locally on the same machine. | |||||
| std::string host("0.0.0.0"); // listen on all interfaces. | |||||
| // Future, allow the user to choose listening interface. For now, default to localhost | |||||
| std::string host("127.0.0.1"); | |||||
| std::string server_address = host + ":" + std::to_string(port_); | std::string server_address = host + ":" + std::to_string(port_); | ||||
| grpc::ServerBuilder builder; | grpc::ServerBuilder builder; | ||||
| // Default message size for gRPC is 4MB. Increase it to 2g-1 | // Default message size for gRPC is 4MB. Increase it to 2g-1 | ||||
| @@ -101,9 +58,6 @@ Status CacheServerGreeterImpl::Run() { | |||||
| int port_tcpip = 0; | int port_tcpip = 0; | ||||
| #if CACHE_LOCAL_CLIENT | #if CACHE_LOCAL_CLIENT | ||||
| int port_local = 0; | int port_local = 0; | ||||
| // Check if we need to do clean up on the shared memory if the server | |||||
| // came down unexpectedly like SEGV | |||||
| RETURN_IF_NOT_OK(IpcResourceCleanup()); | |||||
| // We also optimize on local clients on the same machine using unix socket | // We also optimize on local clients on the same machine using unix socket | ||||
| builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local); | builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local); | ||||
| #endif | #endif | ||||
| @@ -41,7 +41,7 @@ class CacheServerRequest : public BaseRequest { | |||||
| st_(STATE::CREATE), | st_(STATE::CREATE), | ||||
| responder_(&ctx_) {} | responder_(&ctx_) {} | ||||
| ~CacheServerRequest() = default; | |||||
| ~CacheServerRequest() override = default; | |||||
| /// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this | /// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this | ||||
| /// functor will translate each protobuf into some form understood by by CacheService class. | /// functor will translate each protobuf into some form understood by by CacheService class. | ||||
| @@ -87,8 +87,6 @@ class CacheServerGreeterImpl final { | |||||
| void Shutdown(); | void Shutdown(); | ||||
| Status IpcResourceCleanup(); | |||||
| private: | private: | ||||
| int32_t port_; | int32_t port_; | ||||
| size_t shm_pool_sz_in_gb_; | size_t shm_pool_sz_in_gb_; | ||||
| @@ -0,0 +1,163 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||||
| #include <sys/stat.h> | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status PortToFtok(int port, SharedMemory::shm_key_t *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| key_t shmkey = -1; | |||||
| const std::string unix_path = PortToUnixSocketPath(port); | |||||
| shmkey = ftok(unix_path.data(), 'a'); | |||||
| if (shmkey == (key_t)-1) { | |||||
| std::string errMsg = "Unable to create a ftok token. Errno = " + std::to_string(errno); | |||||
| return Status(errno == ENOENT ? StatusCode::kFileNotExist : StatusCode::kUnexpectedError, errMsg); | |||||
| } | |||||
| *out = shmkey; | |||||
| return Status::OK(); | |||||
| } | |||||
| SharedMessage::~SharedMessage() { | |||||
| // Only remove the queue if we are asked to. | |||||
| if (remove_ipc_on_exit_ && msg_qid_ != -1) { | |||||
| // Remove the message que and never mind about the return code. | |||||
| (void)msgctl(msg_qid_, IPC_RMID, nullptr); | |||||
| msg_qid_ = -1; | |||||
| } | |||||
| } | |||||
| Status SharedMessage::Create() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ == -1, "Message queue already created"); | |||||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||||
| msg_qid_ = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); | |||||
| if (msg_qid_ == -1) { | |||||
| std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SharedMessage::SendStatus(const Status &rc) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); | |||||
| StatusMsgBuf msg{ | |||||
| 1, | |||||
| }; | |||||
| msg.body.status.err_code = static_cast<int32_t>(rc.get_code()); | |||||
| auto err = memcpy_s(msg.body.status.err_msg, kSharedMessageSize, rc.ToString().data(), rc.ToString().size()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(err == EOK, "memcpy_s failed. err = " + std::to_string(err)); | |||||
| msg.body.status.err_msg[rc.ToString().size()] = '\0'; | |||||
| err = msgsnd(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), IPC_NOWAIT); | |||||
| if (err == -1) { | |||||
| std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SharedMessage::ReceiveStatus(Status *rc) { | |||||
| RETURN_UNEXPECTED_IF_NULL(rc); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); | |||||
| struct StatusMsgBuf msg {}; | |||||
| auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR); | |||||
| if (err == -1) { | |||||
| std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| Status rc_recv(static_cast<StatusCode>(msg.body.status.err_code), msg.body.status.err_msg); | |||||
| *rc = std::move(rc_recv); | |||||
| return Status::OK(); | |||||
| } | |||||
| SharedMemory::~SharedMemory() { | |||||
| if (shmat_addr_) { | |||||
| (void)Detach(); | |||||
| } | |||||
| if (remove_ipc_on_exit_ && shm_id_ != -1) { | |||||
| // Remove the shared memory and never mind about the return code. | |||||
| Status rc = Destroy(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << rc.ToString(); | |||||
| } | |||||
| } | |||||
| shm_id_ = -1; | |||||
| shmat_addr_ = nullptr; | |||||
| } | |||||
| Status SharedMemory::Create(int64_t sz) { | |||||
| auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; | |||||
| shm_id_ = shmget(shm_key_, sz, IPC_CREAT | IPC_EXCL | access_mode); | |||||
| if (shm_id_ == -1) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno)); | |||||
| } else { | |||||
| shmat_addr_ = shmat(shm_id_, nullptr, 0); | |||||
| if (shmat_addr_ == reinterpret_cast<void *>(-1)) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SharedMemory::Attach() { | |||||
| shm_id_ = shmget(shm_key_, 0, 0); | |||||
| if (shm_id_ == -1) { | |||||
| RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| shmat_addr_ = shmat(shm_id_, nullptr, 0); | |||||
| if (shmat_addr_ == reinterpret_cast<void *>(-1)) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SharedMemory::Detach() { | |||||
| if (shmat_addr_) { | |||||
| auto err = shmdt(shmat_addr_); | |||||
| if (err == -1) { | |||||
| RETURN_STATUS_UNEXPECTED("Shared memory detach failed. Errno " + std::to_string(errno)); | |||||
| } | |||||
| } | |||||
| shmat_addr_ = nullptr; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SharedMemory::Destroy() { | |||||
| // Remove the shared memory and never mind about the return code. | |||||
| auto err = shmctl(shm_id_, IPC_RMID, nullptr); | |||||
| if (err == -1) { | |||||
| std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_); | |||||
| errMsg += ". Errno :" + std::to_string(errno); | |||||
| errMsg += "\nPlesae remove it manually using ipcrm -m command"; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SharedMemory::GetNumAttached(int32_t *num) { | |||||
| RETURN_UNEXPECTED_IF_NULL(num); | |||||
| struct shmid_ds ds {}; | |||||
| auto err = shmctl(shm_id_, IPC_STAT, &ds); | |||||
| if (err == -1) { | |||||
| std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id_); | |||||
| errMsg += "\nPlease remove it manually using ipcrm -m command"; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| *num = ds.shm_nattch; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,207 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ | |||||
| #include <sys/types.h> | |||||
| #include <sys/ipc.h> | |||||
| #include <sys/shm.h> | |||||
| #include <sys/msg.h> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// A message queue structure between the parent and the child process | |||||
| struct StatusMsgBuf { | |||||
| int64_t mtype; | |||||
| union { | |||||
| char mtext[1]; | |||||
| struct { | |||||
| int32_t err_code; | |||||
| char err_msg[kSharedMessageSize]; | |||||
| } status; | |||||
| } body; | |||||
| }; | |||||
| class BaseIPC { | |||||
| public: | |||||
| BaseIPC() : remove_ipc_on_exit_(false) {} | |||||
| virtual ~BaseIPC() {} | |||||
| /// Indicate if we should remove the ipc resource on exit. Usually this is done by parent process. | |||||
| void RemoveResourcesOnExit() { remove_ipc_on_exit_ = true; } | |||||
| /// Copy constructors | |||||
| BaseIPC(const BaseIPC &rhs) : remove_ipc_on_exit_(false) {} | |||||
| BaseIPC &operator=(const BaseIPC &rhs) { | |||||
| if (&rhs != this) { | |||||
| remove_ipc_on_exit_ = false; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// Move constructors | |||||
| BaseIPC(BaseIPC &&rhs) noexcept : remove_ipc_on_exit_(rhs.remove_ipc_on_exit_) { rhs.remove_ipc_on_exit_ = false; } | |||||
| BaseIPC &operator=(BaseIPC &&rhs) noexcept { | |||||
| if (&rhs != this) { | |||||
| remove_ipc_on_exit_ = rhs.remove_ipc_on_exit_; | |||||
| rhs.remove_ipc_on_exit_ = false; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| protected: | |||||
| bool remove_ipc_on_exit_; | |||||
| }; | |||||
| /// \brief This wraps a shared message for the communication between processes. It is used primarily | |||||
| /// for starting and stopping a server. | |||||
| class SharedMessage : public BaseIPC { | |||||
| public: | |||||
| using queue_id_t = int; | |||||
| SharedMessage() : msg_qid_(-1) {} | |||||
| explicit SharedMessage(queue_id_t qid) : msg_qid_(qid) {} | |||||
| ~SharedMessage() override; | |||||
| /// Copy constructors | |||||
| SharedMessage(const SharedMessage &rhs) : BaseIPC(rhs), msg_qid_(rhs.msg_qid_) {} | |||||
| SharedMessage &operator=(const SharedMessage &rhs) { | |||||
| if (&rhs != this) { | |||||
| msg_qid_ = rhs.msg_qid_; | |||||
| BaseIPC::operator=(rhs); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// Move constructors | |||||
| SharedMessage(SharedMessage &&rhs) noexcept : BaseIPC(std::move(rhs)) { | |||||
| msg_qid_ = rhs.msg_qid_; | |||||
| rhs.msg_qid_ = -1; | |||||
| } | |||||
| SharedMessage &operator=(SharedMessage &&rhs) noexcept { | |||||
| if (&rhs != this) { | |||||
| msg_qid_ = rhs.msg_qid_; | |||||
| rhs.msg_qid_ = -1; | |||||
| BaseIPC::operator=(std::move(rhs)); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// Return the private id | |||||
| queue_id_t GetMsgQueueId() const { return msg_qid_; } | |||||
| /// \brief Create a private message queue | |||||
| Status Create(); | |||||
| /// Send a Status object | |||||
| Status SendStatus(const Status &rc); | |||||
| /// Retrieve a Status object | |||||
| Status ReceiveStatus(Status *rc); | |||||
| private: | |||||
| queue_id_t msg_qid_; | |||||
| }; | |||||
| /// \brief This wraps a shared memory for the communication between processes. It is used primarily | |||||
| /// for transporting large tensor rows. | |||||
| class SharedMemory : public BaseIPC { | |||||
| public: | |||||
| using shm_key_t = int; | |||||
| using shm_id_t = int; | |||||
| SharedMemory() : shm_id_(-1), shm_key_(-1), shmat_addr_(nullptr) {} | |||||
| explicit SharedMemory(shm_key_t public_key) : shm_id_(-1), shm_key_(public_key), shmat_addr_(nullptr) {} | |||||
| ~SharedMemory() override; | |||||
| /// Copy constructors | |||||
| SharedMemory(const SharedMemory &rhs) | |||||
| : BaseIPC(rhs), shm_id_(rhs.shm_id_), shm_key_(rhs.shm_key_), shmat_addr_(rhs.shmat_addr_) {} | |||||
| SharedMemory &operator=(const SharedMemory &rhs) { | |||||
| if (&rhs != this) { | |||||
| shm_id_ = rhs.shm_id_; | |||||
| shm_key_ = rhs.shm_key_; | |||||
| shmat_addr_ = rhs.shmat_addr_; | |||||
| BaseIPC::operator=(rhs); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// Move constructors | |||||
| SharedMemory(SharedMemory &&rhs) noexcept : BaseIPC(std::move(rhs)) { | |||||
| shm_id_ = rhs.shm_id_; | |||||
| shm_key_ = rhs.shm_key_; | |||||
| shmat_addr_ = rhs.shmat_addr_; | |||||
| rhs.shm_id_ = -1; | |||||
| rhs.shm_key_ = -1; | |||||
| rhs.shmat_addr_ = nullptr; | |||||
| } | |||||
| SharedMemory &operator=(SharedMemory &&rhs) noexcept { | |||||
| if (&rhs != this) { | |||||
| shm_id_ = rhs.shm_id_; | |||||
| shm_key_ = rhs.shm_key_; | |||||
| shmat_addr_ = rhs.shmat_addr_; | |||||
| rhs.shm_id_ = -1; | |||||
| rhs.shm_key_ = -1; | |||||
| rhs.shmat_addr_ = nullptr; | |||||
| BaseIPC::operator=(std::move(rhs)); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// \brief Set the public key | |||||
| void SetPublicKey(key_t public_key) { shm_key_ = public_key; } | |||||
| /// \brief This returns where we attach to the shared memory. | |||||
| /// \return Base address of the shared memory. | |||||
| const void *SharedMemoryBaseAddr() const { return shmat_addr_; } | |||||
| void *SharedMemoryBaseAddr() { return shmat_addr_; } | |||||
| /// \brief Attach to shared memory | |||||
| /// \return Status object | |||||
| Status Attach(); | |||||
| /// Detach from shared memory | |||||
| /// \return Status object | |||||
| Status Detach(); | |||||
| /// Create shared memory | |||||
| /// \return Status object | |||||
| Status Create(int64_t sz); | |||||
| /// Destroy shared memory | |||||
| /// \return Status object | |||||
| Status Destroy(); | |||||
| /// \brief Return the shared memory id | |||||
| shm_id_t GetSharedMemoryId() const { return shm_id_; } | |||||
| /// \brief Get number of processes attached to the shared memory | |||||
| /// \return Status object | |||||
| Status GetNumAttached(int32_t *num); | |||||
| private: | |||||
| shm_id_t shm_id_; | |||||
| shm_key_t shm_key_; | |||||
| void *shmat_addr_; | |||||
| }; | |||||
| /// \brief Generate a shared memory key using the tcp/ip port. | |||||
| /// \note It must be called after the cache server generates the unix socket or ftok will fail. | |||||
| /// \note Caller must check the return value. -1 means ftok failed. | |||||
| /// \param[in] port | |||||
| /// \param[out] err. If not null and ftok fails, this will contain the value of errno | |||||
| /// \return key | |||||
| Status PortToFtok(int port, SharedMemory::shm_key_t *); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_IPC_H_ | |||||
| @@ -21,24 +21,136 @@ | |||||
| #include <glog/logging.h> | #include <glog/logging.h> | ||||
| #endif | #endif | ||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <thread> | |||||
| #include <chrono> | |||||
| #include "minddata/dataset/engine/cache/cache_common.h" | |||||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||||
| namespace ds = mindspore::dataset; | namespace ds = mindspore::dataset; | ||||
| int main(int argc, char **argv) { | |||||
| /// Send a synchronous command to the local server using tcp/ip. | |||||
| /// We aren't using any client code because this binary is not necessarily linked with the client library. | |||||
| /// So just using grpc call directly. | |||||
| /// \param port tcp/ip port to use | |||||
| /// \param type Type of command. | |||||
| /// \param out grpc result | |||||
| /// \return Status object | |||||
| ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply, | |||||
| grpc::Status *out) { | |||||
| if (rq == nullptr) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null"); | |||||
| } | |||||
| if (reply == nullptr) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null"); | |||||
| } | |||||
| if (out == nullptr) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null"); | |||||
| } | |||||
| const std::string hostname = "127.0.0.1"; | |||||
| auto unix_socket = ds::PortToUnixSocketPath(port); | |||||
| #if CACHE_LOCAL_CLIENT | |||||
| const std::string target = "unix://" + unix_socket; | |||||
| #else | |||||
| const std::string target = hostname + ":" + std::to_string(port); | |||||
| #endif | |||||
| try { | |||||
| rq->set_type(static_cast<int16_t>(type)); | |||||
| grpc::ChannelArguments args; | |||||
| grpc::ClientContext ctx; | |||||
| grpc::CompletionQueue cq; | |||||
| // Standard async rpc call | |||||
| std::shared_ptr<grpc::Channel> channel = | |||||
| grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); | |||||
| std::unique_ptr<ds::CacheServerGreeter::Stub> stub = ds::CacheServerGreeter::NewStub(channel); | |||||
| std::unique_ptr<grpc::ClientAsyncResponseReader<ds::CacheReply>> rpc = | |||||
| stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq); | |||||
| rpc->StartCall(); | |||||
| // We need to pass a tag. But since this is the only request in the completion queue and so we | |||||
| // just pass a dummy | |||||
| int64_t dummy; | |||||
| void *tag; | |||||
| bool success; | |||||
| rpc->Finish(reply, out, &dummy); | |||||
| // Now we wait on the completion queue synchronously. | |||||
| auto r = cq.Next(&tag, &success); | |||||
| if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { | |||||
| if (!success || tag != &dummy) { | |||||
| std::string errMsg = "Unexpected programming error "; | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } | |||||
| if (out->ok()) { | |||||
| return ds::Status(static_cast<ds::StatusCode>(reply->rc()), reply->msg()); | |||||
| } else { | |||||
| auto error_code = out->error_code(); | |||||
| std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code); | |||||
| return ds::Status(ds::StatusCode::kNetWorkError, errMsg); | |||||
| } | |||||
| } else { | |||||
| std::string errMsg = "Unexpected queue rc = " + std::to_string(r); | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } | |||||
| } catch (const std::exception &e) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); | |||||
| } | |||||
| } | |||||
| /// Stop the server | |||||
| /// \param argv | |||||
| /// \return Status object | |||||
| ds::Status StopServer(int argc, char **argv) { | |||||
| ds::Status rc; | ds::Status rc; | ||||
| ds::CacheServer::Builder builder; | ds::CacheServer::Builder builder; | ||||
| std::string errMsg; | |||||
| if (argc != 2) { | |||||
| return ds::Status(ds::StatusCode::kSyntaxError); | |||||
| } | |||||
| int32_t port = strtol(argv[1], nullptr, 10); | |||||
| // We will go through the builder to do some snaity check. We only need the port number | |||||
| // to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything | |||||
| // to the spill directory. | |||||
| builder.SetPort(port).SetRootDirectory(""); | |||||
| // Part of the sanity check is check the shared memory. If the server is up and running, we expect | |||||
| // the return code is kDuplicate. | |||||
| rc = builder.SanityCheck(); | |||||
| if (rc.IsOk()) { | |||||
| errMsg = "Server is not up or has been shutdown already."; | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, errMsg); | |||||
| } else if (rc.get_code() != ds::StatusCode::kDuplicateKey) { | |||||
| // Not OK, and no duplicate, just return the rc whatever it is. | |||||
| return rc; | |||||
| } else { | |||||
| // Now we get some work to do. We will send a tcp/ip request to the given port. | |||||
| // This binary is not linked with client side of code, so we will just call grpc directly. | |||||
| ds::CacheRequest rq; | |||||
| ds::CacheReply reply; | |||||
| grpc::Status grpc_rc; | |||||
| rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc); | |||||
| // The request is like a self destruct message, the server will not send anything back and | |||||
| // shutdown all incoming request. So we should expect some unexpected network error if | |||||
| // all goes well and we expect to GRPC code 14. | |||||
| auto err_code = grpc_rc.error_code(); | |||||
| if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) { | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__); | |||||
| } | |||||
| } | |||||
| return ds::Status::OK(); | |||||
| } | |||||
| // This executable is not to be called directly, and should be invoked by cache_admin executable. | |||||
| if (argc != 7) { | |||||
| rc = ds::Status(ds::StatusCode::kSyntaxError); | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| return static_cast<int>(rc.get_code()); | |||||
| /// Start the server | |||||
| /// \param argv | |||||
| /// \return Status object | |||||
| ds::Status StartServer(int argc, char **argv) { | |||||
| ds::Status rc; | |||||
| ds::CacheServer::Builder builder; | |||||
| if (argc != 8) { | |||||
| return ds::Status(ds::StatusCode::kSyntaxError); | |||||
| } | } | ||||
| int32_t port = strtol(argv[3], nullptr, 10); | |||||
| builder.SetRootDirectory(argv[1]) | builder.SetRootDirectory(argv[1]) | ||||
| .SetNumWorkers(strtol(argv[2], nullptr, 10)) | .SetNumWorkers(strtol(argv[2], nullptr, 10)) | ||||
| .SetPort(strtol(argv[3], nullptr, 10)) | |||||
| .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)); | |||||
| .SetPort(port) | |||||
| .SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10)) | |||||
| .SetMemoryCapRatio(strtof(argv[7], nullptr)); | |||||
| #ifdef USE_GLOG | #ifdef USE_GLOG | ||||
| FLAGS_minloglevel = strtol(argv[5], nullptr, 10); | FLAGS_minloglevel = strtol(argv[5], nullptr, 10); | ||||
| @@ -52,36 +164,42 @@ int main(int argc, char **argv) { | |||||
| // is called. This is a standard procedure for daemonize a process on unix. | // is called. This is a standard procedure for daemonize a process on unix. | ||||
| if (chdir("/") == -1) { | if (chdir("/") == -1) { | ||||
| std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno); | std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno); | ||||
| std::cerr << errMsg << std::endl; | |||||
| return -1; | |||||
| } | |||||
| // Simple check of the parameters before we move on. | |||||
| rc = builder.SanityCheck(); | |||||
| if (rc.IsError()) { | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| return static_cast<int>(rc.get_code()); | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } | } | ||||
| // A message queue for communication between parent and child (if we fork). | |||||
| ds::SharedMessage msg; | |||||
| if (daemonize) { | |||||
| #ifdef USE_GLOG | #ifdef USE_GLOG | ||||
| FLAGS_log_dir = "/tmp"; | |||||
| google::InitGoogleLogging(argv[0]); | |||||
| FLAGS_log_dir = "/tmp"; | |||||
| google::InitGoogleLogging(argv[0]); | |||||
| #endif | #endif | ||||
| if (daemonize) { | |||||
| // fork the child process to become the daemon | |||||
| rc = msg.Create(); | |||||
| if (rc.IsError()) { | |||||
| return rc; | |||||
| } | |||||
| pid_t pid = fork(); | pid_t pid = fork(); | ||||
| // failed to fork | // failed to fork | ||||
| if (pid < 0) { | if (pid < 0) { | ||||
| std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno); | |||||
| std::cerr << err_msg << std::endl; | |||||
| return errno; | |||||
| std::string errMsg = "Failed to fork process for cache server. Errno = " + std::to_string(errno); | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else if (pid > 0) { | } else if (pid > 0) { | ||||
| // Parent | |||||
| // Parent and will be responsible for remove the queue on exit. | |||||
| msg.RemoveResourcesOnExit(); | |||||
| // Sleep one second and we attach to the msg que | |||||
| std::this_thread::sleep_for(std::chrono::seconds(1)); | |||||
| ds::Status child_rc; | |||||
| rc = msg.ReceiveStatus(&child_rc); | |||||
| if (rc.IsError()) { | |||||
| return rc; | |||||
| } | |||||
| if (child_rc.IsError()) { | |||||
| return child_rc; | |||||
| } | |||||
| std::cerr << "cache server daemon process has been created as process id: " << pid | std::cerr << "cache server daemon process has been created as process id: " << pid | ||||
| << "\nCheck log file for any start up error" << std::endl; | << "\nCheck log file for any start up error" << std::endl; | ||||
| signal(SIGCHLD, SIG_IGN); // ignore sig child signal. | signal(SIGCHLD, SIG_IGN); // ignore sig child signal. | ||||
| return 0; | |||||
| return ds::Status::OK(); | |||||
| } else { | } else { | ||||
| // Child process will continue from here if daemonize and parent has already exited. | // Child process will continue from here if daemonize and parent has already exited. | ||||
| // If we are running in the foreground, none of the code in block below will be run. | // If we are running in the foreground, none of the code in block below will be run. | ||||
| @@ -89,8 +207,8 @@ int main(int argc, char **argv) { | |||||
| umask(0); | umask(0); | ||||
| sid = setsid(); | sid = setsid(); | ||||
| if (sid < 0) { | if (sid < 0) { | ||||
| MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno); | |||||
| return errno; | |||||
| std::string errMsg = "Failed to setsid(). Errno = " + std::to_string(errno); | |||||
| return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } | } | ||||
| close(0); | close(0); | ||||
| close(1); | close(1); | ||||
| @@ -100,22 +218,36 @@ int main(int argc, char **argv) { | |||||
| // Dump the summary | // Dump the summary | ||||
| MS_LOG(INFO) << builder << std::endl; | MS_LOG(INFO) << builder << std::endl; | ||||
| // Create the instance with some sanity checks built in | |||||
| rc = builder.Build(); | rc = builder.Build(); | ||||
| if (rc.IsOk()) { | if (rc.IsOk()) { | ||||
| // If all goes well, kick off the threads. Loop forever and never return unless error. | |||||
| ds::CacheServer &cs = ds::CacheServer::GetInstance(); | ds::CacheServer &cs = ds::CacheServer::GetInstance(); | ||||
| // Kick off the threads. Loop forever and never return unless error. | |||||
| rc = cs.Run(); | |||||
| if (rc.get_code() == ds::StatusCode::kDuplicateKey) { | |||||
| std::string errMsg = "Server is already started"; | |||||
| MS_LOG(ERROR) << errMsg; | |||||
| std::cerr << errMsg << std::endl; | |||||
| return 0; | |||||
| } | |||||
| rc = cs.Run(msg.GetMsgQueueId()); | |||||
| } else if (daemonize) { | |||||
| // If we didn't pass the sanity check to at least create the instance, use | |||||
| // the message queue to return the error message if this is the child daemon. | |||||
| return msg.SendStatus(rc); | |||||
| } | |||||
| return rc; | |||||
| } | |||||
| int main(int argc, char **argv) { | |||||
| ds::Status rc; | |||||
| ds::CacheServer::Builder builder; | |||||
| // This executable is not to be called directly, and should be invoked by cache_admin executable. | |||||
| if (strcmp(argv[0], "-") == 0) { | |||||
| rc = StopServer(argc, argv); | |||||
| } else { | |||||
| rc = StartServer(argc, argv); | |||||
| } | } | ||||
| // Check result | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| MS_LOG(ERROR) << rc.ToString(); | |||||
| std::cerr << rc.ToString() << std::endl; | |||||
| return static_cast<int>(rc.get_code()); | |||||
| auto errCode = rc.get_code(); | |||||
| auto errMsg = rc.ToString(); | |||||
| std::cerr << errMsg << std::endl; | |||||
| return static_cast<int>(errCode); | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -250,5 +250,27 @@ Status GetStatRequest::PostReply() { | |||||
| stat_.cache_service_state = msg->state(); | stat_.cache_service_state = msg->state(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ListSessionsRequest::PostReply() { | |||||
| auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data()); | |||||
| auto session_vector = msg->sessions(); | |||||
| for (auto i = 0; i < session_vector->size(); ++i) { | |||||
| SessionCacheInfo current_info; | |||||
| CacheServiceStat stats; | |||||
| auto current_session_info = session_vector->Get(i); | |||||
| current_info.session_id = current_session_info->session_id(); | |||||
| current_info.connection_id = current_session_info->connection_id(); | |||||
| stats.num_mem_cached = current_session_info->stats()->num_mem_cached(); | |||||
| stats.num_disk_cached = current_session_info->stats()->num_disk_cached(); | |||||
| stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz(); | |||||
| stats.min_row_id = current_session_info->stats()->min_row_id(); | |||||
| stats.max_row_id = current_session_info->stats()->max_row_id(); | |||||
| stats.cache_service_state = current_session_info->stats()->state(); | |||||
| current_info.stats = stats; // fixed length struct. = operator is safe | |||||
| session_info_list_.push_back(current_info); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -46,6 +46,13 @@ struct CacheServiceStat { | |||||
| int8_t cache_service_state; | int8_t cache_service_state; | ||||
| }; | }; | ||||
| /// \brief Info structure ListSessionsRequest | |||||
| struct SessionCacheInfo { | |||||
| session_id_type session_id; | |||||
| connection_id_type connection_id; | |||||
| CacheServiceStat stats; | |||||
| }; | |||||
| /// \brief CacheClient communicates with CacheServer using Requests. | /// \brief CacheClient communicates with CacheServer using Requests. | ||||
| class BaseRequest { | class BaseRequest { | ||||
| public: | public: | ||||
| @@ -54,7 +61,7 @@ class BaseRequest { | |||||
| kCacheRow = 0, | kCacheRow = 0, | ||||
| kBatchFetchRows = 1, | kBatchFetchRows = 1, | ||||
| kCreateCache = 2, | kCreateCache = 2, | ||||
| kPurgeCache = 3, | |||||
| kGetCacheMissKeys = 3, | |||||
| kDestroyCache = 4, | kDestroyCache = 4, | ||||
| kGetStat = 5, | kGetStat = 5, | ||||
| kCacheSchema = 6, | kCacheSchema = 6, | ||||
| @@ -65,6 +72,9 @@ class BaseRequest { | |||||
| kAllocateSharedBlock = 11, | kAllocateSharedBlock = 11, | ||||
| kFreeSharedBlock = 12, | kFreeSharedBlock = 12, | ||||
| kStopService = 13, | kStopService = 13, | ||||
| kHeartBeat = 14, | |||||
| kToggleWriteMode = 15, | |||||
| kListSessions = 16, | |||||
| // Add new request before it. | // Add new request before it. | ||||
| kRequestUnknown = 32767 | kRequestUnknown = 32767 | ||||
| }; | }; | ||||
| @@ -73,6 +83,7 @@ class BaseRequest { | |||||
| friend class CacheServerRequest; | friend class CacheServerRequest; | ||||
| friend class CacheClientGreeter; | friend class CacheClientGreeter; | ||||
| friend class CacheClientRequestTag; | friend class CacheClientRequestTag; | ||||
| friend class CacheClient; | |||||
| /// \brief Base class of a cache server request | /// \brief Base class of a cache server request | ||||
| /// \param type Type of the request | /// \param type Type of the request | ||||
| @@ -119,7 +130,7 @@ class FreeSharedBlockRequest : public BaseRequest { | |||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| rq_.add_buf_data(std::to_string(addr)); | rq_.add_buf_data(std::to_string(addr)); | ||||
| } | } | ||||
| ~FreeSharedBlockRequest() = default; | |||||
| ~FreeSharedBlockRequest() override = default; | |||||
| }; | }; | ||||
| /// \brief Request to cache a single TensorRow | /// \brief Request to cache a single TensorRow | ||||
| @@ -136,7 +147,7 @@ class CacheRowRequest : public BaseRequest { | |||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| rq_.add_buf_data(cookie); | rq_.add_buf_data(cookie); | ||||
| } | } | ||||
| ~CacheRowRequest() = default; | |||||
| ~CacheRowRequest() override = default; | |||||
| /// \brief Serialize a TensorRow for streaming to the cache server | /// \brief Serialize a TensorRow for streaming to the cache server | ||||
| /// \param row TensorRow | /// \param row TensorRow | ||||
| @@ -183,7 +194,7 @@ class BatchFetchRequest : public BaseRequest { | |||||
| friend class CacheServer; | friend class CacheServer; | ||||
| friend class CacheService; | friend class CacheService; | ||||
| BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass); | BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass); | ||||
| ~BatchFetchRequest() = default; | |||||
| ~BatchFetchRequest() override = default; | |||||
| Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); | Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); | ||||
| private: | private: | ||||
| @@ -203,7 +214,7 @@ class CreateCacheRequest : public BaseRequest { | |||||
| /// \param flag Attributes of the cache. | /// \param flag Attributes of the cache. | ||||
| explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, | ||||
| CreateCacheFlag flag = CreateCacheFlag::kNone); | CreateCacheFlag flag = CreateCacheFlag::kNone); | ||||
| ~CreateCacheRequest() = default; | |||||
| ~CreateCacheRequest() override = default; | |||||
| void ParseResult(connection_id_type *id, std::string *out) { | void ParseResult(connection_id_type *id, std::string *out) { | ||||
| auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data()); | ||||
| *id = p->connection_id(); | *id = p->connection_id(); | ||||
| @@ -218,14 +229,15 @@ class CreateCacheRequest : public BaseRequest { | |||||
| CreateCacheFlag flag_; | CreateCacheFlag flag_; | ||||
| }; | }; | ||||
| /// \brief Request to purge a cache. | |||||
| class PurgeCacheRequest : public BaseRequest { | |||||
| /// \brief Request to get all the keys not present at the server. | |||||
| /// \note Only applicable to mappable case | |||||
| class GetCacheMissKeysRequest : public BaseRequest { | |||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) { | |||||
| explicit GetCacheMissKeysRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetCacheMissKeys) { | |||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| } | } | ||||
| ~PurgeCacheRequest() = default; | |||||
| ~GetCacheMissKeysRequest() override = default; | |||||
| }; | }; | ||||
| /// \brief Request to destroy a cache | /// \brief Request to destroy a cache | ||||
| @@ -235,7 +247,7 @@ class DestroyCacheRequest : public BaseRequest { | |||||
| explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) { | explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) { | ||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| } | } | ||||
| ~DestroyCacheRequest() = default; | |||||
| ~DestroyCacheRequest() override = default; | |||||
| }; | }; | ||||
| /// \brief Obtain the statistics of the current connection | /// \brief Obtain the statistics of the current connection | ||||
| @@ -247,7 +259,7 @@ class GetStatRequest : public BaseRequest { | |||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| } | } | ||||
| ~GetStatRequest() = default; | |||||
| ~GetStatRequest() override = default; | |||||
| /// \brief Override base function to process the result. | /// \brief Override base function to process the result. | ||||
| Status PostReply() override; | Status PostReply() override; | ||||
| @@ -269,7 +281,7 @@ class CacheSchemaRequest : public BaseRequest { | |||||
| explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) { | explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) { | ||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| } | } | ||||
| ~CacheSchemaRequest() = default; | |||||
| ~CacheSchemaRequest() override = default; | |||||
| Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map); | Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map); | ||||
| }; | }; | ||||
| @@ -281,7 +293,7 @@ class FetchSchemaRequest : public BaseRequest { | |||||
| explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) { | explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) { | ||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| } | } | ||||
| ~FetchSchemaRequest() = default; | |||||
| ~FetchSchemaRequest() override = default; | |||||
| Status PostReply() override; | Status PostReply() override; | ||||
| @@ -300,7 +312,7 @@ class BuildPhaseDoneRequest : public BaseRequest { | |||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| rq_.add_buf_data(cookie_); | rq_.add_buf_data(cookie_); | ||||
| } | } | ||||
| ~BuildPhaseDoneRequest() = default; | |||||
| ~BuildPhaseDoneRequest() override = default; | |||||
| private: | private: | ||||
| std::string cookie_; | std::string cookie_; | ||||
| @@ -313,7 +325,7 @@ class DropSessionRequest : public BaseRequest { | |||||
| explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) { | explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) { | ||||
| rq_.mutable_connection_info()->operator=(cinfo); | rq_.mutable_connection_info()->operator=(cinfo); | ||||
| } | } | ||||
| ~DropSessionRequest() = default; | |||||
| ~DropSessionRequest() override = default; | |||||
| }; | }; | ||||
| class GenerateSessionIdRequest : public BaseRequest { | class GenerateSessionIdRequest : public BaseRequest { | ||||
| @@ -325,11 +337,36 @@ class GenerateSessionIdRequest : public BaseRequest { | |||||
| rq_.set_connection_id(0); | rq_.set_connection_id(0); | ||||
| } | } | ||||
| ~GenerateSessionIdRequest() = default; | |||||
| ~GenerateSessionIdRequest() override = default; | |||||
| session_id_type GetSessionId() { return atoi(reply_.result().data()); } | session_id_type GetSessionId() { return atoi(reply_.result().data()); } | ||||
| }; | }; | ||||
| class ListSessionsRequest : public BaseRequest { | |||||
| public: | |||||
| friend class CacheServer; | |||||
| ListSessionsRequest() : BaseRequest(RequestType::kListSessions) { | |||||
| // This request is not specific to any cache or session | |||||
| rq_.set_connection_id(0); | |||||
| } | |||||
| ~ListSessionsRequest() override = default; | |||||
| /// \brief Override base function to process the result. | |||||
| Status PostReply() override; | |||||
| void GetSessionCacheInfo(std::vector<SessionCacheInfo> *info) { | |||||
| if (info != nullptr) { | |||||
| (*info) = session_info_list_; | |||||
| } | |||||
| } | |||||
| std::vector<SessionCacheInfo> GetSessionCacheInfo() { return session_info_list_; } | |||||
| private: | |||||
| std::vector<SessionCacheInfo> session_info_list_; | |||||
| }; | |||||
| class AllocateSharedBlockRequest : public BaseRequest { | class AllocateSharedBlockRequest : public BaseRequest { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| @@ -338,7 +375,7 @@ class AllocateSharedBlockRequest : public BaseRequest { | |||||
| rq_.set_connection_id(connection_id); | rq_.set_connection_id(connection_id); | ||||
| rq_.add_buf_data(std::to_string(requestedSz)); | rq_.add_buf_data(std::to_string(requestedSz)); | ||||
| } | } | ||||
| ~AllocateSharedBlockRequest() = default; | |||||
| ~AllocateSharedBlockRequest() override = default; | |||||
| /// \brief On return from the server, we get the (relative) address where | /// \brief On return from the server, we get the (relative) address where | ||||
| /// the free block is located. | /// the free block is located. | ||||
| @@ -349,11 +386,15 @@ class AllocateSharedBlockRequest : public BaseRequest { | |||||
| } | } | ||||
| }; | }; | ||||
| class ShutdownRequest : public BaseRequest { | |||||
| class ToggleWriteModeRequest : public BaseRequest { | |||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| ShutdownRequest() : BaseRequest(RequestType::kStopService) {} | |||||
| ~ShutdownRequest() = default; | |||||
| explicit ToggleWriteModeRequest(connection_id_type connection_id, bool on_off) | |||||
| : BaseRequest(RequestType::kToggleWriteMode) { | |||||
| rq_.set_connection_id(connection_id); | |||||
| rq_.add_buf_data(on_off ? "on" : "off"); | |||||
| } | |||||
| ~ToggleWriteModeRequest() override = default; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <functional> | #include <functional> | ||||
| #include <limits> | #include <limits> | ||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #include "minddata/dataset/engine/cache/cache_ipc.h" | |||||
| #include "minddata/dataset/engine/cache/cache_service.h" | #include "minddata/dataset/engine/cache/cache_service.h" | ||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/util/bit.h" | #include "minddata/dataset/util/bit.h" | ||||
| @@ -107,6 +108,8 @@ Status CacheServer::DoServiceStop() { | |||||
| // First stop all the threads. | // First stop all the threads. | ||||
| RETURN_IF_NOT_OK(vg_.ServiceStop()); | RETURN_IF_NOT_OK(vg_.ServiceStop()); | ||||
| // Clean up all the caches if any. | // Clean up all the caches if any. | ||||
| // Dump a message how much memory we have consumed in total. | |||||
| MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes."; | |||||
| UniqueLock lck(&rwLock_); | UniqueLock lck(&rwLock_); | ||||
| auto it = all_caches_.begin(); | auto it = all_caches_.begin(); | ||||
| while (it != all_caches_.end()) { | while (it != all_caches_.end()) { | ||||
| @@ -121,7 +124,6 @@ Status CacheServer::DoServiceStop() { | |||||
| } | } | ||||
| CacheService *CacheServer::GetService(connection_id_type id) const { | CacheService *CacheServer::GetService(connection_id_type id) const { | ||||
| SharedLock lck(&rwLock_); | |||||
| auto it = all_caches_.find(id); | auto it = all_caches_.find(id); | ||||
| if (it != all_caches_.end()) { | if (it != all_caches_.end()) { | ||||
| return it->second.get(); | return it->second.get(); | ||||
| @@ -134,6 +136,16 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||||
| std::string cookie; | std::string cookie; | ||||
| auto session_id = rq->connection_info().session_id(); | auto session_id = rq->connection_info().session_id(); | ||||
| auto crc = rq->connection_info().crc(); | auto crc = rq->connection_info().crc(); | ||||
| // Before allowing the creation, make sure the session had already been created by the user | |||||
| // Our intention is to add this cache to the active sessions list so leave the list locked during | |||||
| // this entire function. | |||||
| UniqueLock lock(&sessions_lock_); | |||||
| auto session_it = active_sessions_.find(session_id); | |||||
| if (session_it == active_sessions_.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!"); | |||||
| } | |||||
| // We concat both numbers to form the internal connection id. | // We concat both numbers to form the internal connection id. | ||||
| auto connection_id = GetConnectionID(session_id, crc); | auto connection_id = GetConnectionID(session_id, crc); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache"); | CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache"); | ||||
| @@ -172,10 +184,15 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||||
| } catch (const std::bad_alloc &e) { | } catch (const std::bad_alloc &e) { | ||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| } | } | ||||
| // Add the cache into the active session tracking. | |||||
| // We have already validated that the session exists and that this is a new cache created. | |||||
| session_it->second.insert(connection_id); | |||||
| } else { | } else { | ||||
| duplicate = true; | duplicate = true; | ||||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; | ||||
| } | } | ||||
| off_cookie = fbb.CreateString(cookie); | off_cookie = fbb.CreateString(cookie); | ||||
| CreateCacheReplyMsgBuilder bld(fbb); | CreateCacheReplyMsgBuilder bld(fbb); | ||||
| bld.add_connection_id(connection_id); | bld.add_connection_id(connection_id); | ||||
| @@ -183,19 +200,18 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { | |||||
| auto off = bld.Finish(); | auto off = bld.Finish(); | ||||
| fbb.Finish(off); | fbb.Finish(off); | ||||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | ||||
| // Track the history of all the sessions that we have created so far. | |||||
| history_sessions_.insert(session_id); | |||||
| // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it | // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it | ||||
| // treat it as OK. | // treat it as OK. | ||||
| return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK(); | return duplicate ? Status(StatusCode::kDuplicateKey) : Status::OK(); | ||||
| } | } | ||||
| Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { | |||||
| Status CacheServer::DestroyCache(CacheRequest *rq) { | |||||
| // We need a strong lock to protect the map. | // We need a strong lock to protect the map. | ||||
| UniqueLock lck(&rwLock_); | UniqueLock lck(&rwLock_); | ||||
| auto id = rq->connection_id(); | |||||
| CacheService *cs = GetService(id); | |||||
| // it is already destroyed. Ignore it. | // it is already destroyed. Ignore it. | ||||
| if (cs != nullptr) { | if (cs != nullptr) { | ||||
| auto id = rq->connection_id(); | |||||
| MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | ||||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | ||||
| auto n = all_caches_.erase(id); | auto n = all_caches_.erase(id); | ||||
| @@ -204,11 +220,34 @@ Status CacheServer::DestroyCache(CacheService *cs, CacheRequest *rq) { | |||||
| MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; | ||||
| } | } | ||||
| } | } | ||||
| // Now that this cache is removed, we need to also remove it's connection id from active session tracking | |||||
| auto session_id = GetSessionID(id); | |||||
| UniqueLock sess_lck(&sessions_lock_); | |||||
| auto it = active_sessions_.find(session_id); | |||||
| if (it == active_sessions_.end()) { | |||||
| // The session was not found in the active sessions | |||||
| RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!"); | |||||
| } | |||||
| auto connection_it = it->second.find(id); | |||||
| if (connection_it == it->second.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!"); | |||||
| } | |||||
| // remove that connection id from the set | |||||
| it->second.erase(connection_it); | |||||
| MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | auto connection_id = rq->connection_id(); | ||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | if (cs == nullptr) { | ||||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | ||||
| @@ -236,8 +275,11 @@ inline Status CacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | auto connection_id = rq->connection_id(); | ||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| auto shared_pool = comm_layer_->GetSharedMemoryPool(); | auto shared_pool = comm_layer_->GetSharedMemoryPool(); | ||||
| auto *base = shared_pool->SharedMemoryBaseAddr(); | auto *base = shared_pool->SharedMemoryBaseAddr(); | ||||
| // Ensure we got 3 pieces of data coming in | // Ensure we got 3 pieces of data coming in | ||||
| @@ -270,8 +312,11 @@ Status CacheServer::FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply | |||||
| return rc; | return rc; | ||||
| } | } | ||||
| Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | auto connection_id = rq->connection_id(); | ||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | if (cs == nullptr) { | ||||
| std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found"; | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | ||||
| @@ -325,8 +370,11 @@ Status CacheServer::BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheRepl | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | auto connection_id = rq->connection_id(); | ||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | if (cs == nullptr) { | ||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | ||||
| @@ -338,8 +386,8 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); | ||||
| bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); | ||||
| bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); | bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); | ||||
| bld.add_max_row_id(svc_stat.max_); | |||||
| bld.add_min_row_id(svc_stat.min_); | |||||
| bld.add_max_row_id(svc_stat.stat_.max_key); | |||||
| bld.add_min_row_id(svc_stat.stat_.min_key); | |||||
| bld.add_state(svc_stat.state_); | bld.add_state(svc_stat.state_); | ||||
| auto offset = bld.Finish(); | auto offset = bld.Finish(); | ||||
| fbb.Finish(offset); | fbb.Finish(offset); | ||||
| @@ -348,8 +396,11 @@ inline Status GetStat(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { | |||||
| Status CacheServer::CacheSchema(CacheRequest *rq) { | |||||
| auto connection_id = rq->connection_id(); | auto connection_id = rq->connection_id(); | ||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | if (cs == nullptr) { | ||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | ||||
| @@ -361,8 +412,11 @@ inline Status CacheSchema(CacheService *cs, CacheRequest *rq) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) { | |||||
| Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | auto connection_id = rq->connection_id(); | ||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | if (cs == nullptr) { | ||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | ||||
| @@ -377,8 +431,11 @@ inline Status FetchSchema(CacheService *cs, CacheRequest *rq, CacheReply *reply) | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { | |||||
| Status CacheServer::BuildPhaseDone(CacheRequest *rq) { | |||||
| auto connection_id = rq->connection_id(); | auto connection_id = rq->connection_id(); | ||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | if (cs == nullptr) { | ||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | ||||
| @@ -396,15 +453,24 @@ inline Status BuildPhaseDone(CacheService *cs, CacheRequest *rq) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheServer::PurgeCache(CacheService *cs) { | |||||
| Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | SharedLock lck(&rwLock_); | ||||
| // If shutdown in progress, ignore the command. | |||||
| if (global_shutdown_) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // it is already purged. Ignore it. | |||||
| if (cs != nullptr) { | |||||
| RETURN_IF_NOT_OK(cs->Purge()); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| std::vector<row_id_type> gap; | |||||
| RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap)); | |||||
| flatbuffers::FlatBufferBuilder fbb; | |||||
| auto off_t = fbb.CreateVector(gap); | |||||
| TensorRowIdsBuilder bld(fbb); | |||||
| bld.add_row_id(off_t); | |||||
| auto off = bld.Finish(); | |||||
| fbb.Finish(off); | |||||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -414,6 +480,72 @@ inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *re | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheServer::ToggleWriteMode(CacheRequest *rq) { | |||||
| auto connection_id = rq->connection_id(); | |||||
| // Hold the shared lock to prevent the cache from being dropped. | |||||
| SharedLock lck(&rwLock_); | |||||
| CacheService *cs = GetService(connection_id); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| // First piece of data is the on/off flag | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag"); | |||||
| const auto &action = rq->buf_data(0); | |||||
| bool on_off = false; | |||||
| if (strcmp(action.data(), "on") == 0) { | |||||
| on_off = true; | |||||
| } else if (strcmp(action.data(), "off") == 0) { | |||||
| on_off = false; | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Unknown request: " + action); | |||||
| } | |||||
| RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::ListSessions(CacheReply *reply) { | |||||
| SharedLock lck(&sessions_lock_); | |||||
| flatbuffers::FlatBufferBuilder fbb; | |||||
| std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector; | |||||
| for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) { | |||||
| session_id_type current_session_id = it->first; | |||||
| // Loop over each cache inside this session | |||||
| if (!it->second.empty()) { | |||||
| for (auto current_conn_id : it->second) { | |||||
| CacheService *cs = GetService(current_conn_id); | |||||
| if (cs == nullptr) { | |||||
| std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions."; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); | |||||
| } else { | |||||
| CacheService::ServiceStat svc_stat; | |||||
| RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); | |||||
| auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, | |||||
| svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key, | |||||
| svc_stat.stat_.max_key, svc_stat.state_); | |||||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); | |||||
| session_msgs_vector.push_back(current_session_info); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| // If there is no cache created yet, assign a connection id of 0 along with empty stats | |||||
| auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0); | |||||
| auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats); | |||||
| session_msgs_vector.push_back(current_session_info); | |||||
| } | |||||
| } | |||||
| auto session_msgs = fbb.CreateVector(session_msgs_vector); | |||||
| ListSessionsMsgBuilder s_builder(fbb); | |||||
| s_builder.add_sessions(session_msgs); | |||||
| auto offset = s_builder.Finish(); | |||||
| fbb.Finish(offset); | |||||
| reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); | |||||
| return Status::OK(); | |||||
| } | |||||
| /// \brief This is the main loop the cache server thread(s) are running. | /// \brief This is the main loop the cache server thread(s) are running. | ||||
| /// Each thread will pop a request and send the result back to the client using grpc | /// Each thread will pop a request and send the result back to the client using grpc | ||||
| /// \return | /// \return | ||||
| @@ -426,12 +558,6 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||||
| RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); | ||||
| auto &rq = cache_req->rq_; | auto &rq = cache_req->rq_; | ||||
| auto &reply = cache_req->reply_; | auto &reply = cache_req->reply_; | ||||
| CacheService *cs = nullptr; | |||||
| // Request comes in roughly two sets. One set is at the cache level with a connection id. | |||||
| // The other set is working at a high level and without a connection id | |||||
| if (!rq.has_connection_info()) { | |||||
| cs = GetService(rq.connection_id()); | |||||
| } | |||||
| // Except for creating a new session, we expect cs is not null. | // Except for creating a new session, we expect cs is not null. | ||||
| switch (cache_req->type_) { | switch (cache_req->type_) { | ||||
| case BaseRequest::RequestType::kCacheRow: { | case BaseRequest::RequestType::kCacheRow: { | ||||
| @@ -439,42 +565,42 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||||
| // call the appropriate method. | // call the appropriate method. | ||||
| auto flag = rq.flag(); | auto flag = rq.flag(); | ||||
| if (BitTest(flag, kDataIsInSharedMemory)) { | if (BitTest(flag, kDataIsInSharedMemory)) { | ||||
| cache_req->rc_ = FastCacheRow(cs, &rq, &reply); | |||||
| cache_req->rc_ = FastCacheRow(&rq, &reply); | |||||
| } else { | } else { | ||||
| cache_req->rc_ = CacheRow(cs, &rq, &reply); | |||||
| cache_req->rc_ = CacheRow(&rq, &reply); | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kBatchFetchRows: { | case BaseRequest::RequestType::kBatchFetchRows: { | ||||
| cache_req->rc_ = BatchFetchRows(cs, &rq, &reply); | |||||
| cache_req->rc_ = BatchFetchRows(&rq, &reply); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kCreateCache: { | case BaseRequest::RequestType::kCreateCache: { | ||||
| cache_req->rc_ = CreateService(&rq, &reply); | cache_req->rc_ = CreateService(&rq, &reply); | ||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kPurgeCache: { | |||||
| cache_req->rc_ = PurgeCache(cs); | |||||
| case BaseRequest::RequestType::kGetCacheMissKeys: { | |||||
| cache_req->rc_ = GetCacheMissKeys(&rq, &reply); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kDestroyCache: { | case BaseRequest::RequestType::kDestroyCache: { | ||||
| cache_req->rc_ = DestroyCache(cs, &rq); | |||||
| cache_req->rc_ = DestroyCache(&rq); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kGetStat: { | case BaseRequest::RequestType::kGetStat: { | ||||
| cache_req->rc_ = GetStat(cs, &rq, &reply); | |||||
| cache_req->rc_ = GetStat(&rq, &reply); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kCacheSchema: { | case BaseRequest::RequestType::kCacheSchema: { | ||||
| cache_req->rc_ = CacheSchema(cs, &rq); | |||||
| cache_req->rc_ = CacheSchema(&rq); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kFetchSchema: { | case BaseRequest::RequestType::kFetchSchema: { | ||||
| cache_req->rc_ = FetchSchema(cs, &rq, &reply); | |||||
| cache_req->rc_ = FetchSchema(&rq, &reply); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kBuildPhaseDone: { | case BaseRequest::RequestType::kBuildPhaseDone: { | ||||
| cache_req->rc_ = BuildPhaseDone(cs, &rq); | |||||
| cache_req->rc_ = BuildPhaseDone(&rq); | |||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kDropSession: { | case BaseRequest::RequestType::kDropSession: { | ||||
| @@ -498,6 +624,18 @@ Status CacheServer::ServerRequest(int32_t worker_id) { | |||||
| cache_req->rc_ = GlobalShutdown(); | cache_req->rc_ = GlobalShutdown(); | ||||
| break; | break; | ||||
| } | } | ||||
| case BaseRequest::RequestType::kHeartBeat: { | |||||
| cache_req->rc_ = Status::OK(); | |||||
| break; | |||||
| } | |||||
| case BaseRequest::RequestType::kToggleWriteMode: { | |||||
| cache_req->rc_ = ToggleWriteMode(&rq); | |||||
| break; | |||||
| } | |||||
| case BaseRequest::RequestType::kListSessions: { | |||||
| cache_req->rc_ = ListSessions(&reply); | |||||
| break; | |||||
| } | |||||
| default: | default: | ||||
| std::string errMsg("Unknown request type : "); | std::string errMsg("Unknown request type : "); | ||||
| errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_)); | ||||
| @@ -526,18 +664,32 @@ session_id_type CacheServer::GetSessionID(connection_id_type connection_id) cons | |||||
| } | } | ||||
| CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, | CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, | ||||
| int32_t shared_meory_sz_in_gb) | |||||
| int32_t shared_meory_sz_in_gb, float memory_cap_ratio) | |||||
| : top_(spill_path), | : top_(spill_path), | ||||
| num_workers_(num_workers), | num_workers_(num_workers), | ||||
| port_(port), | port_(port), | ||||
| shared_memory_sz_in_gb_(shared_meory_sz_in_gb), | shared_memory_sz_in_gb_(shared_meory_sz_in_gb), | ||||
| global_shutdown_(false) {} | |||||
| global_shutdown_(false), | |||||
| memory_cap_ratio_(memory_cap_ratio), | |||||
| cur_mem_usage_(0) { | |||||
| memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_; | |||||
| } | |||||
| Status CacheServer::Run() { | |||||
| RETURN_IF_NOT_OK(ServiceStart()); | |||||
| Status CacheServer::Run(int msg_qid) { | |||||
| Status rc = ServiceStart(); | |||||
| // If there is a message que, return the status now before we call join_all which will never return | |||||
| if (msg_qid != -1) { | |||||
| SharedMessage msg(msg_qid); | |||||
| RETURN_IF_NOT_OK(msg.SendStatus(rc)); | |||||
| } | |||||
| if (rc.IsError()) { | |||||
| return rc; | |||||
| } | |||||
| // This is called by the main function and we shouldn't exit. Otherwise the main thread | // This is called by the main function and we shouldn't exit. Otherwise the main thread | ||||
| // will just shutdown. So we will call some function that never return unless error. | // will just shutdown. So we will call some function that never return unless error. | ||||
| // One good case will be simply to wait for all threads to return. | // One good case will be simply to wait for all threads to return. | ||||
| // note that after we have sent the initial status using the msg_qid, parent process will exit and | |||||
| // remove it. So we can't use it again. | |||||
| RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking)); | RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -567,32 +719,51 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { | |||||
| Status CacheServer::DestroySession(CacheRequest *rq) { | Status CacheServer::DestroySession(CacheRequest *rq) { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); | CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); | ||||
| auto drop_session_id = rq->connection_info().session_id(); | auto drop_session_id = rq->connection_info().session_id(); | ||||
| UniqueLock lck(&rwLock_); | |||||
| for (auto &cs : all_caches_) { | |||||
| auto connection_id = cs.first; | |||||
| auto session_id = GetSessionID(connection_id); | |||||
| // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. | |||||
| // So we will just manually do it. | |||||
| if (session_id == drop_session_id) { | |||||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||||
| auto n = all_caches_.erase(connection_id); | |||||
| MS_LOG(INFO) << "Destroy " << n << " copies of cache with id " << connection_id; | |||||
| UniqueLock lck(&sessions_lock_); | |||||
| // First validate that this session exists | |||||
| auto it = active_sessions_.find(drop_session_id); | |||||
| if (it == active_sessions_.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!"); | |||||
| } | |||||
| // Iterate over the set of connection id's for this session that we're dropping and erase each one. | |||||
| { | |||||
| UniqueLock rwlck(&rwLock_); | |||||
| for (auto drop_connection_id : it->second) { | |||||
| auto cache_drop_it = all_caches_.find(drop_connection_id); | |||||
| if (cache_drop_it == all_caches_.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry."); | |||||
| } | |||||
| all_caches_.erase(cache_drop_it); | |||||
| MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id; | |||||
| // **Do not bother to remove the cache connection id from the active session because we will soon remove the | |||||
| // entire session. | |||||
| } | } | ||||
| } | } | ||||
| // Finally remove the session itself | |||||
| active_sessions_.erase(it); | |||||
| MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| session_id_type CacheServer::GenerateSessionID() const { | |||||
| SharedLock lock(&rwLock_); | |||||
| session_id_type CacheServer::GenerateSessionID() { | |||||
| UniqueLock lock(&sessions_lock_); | |||||
| auto mt = GetRandomDevice(); | auto mt = GetRandomDevice(); | ||||
| std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max()); | std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max()); | ||||
| session_id_type session_id; | session_id_type session_id; | ||||
| bool duplicate = false; | bool duplicate = false; | ||||
| do { | do { | ||||
| session_id = distribution(mt); | session_id = distribution(mt); | ||||
| auto it = history_sessions_.find(session_id); | |||||
| duplicate = (it != history_sessions_.end()); | |||||
| auto it = active_sessions_.find(session_id); | |||||
| duplicate = (it != active_sessions_.end()); | |||||
| } while (duplicate); | } while (duplicate); | ||||
| // Add this session to our tracking of active sessions with initialized empty set of connections. | |||||
| active_sessions_[session_id] = std::set<connection_id_type>(); | |||||
| return session_id; | return session_id; | ||||
| } | } | ||||
| @@ -637,19 +808,59 @@ Status CacheServer::GlobalShutdown() { | |||||
| vg_.interrupt_all(); | vg_.interrupt_all(); | ||||
| // The next thing to do drop all the caches. | // The next thing to do drop all the caches. | ||||
| UniqueLock lck(&rwLock_); | UniqueLock lck(&rwLock_); | ||||
| for (auto &it : all_caches_) { | |||||
| auto id = it.first; | |||||
| for (auto it = all_caches_.begin(); it != all_caches_.end();) { | |||||
| auto id = it->first; | |||||
| MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id); | ||||
| // Wait for all outstanding work to be finished. | // Wait for all outstanding work to be finished. | ||||
| auto &cs = it.second; | |||||
| auto &cs = it->second; | |||||
| UniqueLock cs_lock(&cs->rw_lock_); | UniqueLock cs_lock(&cs->rw_lock_); | ||||
| // std::map will invoke the destructor of CacheService. So we don't need to do anything here. | |||||
| (void)all_caches_.erase(id); | |||||
| it = all_caches_.erase(it); | |||||
| } | } | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t CacheServer::GetTotalSystemMemory() { | |||||
| auto pages = sysconf(_SC_PHYS_PAGES); | |||||
| auto page_size = sysconf(_SC_PAGE_SIZE); | |||||
| auto total = static_cast<int64_t>(pages) * static_cast<int64_t>(page_size); | |||||
| MS_LOG(INFO) << "Total physical RAM in bytes: " << total; | |||||
| return total; | |||||
| } | |||||
| Status CacheServer::Builder::IpcResourceCleanup() { | |||||
| Status rc; | |||||
| SharedMemory::shm_key_t shm_key; | |||||
| auto unix_socket = PortToUnixSocketPath(port_); | |||||
| rc = PortToFtok(port_, &shm_key); | |||||
| // We are expecting the unix path doesn't exist. | |||||
| if (rc.IsError()) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // Attach to the shared memory which we expect don't exist | |||||
| SharedMemory mem(shm_key); | |||||
| rc = mem.Attach(); | |||||
| if (rc.IsError()) { | |||||
| return Status::OK(); | |||||
| } | |||||
| int32_t num_attached; | |||||
| RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached)); | |||||
| if (num_attached == 0) { | |||||
| // Stale shared memory from last time. | |||||
| // Remove both the memory and the socket path | |||||
| RETURN_IF_NOT_OK(mem.Destroy()); | |||||
| Path p(unix_socket); | |||||
| (void)p.Remove(); | |||||
| } else { | |||||
| // Server is already up. | |||||
| std::string errMsg = "Cache server is already up and running"; | |||||
| // We return a duplicate error. The main() will intercept | |||||
| // and output a proper message | |||||
| return Status(StatusCode::kDuplicateKey, errMsg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheServer::Builder::SanityCheck() { | Status CacheServer::Builder::SanityCheck() { | ||||
| if (shared_memory_sz_in_gb_ <= 0) { | if (shared_memory_sz_in_gb_ <= 0) { | ||||
| RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive"); | RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive"); | ||||
| @@ -673,6 +884,12 @@ Status CacheServer::Builder::SanityCheck() { | |||||
| RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString()); | RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString()); | ||||
| } | } | ||||
| } | } | ||||
| if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) { | |||||
| RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1"); | |||||
| } | |||||
| // Check if the shared memory. | |||||
| RETURN_IF_NOT_OK(IpcResourceCleanup()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -17,6 +17,8 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ | ||||
| #include <string.h> | |||||
| #include <unistd.h> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <atomic> | #include <atomic> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -47,15 +49,16 @@ class CacheServer : public Service { | |||||
| using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>; | ||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {} | |||||
| Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {} | |||||
| ~Builder() = default; | ~Builder() = default; | ||||
| /// \brief Getter functions | /// \brief Getter functions | ||||
| const std::string &getTop() const { return top_; } | |||||
| int32_t getNumWorkers() const { return num_workers_; } | |||||
| int32_t getPort() const { return port_; } | |||||
| int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } | |||||
| const std::string &GetTop() const { return top_; } | |||||
| int32_t GetNumWorkers() const { return num_workers_; } | |||||
| int32_t GetPort() const { return port_; } | |||||
| int32_t GetSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; } | |||||
| float GetMemoryCapRatio() const { return memory_cap_ratio_; } | |||||
| Builder &SetRootDirectory(std::string root) { | Builder &SetRootDirectory(std::string root) { | ||||
| top_ = std::move(root); | top_ = std::move(root); | ||||
| @@ -73,15 +76,20 @@ class CacheServer : public Service { | |||||
| shared_memory_sz_in_gb_ = sz; | shared_memory_sz_in_gb_ = sz; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| Builder &SetMemoryCapRatio(float ratio) { | |||||
| memory_cap_ratio_ = ratio; | |||||
| return *this; | |||||
| } | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| void Print(std::ostream &out) const { | void Print(std::ostream &out) const { | ||||
| out << "Summary of the cache server configuration\n" | out << "Summary of the cache server configuration\n" | ||||
| << "Spill directory: " << getTop() << "\n" | |||||
| << "Number of parallel workers: " << getNumWorkers() << "\n" | |||||
| << "Tcp/ip port: " << getPort() << "\n" | |||||
| << "Shared memory size (in GB): " << getSharedMemorySzInGb(); | |||||
| << "Spill directory: " << GetTop() << "\n" | |||||
| << "Number of parallel workers: " << GetNumWorkers() << "\n" | |||||
| << "Tcp/ip port: " << GetPort() << "\n" | |||||
| << "Shared memory size (in GB): " << GetSharedMemorySzInGb() << "\n" | |||||
| << "Memory cap ratio: " << GetMemoryCapRatio(); | |||||
| } | } | ||||
| friend std::ostream &operator<<(std::ostream &out, const Builder &bld) { | friend std::ostream &operator<<(std::ostream &out, const Builder &bld) { | ||||
| @@ -93,7 +101,8 @@ class CacheServer : public Service { | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | RETURN_IF_NOT_OK(SanityCheck()); | ||||
| // We need to bring up the Task Manager by bringing up the Services singleton. | // We need to bring up the Task Manager by bringing up the Services singleton. | ||||
| RETURN_IF_NOT_OK(Services::CreateInstance()); | RETURN_IF_NOT_OK(Services::CreateInstance()); | ||||
| RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_)); | |||||
| RETURN_IF_NOT_OK( | |||||
| CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_, memory_cap_ratio_)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -102,20 +111,27 @@ class CacheServer : public Service { | |||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t port_; | int32_t port_; | ||||
| int32_t shared_memory_sz_in_gb_; | int32_t shared_memory_sz_in_gb_; | ||||
| float memory_cap_ratio_; | |||||
| /// \brief Sanity checks on the shared memory. | |||||
| /// \return Status object | |||||
| Status IpcResourceCleanup(); | |||||
| }; | }; | ||||
| CacheServer(const CacheServer &) = delete; | CacheServer(const CacheServer &) = delete; | ||||
| CacheServer &operator=(const CacheServer &) = delete; | CacheServer &operator=(const CacheServer &) = delete; | ||||
| CacheServer(CacheServer &&) = delete; | CacheServer(CacheServer &&) = delete; | ||||
| CacheServer &operator=(CacheServer &) = delete; | CacheServer &operator=(CacheServer &) = delete; | ||||
| Status DoServiceStart() override; | Status DoServiceStart() override; | ||||
| Status DoServiceStop() override; | Status DoServiceStop() override; | ||||
| ~CacheServer() { (void)ServiceStop(); } | |||||
| ~CacheServer() override { (void)ServiceStop(); } | |||||
| static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port, | static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port, | ||||
| int32_t shared_memory_sz) { | |||||
| int32_t shared_memory_sz, float memory_cap_ratio) { | |||||
| std::call_once(init_instance_flag_, [&]() -> Status { | std::call_once(init_instance_flag_, [&]() -> Status { | ||||
| auto &svcManager = Services::GetInstance(); | |||||
| RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz)); | |||||
| auto &SvcManager = Services::GetInstance(); | |||||
| RETURN_IF_NOT_OK( | |||||
| SvcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz, memory_cap_ratio)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| }); | }); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -133,7 +149,7 @@ class CacheServer : public Service { | |||||
| } | } | ||||
| /// \\brief Kick off server threads. Never return unless error out. | /// \\brief Kick off server threads. Never return unless error out. | ||||
| Status Run(); | |||||
| Status Run(SharedMessage::queue_id_t msg_qid); | |||||
| /// \brief Get a free tag | /// \brief Get a free tag | ||||
| /// \param q[in] pointer to a pointer to a CacheServerRequest | /// \param q[in] pointer to a pointer to a CacheServerRequest | ||||
| @@ -145,13 +161,35 @@ class CacheServer : public Service { | |||||
| /// \return Status object | /// \return Status object | ||||
| static Status ReturnRequestTag(CacheServerRequest *p); | static Status ReturnRequestTag(CacheServerRequest *p); | ||||
| /// \brief This returns the size (in bytes) of the physical RAM on the machine. | |||||
| /// \return the size (in bytes) of the physical RAM on the machine. | |||||
| static int64_t GetTotalSystemMemory(); | |||||
| /// \brief Internally this is how much we will try to use without exceeding the limit | |||||
| /// \return Internal cap maximum | |||||
| int64_t GetAvailableSystemMemory() { return memory_cap_; } | |||||
| /// \brief Find out the current memory usage | |||||
| int64_t GetMemoryUsage() { return cur_mem_usage_; } | |||||
| /// \brief This updates our current memory usage. | |||||
| enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 }; | |||||
| void UpdateMemoryUsage(int64_t sz, MemUsageOp op) { | |||||
| if (op == MemUsageOp::kAllocate) { | |||||
| cur_mem_usage_ += sz; | |||||
| } else { | |||||
| cur_mem_usage_ -= sz; | |||||
| } | |||||
| } | |||||
| private: | private: | ||||
| static std::once_flag init_instance_flag_; | static std::once_flag init_instance_flag_; | ||||
| static CacheServer *instance_; | static CacheServer *instance_; | ||||
| mutable RWLock rwLock_; | mutable RWLock rwLock_; | ||||
| mutable RWLock sessions_lock_; | |||||
| std::string top_; | std::string top_; | ||||
| cache_index all_caches_; | cache_index all_caches_; | ||||
| std::set<session_id_type> history_sessions_; | |||||
| std::map<session_id_type, std::set<connection_id_type>> active_sessions_; | |||||
| std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_; | std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_; | ||||
| std::shared_ptr<QueueList<CacheServerRequest *>> free_list_; | std::shared_ptr<QueueList<CacheServerRequest *>> free_list_; | ||||
| std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_; | std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_; | ||||
| @@ -162,11 +200,15 @@ class CacheServer : public Service { | |||||
| int32_t port_; | int32_t port_; | ||||
| int32_t shared_memory_sz_in_gb_; | int32_t shared_memory_sz_in_gb_; | ||||
| std::atomic<bool> global_shutdown_; | std::atomic<bool> global_shutdown_; | ||||
| float memory_cap_ratio_; | |||||
| int64_t memory_cap_; | |||||
| std::atomic<int64_t> cur_mem_usage_; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param spill_path Top directory for spilling buffers to. | /// \param spill_path Top directory for spilling buffers to. | ||||
| /// \param num_workers Number of threads for handling requests. | /// \param num_workers Number of threads for handling requests. | ||||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb); | |||||
| explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb, | |||||
| float memory_cap_ratio); | |||||
| /// \brief Locate a cache service from connection id. | /// \brief Locate a cache service from connection id. | ||||
| /// \return Pointer to cache service. Null if not found | /// \return Pointer to cache service. Null if not found | ||||
| @@ -179,11 +221,9 @@ class CacheServer : public Service { | |||||
| Status CreateService(CacheRequest *rq, CacheReply *reply); | Status CreateService(CacheRequest *rq, CacheReply *reply); | ||||
| /// \brief Destroy a cache service | /// \brief Destroy a cache service | ||||
| /// \param cs | |||||
| /// \param rq | /// \param rq | ||||
| /// \return | |||||
| Status DestroyCache(CacheService *cs, CacheRequest *rq); | |||||
| Status PurgeCache(CacheService *cs); | |||||
| /// \return Status object | |||||
| Status DestroyCache(CacheRequest *rq); | |||||
| /// \brief Entry point for all internal server threads. | /// \brief Entry point for all internal server threads. | ||||
| Status ServerRequest(int32_t worker_id); | Status ServerRequest(int32_t worker_id); | ||||
| @@ -207,7 +247,7 @@ class CacheServer : public Service { | |||||
| /// \brief Generate a session ID for the client | /// \brief Generate a session ID for the client | ||||
| /// \return Session ID | /// \return Session ID | ||||
| session_id_type GenerateSessionID() const; | |||||
| session_id_type GenerateSessionID(); | |||||
| /// \brief Handle kAllocateSharedBlock request | /// \brief Handle kAllocateSharedBlock request | ||||
| /// \param rq CacheRequest | /// \param rq CacheRequest | ||||
| @@ -220,20 +260,55 @@ class CacheServer : public Service { | |||||
| /// \return Status object | /// \return Status object | ||||
| Status FreeSharedMemory(CacheRequest *rq); | Status FreeSharedMemory(CacheRequest *rq); | ||||
| /// \brief Handle kFastCacheRow request | |||||
| /// \brief Handle CacheRow request | |||||
| /// \note There are two different implementation depends if shared memory is used for transportation. | |||||
| /// \return Status object | /// \return Status object | ||||
| Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply); | |||||
| Status FastCacheRow(CacheRequest *rq, CacheReply *reply); | |||||
| Status CacheRow(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Internal function to do row batch fetch | /// \brief Internal function to do row batch fetch | ||||
| /// \param cs CacheService | |||||
| /// \param rq Request | /// \param rq Request | ||||
| /// \param reply Reply | /// \param reply Reply | ||||
| /// \return | |||||
| Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply); | |||||
| /// \return Status object | |||||
| Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Internal function to get statistics | |||||
| /// \param rq | |||||
| /// \param reply | |||||
| /// \return Status object | |||||
| Status GetStat(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Cache a schema request | |||||
| /// \param rq | |||||
| /// \return Status object | |||||
| Status CacheSchema(CacheRequest *rq); | |||||
| /// \brief Fetch a schema request | |||||
| /// \param rq | |||||
| /// \param reply | |||||
| /// \return Status object | |||||
| Status FetchSchema(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Mark Build phase done (for non-mappable case) | |||||
| /// \param rq | |||||
| /// \return Status object | |||||
| Status BuildPhaseDone(CacheRequest *rq); | |||||
| /// \brief A proper shutdown of the server | /// \brief A proper shutdown of the server | ||||
| /// \return Status object | /// \return Status object | ||||
| Status GlobalShutdown(); | Status GlobalShutdown(); | ||||
| /// \brief Find keys that will be cache miss | |||||
| /// \return Status object | |||||
| Status GetCacheMissKeys(CacheRequest *rq, CacheReply *reply); | |||||
| /// \brief Toggle write mode for a service | |||||
| Status ToggleWriteMode(CacheRequest *rq); | |||||
| /// \brief List the sessions and their caches | |||||
| /// \param reply | |||||
| /// \return Status object | |||||
| Status ListSessions(CacheReply *reply); | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,6 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/engine/cache/cache_service.h" | #include "minddata/dataset/engine/cache/cache_service.h" | ||||
| #include "minddata/dataset/engine/cache/cache_server.h" | |||||
| #include "minddata/dataset/util/slice.h" | #include "minddata/dataset/util/slice.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -22,42 +23,62 @@ CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool genera | |||||
| : root_(root), | : root_(root), | ||||
| cache_mem_sz_(mem_sz), | cache_mem_sz_(mem_sz), | ||||
| cp_(nullptr), | cp_(nullptr), | ||||
| map_(nullptr), | |||||
| next_id_(0), | next_id_(0), | ||||
| generate_id_(generate_id), | generate_id_(generate_id), | ||||
| schema_key_(-1), | |||||
| st_(generate_id ? State::kBuildPhase : State::kNone) {} | |||||
| st_(generate_id ? State::kBuildPhase : State::kNone), | |||||
| cur_mem_usage_(0), | |||||
| cur_disk_usage_(0) {} | |||||
| CacheService::~CacheService() { (void)ServiceStop(); } | CacheService::~CacheService() { (void)ServiceStop(); } | ||||
| bool CacheService::UseArena() { | bool CacheService::UseArena() { | ||||
| // If fixed size, use Arena instead of the pool from global context. | // If fixed size, use Arena instead of the pool from global context. | ||||
| return (cache_mem_sz_ > 0); | return (cache_mem_sz_ > 0); | ||||
| } | } | ||||
| Status CacheService::DoServiceStart() { | Status CacheService::DoServiceStart() { | ||||
| std::shared_ptr<MemoryPool> mp_; | std::shared_ptr<MemoryPool> mp_; | ||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| if (UseArena()) { | if (UseArena()) { | ||||
| auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L; | |||||
| if (cache_mem_sz_ > avail_mem) { | |||||
| // Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway. | |||||
| MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem | |||||
| << " MB"; | |||||
| } | |||||
| // Create a fixed size arena based on the parameter. | // Create a fixed size arena based on the parameter. | ||||
| std::shared_ptr<Arena> arena; | std::shared_ptr<Arena> arena; | ||||
| RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); | RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); | ||||
| mp_ = std::move(arena); | mp_ = std::move(arena); | ||||
| // update the global usage only. | |||||
| cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate); | |||||
| } else { | } else { | ||||
| // Unlimited size. Simply use a system pool. Another choice is CircularPool. | // Unlimited size. Simply use a system pool. Another choice is CircularPool. | ||||
| mp_ = std::make_shared<SystemPool>(); | mp_ = std::make_shared<SystemPool>(); | ||||
| } | } | ||||
| // Put together a CachePool for backing up the Tensor | // Put together a CachePool for backing up the Tensor | ||||
| cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), root_); | |||||
| cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), UseArena(), root_); | |||||
| RETURN_IF_NOT_OK(cp_->ServiceStart()); | RETURN_IF_NOT_OK(cp_->ServiceStart()); | ||||
| // Set up the B+ tree as well. But use the system pool instead. | |||||
| map_ = std::make_shared<row_map>(); | |||||
| // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. | // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. | ||||
| cookie_ = cp_->MyName(); | cookie_ = cp_->MyName(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::DoServiceStop() { | Status CacheService::DoServiceStop() { | ||||
| if (cp_ != nullptr) { | if (cp_ != nullptr) { | ||||
| RETURN_IF_NOT_OK(cp_->ServiceStop()); | RETURN_IF_NOT_OK(cp_->ServiceStop()); | ||||
| } | } | ||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| if (UseArena()) { | |||||
| cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage() | |||||
| << " bytes."; | |||||
| cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree); | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) { | Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) { | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| RETURN_UNEXPECTED_IF_NULL(row_id_generated); | RETURN_UNEXPECTED_IF_NULL(row_id_generated); | ||||
| @@ -66,6 +87,11 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||||
| // allow other to cache more rows. | // allow other to cache more rows. | ||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | ||||
| } | } | ||||
| if (st_ == State::kNoLocking) { | |||||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | |||||
| // return out of memory from now on. | |||||
| return Status(StatusCode::kOutOfMemory); | |||||
| } | |||||
| try { | try { | ||||
| // The first buffer is a flatbuffer which describes the rest of the buffers follow | // The first buffer is a flatbuffer which describes the rest of the buffers follow | ||||
| auto fb = buf.front(); | auto fb = buf.front(); | ||||
| @@ -86,6 +112,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||||
| *row_id_generated = msg->row_id(); | *row_id_generated = msg->row_id(); | ||||
| } | } | ||||
| auto size_of_this = msg->size_of_this(); | auto size_of_this = msg->size_of_this(); | ||||
| size_t total_sz = size_of_this; | |||||
| auto column_hdr = msg->column(); | auto column_hdr = msg->column(); | ||||
| // Number of tensor buffer should match the number of columns plus one. | // Number of tensor buffer should match the number of columns plus one. | ||||
| if (buf.size() != column_hdr->size() + 1) { | if (buf.size() != column_hdr->size() + 1) { | ||||
| @@ -99,16 +126,28 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type | |||||
| all_data.emplace_back(fb, size_of_this); | all_data.emplace_back(fb, size_of_this); | ||||
| for (auto i = 0; i < column_hdr->size(); ++i) { | for (auto i = 0; i < column_hdr->size(); ++i) { | ||||
| all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); | all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); | ||||
| total_sz += msg->data_sz()->Get(i); | |||||
| } | } | ||||
| // Now we cache the flat buffer. | |||||
| CachePool::key_type key; | |||||
| RETURN_IF_NOT_OK(cp_->Insert(all_data, &key)); | |||||
| Status rc = map_->DoInsert(*row_id_generated, key); | |||||
| // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. | |||||
| // Otherwise, we check how much (globally) how much we use and may simply spill to disk | |||||
| // directly. | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); | |||||
| Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly); | |||||
| if (rc == Status(StatusCode::kDuplicateKey)) { | if (rc == Status(StatusCode::kDuplicateKey)) { | ||||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | MS_LOG(DEBUG) << "Ignoring duplicate key."; | ||||
| } else { | } else { | ||||
| RETURN_IF_NOT_OK(rc); | RETURN_IF_NOT_OK(rc); | ||||
| } | } | ||||
| // All good, then update the memory usage local and global (if not using arena) | |||||
| if (write_to_disk_directly) { | |||||
| cur_disk_usage_ += total_sz; | |||||
| } else { | |||||
| cur_mem_usage_ += total_sz; | |||||
| if (!UseArena()) { | |||||
| cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| RETURN_STATUS_UNEXPECTED(e.what()); | RETURN_STATUS_UNEXPECTED(e.what()); | ||||
| @@ -123,6 +162,11 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ | |||||
| // allow other to cache more rows. | // allow other to cache more rows. | ||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | ||||
| } | } | ||||
| if (st_ == State::kNoLocking) { | |||||
| // We ignore write this request once we turn off locking on the B+ tree. So we will just | |||||
| // return out of memory from now on. | |||||
| return Status(StatusCode::kOutOfMemory); | |||||
| } | |||||
| try { | try { | ||||
| // If we don't need to generate id, we need to find it from the buffer. | // If we don't need to generate id, we need to find it from the buffer. | ||||
| if (generate_id_) { | if (generate_id_) { | ||||
| @@ -139,20 +183,33 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ | |||||
| } | } | ||||
| *row_id_generated = msg->row_id(); | *row_id_generated = msg->row_id(); | ||||
| } | } | ||||
| // Now we cache the flat buffer. | |||||
| CachePool::key_type key; | |||||
| RETURN_IF_NOT_OK(cp_->Insert({src}, &key)); | |||||
| Status rc = map_->DoInsert(*row_id_generated, key); | |||||
| // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. | |||||
| // Otherwise, we check how much (globally) how much we use and may simply spill to disk | |||||
| // directly. | |||||
| auto total_sz = src.GetSize(); | |||||
| CacheServer &cs = CacheServer::GetInstance(); | |||||
| bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); | |||||
| Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly); | |||||
| if (rc == Status(StatusCode::kDuplicateKey)) { | if (rc == Status(StatusCode::kDuplicateKey)) { | ||||
| MS_LOG(DEBUG) << "Ignoring duplicate key."; | MS_LOG(DEBUG) << "Ignoring duplicate key."; | ||||
| } else { | } else { | ||||
| RETURN_IF_NOT_OK(rc); | RETURN_IF_NOT_OK(rc); | ||||
| } | } | ||||
| // All good, then update the memory usage local and global (if not using arena) | |||||
| if (write_to_disk_directly) { | |||||
| cur_disk_usage_ += total_sz; | |||||
| } else { | |||||
| cur_mem_usage_ += total_sz; | |||||
| if (!UseArena()) { | |||||
| cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| RETURN_STATUS_UNEXPECTED(e.what()); | RETURN_STATUS_UNEXPECTED(e.what()); | ||||
| } | } | ||||
| } | } | ||||
| std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\nCache memory size: " << cs.cache_mem_sz_; | out << "\nCache memory size: " << cs.cache_mem_sz_; | ||||
| @@ -164,34 +221,29 @@ std::ostream &operator<<(std::ostream &out, const CacheService &cs) { | |||||
| } | } | ||||
| return out; | return out; | ||||
| } | } | ||||
| Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } | Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } | ||||
| Status CacheService::Purge() { | |||||
| // First we must lock exclusively. No one else can cache/restore anything. | |||||
| UniqueLock rw(&rw_lock_); | |||||
| RETURN_IF_NOT_OK(cp_->ServiceStop()); | |||||
| auto new_map = std::make_shared<row_map>(); | |||||
| map_.reset(); | |||||
| map_ = std::move(new_map); | |||||
| next_id_ = 0; | |||||
| RETURN_IF_NOT_OK(cp_->ServiceStart()); | |||||
| Status CacheService::FindKeysMiss(std::vector<row_id_type> *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| std::unique_lock<std::mutex> lock(get_key_miss_mux_); | |||||
| if (key_miss_results_ == nullptr) { | |||||
| // Just do it once. | |||||
| key_miss_results_ = std::make_shared<std::vector<row_id_type>>(); | |||||
| auto stat = cp_->GetStat(true); | |||||
| key_miss_results_->push_back(stat.min_key); | |||||
| key_miss_results_->push_back(stat.max_key); | |||||
| key_miss_results_->insert(key_miss_results_->end(), stat.gap.begin(), stat.gap.end()); | |||||
| } | |||||
| out->insert(out->end(), key_miss_results_->begin(), key_miss_results_->end()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::GetStat(CacheService::ServiceStat *out) { | Status CacheService::GetStat(CacheService::ServiceStat *out) { | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| if (st_ == State::kNone || st_ == State::kFetchPhase) { | |||||
| out->stat_ = cp_->GetStat(); | |||||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||||
| auto it = map_->begin(); | |||||
| if (it != map_->end()) { | |||||
| out->min_ = it.key(); | |||||
| auto end_it = map_->end(); | |||||
| --end_it; | |||||
| out->max_ = end_it.key(); | |||||
| } | |||||
| } else { | |||||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||||
| } | |||||
| out->stat_ = cp_->GetStat(); | |||||
| out->state_ = static_cast<ServiceStat::state_type>(st_); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -204,19 +256,12 @@ Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vecto | |||||
| *mem_sz = (num_elements + 1) * sizeof(int64_t); | *mem_sz = (num_elements + 1) * sizeof(int64_t); | ||||
| (*out).reserve(num_elements); | (*out).reserve(num_elements); | ||||
| for (auto row_id : v) { | for (auto row_id : v) { | ||||
| auto r = map_->Search(row_id); | |||||
| if (r.second) { | |||||
| auto &it = r.first; | |||||
| CachePool::key_type key = it.value(); | |||||
| auto sz = cp_->GetSize(key); | |||||
| if (sz == 0) { | |||||
| std::string errMsg = "Key not found: "; | |||||
| errMsg += std::to_string(key); | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| (*out).emplace_back(key, sz); | |||||
| auto sz = cp_->GetSize(row_id); | |||||
| if (sz > 0) { | |||||
| (*out).emplace_back(row_id, sz); | |||||
| (*mem_sz) += sz; | (*mem_sz) += sz; | ||||
| } else { | } else { | ||||
| // key not found | |||||
| (*out).emplace_back(-1, 0); | (*out).emplace_back(-1, 0); | ||||
| } | } | ||||
| } | } | ||||
| @@ -252,27 +297,19 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::ve | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::CacheSchema(const void *buf, int64_t len) { | Status CacheService::CacheSchema(const void *buf, int64_t len) { | ||||
| SharedLock rw(&rw_lock_); | |||||
| if (st_ == State::kFetchPhase) { | |||||
| // For this kind of cache service, once we are done with the build phase into fetch phase, we can't | |||||
| // allow other to cache more rows. | |||||
| RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); | |||||
| } | |||||
| // This is a special request and we need to remember where we store it. | |||||
| UniqueLock rw(&rw_lock_); | |||||
| // In case we are calling the same function from multiple threads, only | // In case we are calling the same function from multiple threads, only | ||||
| // the first one is considered. Rest is ignored. | // the first one is considered. Rest is ignored. | ||||
| CachePool::key_type cur_key = schema_key_; | |||||
| CachePool::key_type key; | |||||
| if (cur_key < 0) { | |||||
| RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key)); | |||||
| auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key); | |||||
| MS_LOG(DEBUG) << "Caching Schema. Result = " << result; | |||||
| if (schema_.empty()) { | |||||
| schema_.assign(static_cast<const char *>(buf), len); | |||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Caching Schema already done"; | MS_LOG(DEBUG) << "Caching Schema already done"; | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::FetchSchema(std::string *out) const { | Status CacheService::FetchSchema(std::string *out) const { | ||||
| SharedLock rw(&rw_lock_); | SharedLock rw(&rw_lock_); | ||||
| if (st_ == State::kBuildPhase) { | if (st_ == State::kBuildPhase) { | ||||
| @@ -283,32 +320,44 @@ Status CacheService::FetchSchema(std::string *out) const { | |||||
| // We are going to use std::string to allocate and hold the result which will be eventually | // We are going to use std::string to allocate and hold the result which will be eventually | ||||
| // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose | ||||
| // to minimize memory copy. | // to minimize memory copy. | ||||
| std::string mem; | |||||
| if (schema_key_ >= 0) { | |||||
| auto len = cp_->GetSize(schema_key_); | |||||
| try { | |||||
| mem.resize(len); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error"); | |||||
| } catch (const std::bad_alloc &e) { | |||||
| return Status(StatusCode::kOutOfMemory); | |||||
| } | |||||
| auto slice = WritableSlice(mem.data(), len); | |||||
| RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); | |||||
| std::string mem(schema_); | |||||
| if (!mem.empty()) { | |||||
| *out = std::move(mem); | *out = std::move(mem); | ||||
| } else { | } else { | ||||
| return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); | return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CacheService::BuildPhaseDone() { | Status CacheService::BuildPhaseDone() { | ||||
| if (HasBuildPhase()) { | if (HasBuildPhase()) { | ||||
| // Exclusive lock to switch phase | // Exclusive lock to switch phase | ||||
| UniqueLock rw(&rw_lock_); | UniqueLock rw(&rw_lock_); | ||||
| st_ = State::kFetchPhase; | st_ = State::kFetchPhase; | ||||
| cp_->SetLocking(false); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } else { | } else { | ||||
| RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); | RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); | ||||
| } | } | ||||
| } | } | ||||
| Status CacheService::ToggleWriteMode(bool on_off) { | |||||
| UniqueLock rw(&rw_lock_); | |||||
| if (HasBuildPhase()) { | |||||
| RETURN_STATUS_UNEXPECTED("Not applicable to non-mappable dataset"); | |||||
| } else { | |||||
| // If we stop accepting write request, we turn off locking for the | |||||
| // underlying B+ tree. All future write request we will return kOutOfMemory. | |||||
| if (st_ == State::kNone && !on_off) { | |||||
| st_ = State::kNoLocking; | |||||
| cp_->SetLocking(on_off); | |||||
| MS_LOG(WARNING) << "Locking mode is switched off."; | |||||
| } else if (st_ == State::kNoLocking && on_off) { | |||||
| st_ = State::kNone; | |||||
| cp_->SetLocking(on_off); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <atomic> | #include <atomic> | ||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | |||||
| #include <string> | #include <string> | ||||
| #include <type_traits> | #include <type_traits> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -44,9 +45,8 @@ using key_size_pair = std::pair<CachePool::key_type, size_t>; | |||||
| class CacheService : public Service { | class CacheService : public Service { | ||||
| public: | public: | ||||
| friend class CacheServer; | friend class CacheServer; | ||||
| using row_map = BPlusTree<row_id_type, CachePool::key_type>; | |||||
| enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase }; | |||||
| enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited | /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited | ||||
| @@ -97,11 +97,9 @@ class CacheService : public Service { | |||||
| class ServiceStat { | class ServiceStat { | ||||
| public: | public: | ||||
| using state_type = std::underlying_type<State>::type; | using state_type = std::underlying_type<State>::type; | ||||
| ServiceStat() : min_(0), max_(0), state_(0) {} | |||||
| ServiceStat() : state_(0) {} | |||||
| ~ServiceStat() = default; | ~ServiceStat() = default; | ||||
| CachePool::CacheStat stat_{}; | CachePool::CacheStat stat_{}; | ||||
| row_id_type min_; | |||||
| row_id_type max_; | |||||
| state_type state_; | state_type state_; | ||||
| }; | }; | ||||
| /// \brief Statistics for the current service | /// \brief Statistics for the current service | ||||
| @@ -117,9 +115,9 @@ class CacheService : public Service { | |||||
| /// \param out A contiguous memory that contains the serialized form of schema. | /// \param out A contiguous memory that contains the serialized form of schema. | ||||
| /// \return Status object | /// \return Status object | ||||
| Status FetchSchema(std::string *out) const; | Status FetchSchema(std::string *out) const; | ||||
| /// \brief Purge the content of a cache | |||||
| /// \brief Return a set of keys that are definitely cache miss | |||||
| /// \return Status object | /// \return Status object | ||||
| Status Purge(); | |||||
| Status FindKeysMiss(std::vector<row_id_type> *out); | |||||
| /// \brief Overload the << operator to print a cache service | /// \brief Overload the << operator to print a cache service | ||||
| /// \param out std::ostream | /// \param out std::ostream | ||||
| /// \param cs A cache service | /// \param cs A cache service | ||||
| @@ -136,19 +134,33 @@ class CacheService : public Service { | |||||
| /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. | /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. | ||||
| /// \return Status object | /// \return Status object | ||||
| Status BuildPhaseDone(); | Status BuildPhaseDone(); | ||||
| /// \brief Find out the current memory usage | |||||
| int64_t GetMemoryUsage() { return cur_mem_usage_; } | |||||
| /// \brief Find out the current disk usage | |||||
| int64_t GetDiskUsage() { return cur_disk_usage_; } | |||||
| /// \brief For kToggleWriteMode request | |||||
| Status ToggleWriteMode(bool on_off); | |||||
| private: | private: | ||||
| mutable RWLock rw_lock_; | mutable RWLock rw_lock_; | ||||
| std::string root_; | std::string root_; | ||||
| uint64_t cache_mem_sz_; | uint64_t cache_mem_sz_; | ||||
| std::shared_ptr<CachePool> cp_; | std::shared_ptr<CachePool> cp_; | ||||
| std::shared_ptr<row_map> map_; | |||||
| std::atomic<row_id_type> next_id_; | std::atomic<row_id_type> next_id_; | ||||
| bool generate_id_; | bool generate_id_; | ||||
| std::atomic<CachePool::key_type> schema_key_; | |||||
| std::string cookie_; | std::string cookie_; | ||||
| State st_; | State st_; | ||||
| std::string schema_; | |||||
| // If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it. | |||||
| // Otherwise we track how much is in memory and how much is on disk (if root_ is not empty). | |||||
| // We use them to control when we should stop caching in memory in the case when there is no | |||||
| // Arena. | |||||
| std::atomic<int64_t> cur_mem_usage_; | |||||
| std::atomic<int64_t> cur_disk_usage_; | |||||
| // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make | |||||
| // this request after we hit memory full or disk full. So the result is unlikely to change. | |||||
| std::mutex get_key_miss_mux_; | |||||
| std::shared_ptr<std::vector<row_id_type>> key_miss_results_; | |||||
| /// \brief Private function to generate a row id | /// \brief Private function to generate a row id | ||||
| /// \return Row id assigned. | /// \return Row id assigned. | ||||
| row_id_type GetNextRowId() { return next_id_.fetch_add(1); } | row_id_type GetNextRowId() { return next_id_.fetch_add(1); } | ||||
| @@ -92,3 +92,13 @@ table CreateCacheReplyMsg { | |||||
| connection_id:int64; | connection_id:int64; | ||||
| cookie:string; | cookie:string; | ||||
| } | } | ||||
| table ListSessionMsg { | |||||
| session_id:uint32; | |||||
| connection_id:uint64; | |||||
| stats:ServiceStatMsg; | |||||
| } | |||||
| table ListSessionsMsg { | |||||
| sessions:[ListSessionMsg]; | |||||
| } | |||||
| @@ -53,22 +53,35 @@ CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t row | |||||
| num_cache_miss_(0), | num_cache_miss_(0), | ||||
| cache_client_(std::move(cache_client)), | cache_client_(std::move(cache_client)), | ||||
| rows_per_buffer_(rows_per_buf), | rows_per_buffer_(rows_per_buf), | ||||
| // We can cause deadlock if this internal Connector size is too small. | |||||
| keys_miss_(num_workers_, 1, connector_capacity_), | |||||
| prefetch_size_(cache_client_->getPrefetchSize()) { | |||||
| prefetch_size_(rows_per_buffer_), | |||||
| num_prefetchers_(num_workers_) { | |||||
| // Adjust the prefetch size based on the number of workers. | |||||
| auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_; | |||||
| if (prefetch_size_ < prefetch_sz_per_thread) { | |||||
| prefetch_size_ = prefetch_sz_per_thread; | |||||
| MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_; | |||||
| } | |||||
| io_block_queues_.Init(num_workers, op_connector_size); | io_block_queues_.Init(num_workers, op_connector_size); | ||||
| prefetch_queues_.Init(num_workers, op_connector_size); | |||||
| sampler_queue_ = std::make_unique<Queue<std::shared_ptr<Tensor>>>(op_connector_size); | |||||
| prefetch_queues_.Init(num_prefetchers_, op_connector_size); | |||||
| // We can cause deadlock if this internal Connector size is too small. | |||||
| keys_miss_ = std::make_unique<Connector<std::vector<row_id_type>>>(num_prefetchers_, 1, connector_capacity_); | |||||
| } | } | ||||
| // Common function to fetch samples from the sampler and send them using the io_block_queues to | // Common function to fetch samples from the sampler and send them using the io_block_queues to | ||||
| // the parallel workers | // the parallel workers | ||||
| Status CacheBase::FetchSamplesToWorkers() { | Status CacheBase::FetchSamplesToWorkers() { | ||||
| int64_t buf_cnt = 0; | int64_t buf_cnt = 0; | ||||
| int64_t wait_cnt = 0; | int64_t wait_cnt = 0; | ||||
| int64_t prefetch_cnt = 0; | |||||
| // Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_ | // Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_ | ||||
| // is too small (1 by default) and won't help performance. | // is too small (1 by default) and won't help performance. | ||||
| RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this))); | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); | |||||
| RETURN_IF_NOT_OK( | |||||
| tree_->LaunchWorkers(num_prefetchers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1))); | |||||
| auto send_to_que = [](QueueList<std::unique_ptr<IOBlock>> &qList, int32_t worker_id, | |||||
| std::vector<row_id_type> &keys) -> Status { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(qList[worker_id]->Add(std::move(blk))); | |||||
| return Status::OK(); | |||||
| }; | |||||
| // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them | // Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them | ||||
| // to the WorkerEntry. | // to the WorkerEntry. | ||||
| do { | do { | ||||
| @@ -82,33 +95,54 @@ Status CacheBase::FetchSamplesToWorkers() { | |||||
| ++wait_cnt; | ++wait_cnt; | ||||
| std::vector<row_id_type> keys; | std::vector<row_id_type> keys; | ||||
| keys.reserve(rows_per_buffer_); | keys.reserve(rows_per_buffer_); | ||||
| std::vector<row_id_type> prefetch_keys; | |||||
| prefetch_keys.reserve(prefetch_size_); | |||||
| std::unique_ptr<DataBuffer> sampler_buffer; | std::unique_ptr<DataBuffer> sampler_buffer; | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | ||||
| while (!sampler_buffer->eoe()) { | while (!sampler_buffer->eoe()) { | ||||
| TensorRow sample_row; | TensorRow sample_row; | ||||
| RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); | ||||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | std::shared_ptr<Tensor> sample_ids = sample_row[0]; | ||||
| // Send the sampler tensor to other thread for prefetching. We are using shared pointer so it | |||||
| // won't go out scope until it is really not in use. | |||||
| RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids)); | |||||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | ||||
| keys.push_back(*itr); | |||||
| ++row_cnt_; | ++row_cnt_; | ||||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||||
| keys.clear(); | |||||
| prefetch_keys.push_back(*itr); | |||||
| // Batch enough rows for performance reason. | |||||
| if (row_cnt_ % prefetch_size_ == 0) { | |||||
| RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys)); | |||||
| // Now we tell the WorkerEntry to wait for them to come back. If prefetch_size_ is a multiple | |||||
| // of rows_per_buffer_, the keys vector will always be empty. But it can be partially filled. | |||||
| // The only requirement we set up is rows_per_buffer_ is less than or equal to prefetch_size_. | |||||
| for (auto row_id : prefetch_keys) { | |||||
| keys.push_back(row_id); | |||||
| if (keys.size() == rows_per_buffer_) { | |||||
| RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); | |||||
| keys.clear(); | |||||
| } | |||||
| } | |||||
| prefetch_keys.clear(); | |||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); | ||||
| } | } | ||||
| // Deal with any partial keys left. | |||||
| if (!prefetch_keys.empty()) { | |||||
| RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys)); | |||||
| for (auto row_id : prefetch_keys) { | |||||
| keys.push_back(row_id); | |||||
| if (keys.size() == rows_per_buffer_) { | |||||
| RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); | |||||
| keys.clear(); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!keys.empty()) { | if (!keys.empty()) { | ||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||||
| RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); | |||||
| } | } | ||||
| // send the eoe | // send the eoe | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | ||||
| RETURN_IF_NOT_OK(prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add( | |||||
| std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); | |||||
| // If repeat but the not last repeat, wait for reset. | // If repeat but the not last repeat, wait for reset. | ||||
| if (!IsLastIteration()) { | if (!IsLastIteration()) { | ||||
| MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; | MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt; | ||||
| @@ -123,8 +157,6 @@ Status CacheBase::FetchSamplesToWorkers() { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof))); | ||||
| // Shutdown threads | // Shutdown threads | ||||
| std::shared_ptr<Tensor> empty; | |||||
| RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty))); | |||||
| for (int32_t i = 0; i < num_workers_; i++) { | for (int32_t i = 0; i < num_workers_; i++) { | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | ||||
| @@ -145,13 +177,6 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||||
| if (blk->eof()) { | if (blk->eof()) { | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); | RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))); | ||||
| } else if (blk->eoe()) { | } else if (blk->eoe()) { | ||||
| if (AllowCacheMiss()) { | |||||
| // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from | |||||
| // a sampler, send a eoe to physical leaf op as well. | |||||
| std::vector<row_id_type> eoe; | |||||
| eoe.push_back(eoe_row_id); | |||||
| RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))); | ||||
| } else { | } else { | ||||
| std::vector<int64_t> keys; | std::vector<int64_t> keys; | ||||
| @@ -162,22 +187,21 @@ Status CacheBase::FetchFromCache(int32_t worker_id) { | |||||
| } | } | ||||
| std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone); | std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone); | ||||
| std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>(); | std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>(); | ||||
| std::vector<row_id_type> cache_miss; | |||||
| cache_miss.reserve(keys.size()); | |||||
| for (auto row_id : keys) { | for (auto row_id : keys) { | ||||
| TensorRow row; | TensorRow row; | ||||
| // Block until the row shows up in the pool. | // Block until the row shows up in the pool. | ||||
| RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row)); | |||||
| RETURN_IF_NOT_OK(GetPrefetchRow(row_id, &row)); | |||||
| if (row.empty()) { | if (row.empty()) { | ||||
| cache_miss.push_back(row_id); | |||||
| if (AllowCacheMiss()) { | |||||
| ++num_cache_miss_; | |||||
| } else { | |||||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| } | } | ||||
| que->push_back(std::move(row)); | que->push_back(std::move(row)); | ||||
| } | } | ||||
| db->set_tensor_table(std::move(que)); | db->set_tensor_table(std::move(que)); | ||||
| if (AllowCacheMiss()) { | |||||
| // Because of the way connector works, we push unconditionally even cache_miss can be empty. | |||||
| RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); | RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); | ||||
| buffer_id += num_workers_; | buffer_id += num_workers_; | ||||
| } | } | ||||
| @@ -189,7 +213,6 @@ Status CacheBase::RegisterResources() { | |||||
| RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); | ||||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | ||||
| RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks())); | ||||
| RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks())); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -208,73 +231,97 @@ Status CacheBase::UpdateColumnMapFromCache() { | |||||
| return rc; | return rc; | ||||
| } | } | ||||
| Status CacheBase::Dispatcher() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| int64_t buf_cnt = 0; | |||||
| int64_t num_row = 0; | |||||
| std::vector<row_id_type> keys; | |||||
| keys.reserve(prefetch_size_); | |||||
| do { | |||||
| keys.clear(); | |||||
| std::shared_ptr<Tensor> sample_ids; | |||||
| RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids)); | |||||
| if (sample_ids == nullptr) { | |||||
| // A null shared pointer signal times to quit. | |||||
| // Also signal all prefetchers to quit. | |||||
| for (int32_t i = 0; i < num_workers_; i++) { | |||||
| RETURN_IF_NOT_OK( | |||||
| prefetch_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone))); | |||||
| } | |||||
| break; | |||||
| Status CacheBase::GetPrefetchRow(row_id_type row_id, TensorRow *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(row_id >= 0, "Expect positive row id"); | |||||
| RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, out)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheBase::PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss) { | |||||
| RETURN_UNEXPECTED_IF_NULL(cache_miss); | |||||
| std::vector<row_id_type> prefetch_keys; | |||||
| prefetch_keys.reserve(keys.size()); | |||||
| // Filter out all those keys that unlikely we will find at the server | |||||
| for (auto row_id : keys) { | |||||
| if (cache_client_->KeyIsCacheMiss(row_id)) { | |||||
| // Just put an empty row in the cache. | |||||
| TensorRow row; | |||||
| row.setId(row_id); | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| cache_miss->push_back(row_id); | |||||
| } else { | |||||
| prefetch_keys.push_back(row_id); | |||||
| } | } | ||||
| // Now we distribute the sampler ids to each prefetcher according to the prefetch size. | |||||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | |||||
| keys.push_back(*itr); | |||||
| ++num_row; | |||||
| if (num_row % prefetch_size_ == 0) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||||
| keys.clear(); | |||||
| } | |||||
| // Early exit if nothing to fetch | |||||
| if (prefetch_keys.empty()) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // Get the rows from the server | |||||
| TensorTable ttbl; | |||||
| Status rc = cache_client_->GetRows(prefetch_keys, &ttbl); | |||||
| if (rc.IsOk()) { | |||||
| auto row_it = ttbl.begin(); | |||||
| for (auto row_id : prefetch_keys) { | |||||
| auto &row = *row_it; | |||||
| if (row.empty()) { | |||||
| cache_miss->push_back(row_id); | |||||
| } | } | ||||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| ++row_it; | |||||
| } | } | ||||
| // Send the remaining sample id | |||||
| if (!keys.empty()) { | |||||
| auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); | |||||
| RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); | |||||
| } else { | |||||
| // In case any thread is waiting for the rows to come back and blocked on a semaphore, | |||||
| // we will put an empty row in the local cache. | |||||
| for (auto row_id : prefetch_keys) { | |||||
| TensorRow row; | |||||
| row.setId(row_id); | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| cache_miss->push_back(row_id); | |||||
| } | } | ||||
| } while (true); | |||||
| return Status::OK(); | |||||
| } | |||||
| return rc; | |||||
| } | } | ||||
| Status CacheBase::Prefetcher(int32_t worker_id) { | Status CacheBase::Prefetcher(int32_t worker_id) { | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| std::vector<row_id_type> prefetch_keys; | std::vector<row_id_type> prefetch_keys; | ||||
| prefetch_keys.reserve(prefetch_size_); | prefetch_keys.reserve(prefetch_size_); | ||||
| std::vector<row_id_type> cache_miss; | |||||
| cache_miss.reserve(prefetch_size_); | |||||
| do { | do { | ||||
| prefetch_keys.clear(); | prefetch_keys.clear(); | ||||
| cache_miss.clear(); | |||||
| std::unique_ptr<IOBlock> blk; | std::unique_ptr<IOBlock> blk; | ||||
| RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk)); | RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk)); | ||||
| RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); | |||||
| if (prefetch_keys.empty()) { | |||||
| // Empty keys mean time to quit. | |||||
| break; | |||||
| } | |||||
| TensorTable ttbl; | |||||
| RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); | |||||
| auto row_it = ttbl.begin(); | |||||
| for (auto row_id : prefetch_keys) { | |||||
| auto &row = *row_it; | |||||
| if (row.empty()) { | |||||
| if (AllowCacheMiss()) { | |||||
| ++num_cache_miss_; | |||||
| } else { | |||||
| std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!blk->eof(), "Expect eoe or a regular io block"); | |||||
| if (!blk->eoe()) { | |||||
| RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys)); | |||||
| Status rc; | |||||
| const int32_t max_retries = 5; | |||||
| int32_t retry_count = 0; | |||||
| do { | |||||
| rc = PrefetchRows(prefetch_keys, &cache_miss); | |||||
| if (rc.IsNetWorkError() && retry_count < max_retries) { | |||||
| // If we get some network error, we will attempt some retries | |||||
| retry_count++; | |||||
| } else if (rc.IsError()) { | |||||
| return rc; | |||||
| } | } | ||||
| } while (rc.IsNetWorkError()); | |||||
| } else { | |||||
| if (AllowCacheMiss()) { | |||||
| // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from | |||||
| // a sampler, send a eoe to physical leaf op as well. | |||||
| cache_miss.push_back(eoe_row_id); | |||||
| } | } | ||||
| // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row | |||||
| RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); | |||||
| ++row_it; | |||||
| } | |||||
| if (AllowCacheMiss()) { | |||||
| // Because of the way connector works, we push unconditionally even cache_miss can be empty. | |||||
| RETURN_IF_NOT_OK(keys_miss_->Push(worker_id, cache_miss)); | |||||
| } | } | ||||
| } while (true); | } while (true); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/connector.h" | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/engine/cache/cache_service.h" | #include "minddata/dataset/engine/cache/cache_service.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| @@ -90,8 +91,7 @@ class CacheBase : public ParallelOp { | |||||
| std::shared_ptr<CacheClient> cache_client_; | std::shared_ptr<CacheClient> cache_client_; | ||||
| WaitPost epoch_sync_; | WaitPost epoch_sync_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| Connector<std::vector<row_id_type>> keys_miss_; | |||||
| QueueMap<row_id_type, TensorRow> prefetch_; | |||||
| std::unique_ptr<Connector<std::vector<row_id_type>>> keys_miss_; | |||||
| /// \brief Common function to register resources for interrupt | /// \brief Common function to register resources for interrupt | ||||
| /// \note Derived should override this function for extra resources to be registered | /// \note Derived should override this function for extra resources to be registered | ||||
| @@ -111,13 +111,16 @@ class CacheBase : public ParallelOp { | |||||
| constexpr static int32_t connector_capacity_ = 1024; | constexpr static int32_t connector_capacity_ = 1024; | ||||
| int32_t prefetch_size_; | int32_t prefetch_size_; | ||||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | ||||
| int32_t num_prefetchers_; | |||||
| QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; | QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; | ||||
| std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_; | |||||
| QueueMap<row_id_type, TensorRow> prefetch_; | |||||
| Status Dispatcher(); | |||||
| /// \brief Prefetcher. It prefetch the rows from cache server | /// \brief Prefetcher. It prefetch the rows from cache server | ||||
| /// \return Status object. | /// \return Status object. | ||||
| Status Prefetcher(int32_t worker_id); | Status Prefetcher(int32_t worker_id); | ||||
| /// \brief Functions used by prefetcher and WorkerEntry | |||||
| Status PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss); | |||||
| Status GetPrefetchRow(row_id_type row_id, TensorRow *out); | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -87,10 +87,10 @@ Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } | |||||
| void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | ||||
| Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | ||||
| std::vector<row_id_type> cache_miss; | std::vector<row_id_type> cache_miss; | ||||
| RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); | |||||
| RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); | |||||
| // Ignore the case we have no cache miss, we can't return empty samples. | // Ignore the case we have no cache miss, we can't return empty samples. | ||||
| while (cache_miss.empty()) { | while (cache_miss.empty()) { | ||||
| RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); | |||||
| RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); | |||||
| } | } | ||||
| // Special code for eoe | // Special code for eoe | ||||
| if (cache_miss.at(0) == eoe_row_id) { | if (cache_miss.at(0) == eoe_row_id) { | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/util/system_pool.h" | |||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -48,7 +49,8 @@ CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t | |||||
| std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler) | ||||
| : ParallelOp(numWorkers, opConnectorSize, sampler), | : ParallelOp(numWorkers, opConnectorSize, sampler), | ||||
| num_cleaners_(numCleaners), | num_cleaners_(numCleaners), | ||||
| cache_client_(std::move(cache_client)) {} | |||||
| cache_client_(std::move(cache_client)), | |||||
| cache_missing_rows_(true) {} | |||||
| Status CacheMergeOp::operator()() { | Status CacheMergeOp::operator()() { | ||||
| // A queue of row id to let cleaner send cache miss rows to the cache server | // A queue of row id to let cleaner send cache miss rows to the cache server | ||||
| @@ -129,17 +131,19 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { | |||||
| std::string errMsg = "Expect positive row id: " + std::to_string(row_id); | std::string errMsg = "Expect positive row id: " + std::to_string(row_id); | ||||
| RETURN_STATUS_UNEXPECTED(errMsg); | RETURN_STATUS_UNEXPECTED(errMsg); | ||||
| } | } | ||||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||||
| TensorRowCacheRequest *rq; | |||||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||||
| if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { | |||||
| // We will send the request async. But any error we most | |||||
| // likely ignore and continue. | |||||
| Status rc; | |||||
| rc = rq->AsyncSendCacheRequest(cache_client_, row); | |||||
| if (rc.IsOk()) { | |||||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||||
| if (cache_missing_rows_) { | |||||
| // Technically number of this row shows up in the cache miss stream is equal to the number | |||||
| // of P() call. However the cleaner wants it too. So we need an extra copy. | |||||
| TensorRowCacheRequest *rq; | |||||
| RETURN_IF_NOT_OK(GetRq(row_id, &rq)); | |||||
| if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) { | |||||
| // We will send the request async. But any error we most | |||||
| // likely ignore and continue. | |||||
| Status rc; | |||||
| rc = rq->AsyncSendCacheRequest(cache_client_, row); | |||||
| if (rc.IsOk()) { | |||||
| RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row))); | RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row))); | ||||
| @@ -168,13 +172,18 @@ Status CacheMergeOp::Cleaner() { | |||||
| Status rc = rq->CheckCacheResult(); | Status rc = rq->CheckCacheResult(); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| // If interrupt, time to quit. | // If interrupt, time to quit. | ||||
| if (rc.get_code() == StatusCode::kInterrupted) { | |||||
| if (rc.IsInterrupted()) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } else if (rc.IsOutofMemory() || rc.IsNoSpace()) { | |||||
| // The server is hitting some limit and we will turn off caching from now on. | |||||
| cache_missing_rows_ = false; | |||||
| cache_client_->ServerRunningOutOfResources(); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); | |||||
| // Bad rc should not bring down the pipeline. We will simply continue and | |||||
| // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. | |||||
| rq->SetState(TensorRowCacheRequest::State::kEmpty); | |||||
| } | } | ||||
| MS_LOG(INFO) << "Cache row not successful: " << rc.ToString(); | |||||
| // Bad rc should not bring down the pipeline. We will simply continue and | |||||
| // change the state back to empty. We don't need a CAS from CLEAN back to EMPTY. | |||||
| rq->SetState(TensorRowCacheRequest::State::kEmpty); | |||||
| } | } | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -253,7 +262,7 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) { | |||||
| Status CacheMergeOp::EoeReceived(int32_t worker_id) { | Status CacheMergeOp::EoeReceived(int32_t worker_id) { | ||||
| // If we are in a repeat path, send the eoe up. | // If we are in a repeat path, send the eoe up. | ||||
| // Otherwise ignore it. | // Otherwise ignore it. | ||||
| if (op_total_repeats_ > 1) { | |||||
| if (op_total_repeats_ != 1) { | |||||
| return DatasetOp::EoeReceived(worker_id); | return DatasetOp::EoeReceived(worker_id); | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -281,7 +290,7 @@ Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheReque | |||||
| *out = it->second.GetMutablePointer(); | *out = it->second.GetMutablePointer(); | ||||
| } else { | } else { | ||||
| // We will create a new one. | // We will create a new one. | ||||
| auto alloc = Services::GetAllocator<TensorRowCacheRequest>(); | |||||
| auto alloc = SystemPool::GetAllocator<TensorRowCacheRequest>(); | |||||
| auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc)); | auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc)); | ||||
| if (r.second) { | if (r.second) { | ||||
| auto &mem = r.first->second; | auto &mem = r.first->second; | ||||
| @@ -202,6 +202,7 @@ class CacheMergeOp : public ParallelOp { | |||||
| std::unique_ptr<Queue<row_id_type>> io_que_; | std::unique_ptr<Queue<row_id_type>> io_que_; | ||||
| std::shared_ptr<CacheClient> cache_client_; | std::shared_ptr<CacheClient> cache_client_; | ||||
| int32_t num_cleaners_; | int32_t num_cleaners_; | ||||
| std::atomic<bool> cache_missing_rows_; | |||||
| /// \brief Locate the cache request from the io_request_ map | /// \brief Locate the cache request from the io_request_ map | ||||
| /// \param row_id | /// \param row_id | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | #include "minddata/dataset/engine/datasetops/cache_op.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| @@ -64,7 +65,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) { | |||||
| // Constructor of CacheOp | // Constructor of CacheOp | ||||
| CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, | ||||
| std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler) | ||||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), | |||||
| : CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), | |||||
| num_guys_in_(0), | num_guys_in_(0), | ||||
| phase_(Phase::kBuildPhase) {} | phase_(Phase::kBuildPhase) {} | ||||
| @@ -174,7 +175,7 @@ Status CacheOp::WorkerEntry(int32_t worker_id) { | |||||
| Status CacheOp::RegisterResources() { | Status CacheOp::RegisterResources() { | ||||
| RETURN_IF_NOT_OK(CacheBase::RegisterResources()); | RETURN_IF_NOT_OK(CacheBase::RegisterResources()); | ||||
| RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); | ||||
| RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks())); | |||||
| RETURN_IF_NOT_OK(keys_miss_->Register(tree_->AllTasks())); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/engine/data_buffer.h" | #include "minddata/dataset/engine/data_buffer.h" | ||||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | #include "minddata/dataset/engine/datasetops/concat_op.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| @@ -188,5 +189,11 @@ Status ConcatOp::ComputeColMap() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor pre-accept method for NodePass | |||||
| Status ConcatOp::PreAccept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -105,6 +105,12 @@ class ConcatOp : public PipelineOp { | |||||
| // @return - Status | // @return - Status | ||||
| Status ComputeColMap() override; | Status ComputeColMap() override; | ||||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status PreAccept(NodePass *p, bool *modified) override; | |||||
| private: | private: | ||||
| Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf); | Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf); | ||||
| @@ -243,6 +243,7 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_ | out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_ | ||||
| << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; | << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; | ||||
| if (sampler_) { | if (sampler_) { | ||||
| out << "\nSampler:\n"; | |||||
| sampler_->Print(out, show_all); | sampler_->Print(out, show_all); | ||||
| } | } | ||||
| } | } | ||||
| @@ -268,5 +268,11 @@ Status FilterOp::Accept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| return p->RunOnNode(shared_from_base<FilterOp>(), modified); | return p->RunOnNode(shared_from_base<FilterOp>(), modified); | ||||
| } | } | ||||
| // Visitor pre-accept method for NodePass | |||||
| Status FilterOp::PreAccept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->PreRunOnNode(shared_from_base<FilterOp>(), modified); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -121,6 +121,12 @@ class FilterOp : public ParallelOp { | |||||
| // @param show_all A bool to control if you want to show all info or just a summary. | // @param show_all A bool to control if you want to show all info or just a summary. | ||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status PreAccept(NodePass *p, bool *modified) override; | |||||
| // Base-class override for NodePass visitor acceptor. | // Base-class override for NodePass visitor acceptor. | ||||
| // @param p - Pointer to the NodePass to be accepted. | // @param p - Pointer to the NodePass to be accepted. | ||||
| // @param modified - Whether this node visit modified the pipeline. | // @param modified - Whether this node visit modified the pipeline. | ||||
| @@ -458,6 +458,12 @@ Status MapOp::Accept(NodePass *p, bool *modified) { | |||||
| return p->RunOnNode(shared_from_base<MapOp>(), modified); | return p->RunOnNode(shared_from_base<MapOp>(), modified); | ||||
| } | } | ||||
| // Visitor pre-accept method for NodePass | |||||
| Status MapOp::PreAccept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->PreRunOnNode(shared_from_base<MapOp>(), modified); | |||||
| } | |||||
| Status MapOp::WaitForWorkers() { | Status MapOp::WaitForWorkers() { | ||||
| // reset num_paused workers to 0 | // reset num_paused workers to 0 | ||||
| num_workers_paused_ = 0; | num_workers_paused_ = 0; | ||||
| @@ -177,10 +177,16 @@ class MapOp : public ParallelOp { | |||||
| // @return the number of threads consuming data from previous op's output Connector. | // @return the number of threads consuming data from previous op's output Connector. | ||||
| int32_t num_consumers() const override; | int32_t num_consumers() const override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status PreAccept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for NodePass visitor acceptor. | |||||
| /// \param[in] p Pointer to the NodePass to be accepted. | |||||
| /// \param[out] modified Whether this node visit modified the pipeline. | |||||
| /// \return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | Status Accept(NodePass *p, bool *modified) override; | ||||
| // Op name getter | // Op name getter | ||||
| @@ -52,9 +52,9 @@ Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { | |||||
| void ParallelOp::Print(std::ostream &out, bool show_all) const { | void ParallelOp::Print(std::ostream &out, bool show_all) const { | ||||
| // Summary 1-liner print | // Summary 1-liner print | ||||
| if (!show_all) { | if (!show_all) { | ||||
| out << " [workers: " << num_workers_ << "]"; | |||||
| // Call super class printer | // Call super class printer | ||||
| DatasetOp::Print(out, show_all); | DatasetOp::Print(out, show_all); | ||||
| out << " [workers: " << num_workers_ << "]"; | |||||
| } else { | } else { | ||||
| // Detailed print | // Detailed print | ||||
| DatasetOp::Print(out, show_all); | DatasetOp::Print(out, show_all); | ||||
| @@ -27,14 +27,14 @@ PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampl | |||||
| void PipelineOp::Print(std::ostream &out, bool show_all) const { | void PipelineOp::Print(std::ostream &out, bool show_all) const { | ||||
| // Summary 1-liner print | // Summary 1-liner print | ||||
| if (!show_all) { | if (!show_all) { | ||||
| // Call super class printer | |||||
| DatasetOp::Print(out, show_all); | |||||
| out << " [workers: "; | out << " [workers: "; | ||||
| if (this->inlined()) { | if (this->inlined()) { | ||||
| out << "0 (inlined)]"; | out << "0 (inlined)]"; | ||||
| } else { | } else { | ||||
| out << "1]"; // Pipeline ops only have 1 worker | out << "1]"; // Pipeline ops only have 1 worker | ||||
| } | } | ||||
| // Call super class printer | |||||
| DatasetOp::Print(out, show_all); | |||||
| } else { | } else { | ||||
| // Detailed print | // Detailed print | ||||
| DatasetOp::Print(out, show_all); | DatasetOp::Print(out, show_all); | ||||
| @@ -235,6 +235,12 @@ Status ZipOp::EoeReceived(int32_t) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Visitor pre-accept method for NodePass | |||||
| Status ZipOp::PreAccept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->PreRunOnNode(shared_from_base<ZipOp>(), modified); | |||||
| } | |||||
| // Visitor accept method for NodePass | // Visitor accept method for NodePass | ||||
| Status ZipOp::Accept(NodePass *p, bool *modified) { | Status ZipOp::Accept(NodePass *p, bool *modified) { | ||||
| // Downcast shared pointer then call visitor | // Downcast shared pointer then call visitor | ||||
| @@ -104,10 +104,16 @@ class ZipOp : public PipelineOp { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||||
| /// \param[in] p The node to visit | |||||
| /// \param[out] modified Indicator if the node was modified | |||||
| /// \return Status of the node visit | |||||
| Status PreAccept(NodePass *p, bool *modified) override; | |||||
| /// \brief Base-class override for NodePass visitor acceptor. | |||||
| /// \param[in] p Pointer to the NodePass to be accepted. | |||||
| /// \param[out] modified Whether this node visit modified the pipeline. | |||||
| /// \return - Status of the node visit. | |||||
| Status Accept(NodePass *p, bool *modified) override; | Status Accept(NodePass *p, bool *modified) override; | ||||
| // Op name getter | // Op name getter | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | ||||
| #include "minddata/dataset/engine/opt/post/repeat_pass.h" | #include "minddata/dataset/engine/opt/post/repeat_pass.h" | ||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/opt/pre/cache_error_pass.h" | |||||
| #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | #include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h" | ||||
| #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | #include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" | ||||
| #include "minddata/dataset/engine/perf/profiling.h" | #include "minddata/dataset/engine/perf/profiling.h" | ||||
| @@ -235,6 +236,7 @@ Status ExecutionTree::PrepareTreePreAction() { | |||||
| std::vector<std::unique_ptr<Pass>> pre_actions; | std::vector<std::unique_ptr<Pass>> pre_actions; | ||||
| // Construct pre actions | // Construct pre actions | ||||
| MS_LOG(INFO) << "Running pre pass loops."; | MS_LOG(INFO) << "Running pre pass loops."; | ||||
| pre_actions.push_back(std::make_unique<CacheErrorPass>()); | |||||
| pre_actions.push_back(std::make_unique<EpochInjectionPass>()); | pre_actions.push_back(std::make_unique<EpochInjectionPass>()); | ||||
| pre_actions.push_back(std::make_unique<RemovalPass>()); | pre_actions.push_back(std::make_unique<RemovalPass>()); | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE | |||||
| add_library(engine-opt OBJECT | add_library(engine-opt OBJECT | ||||
| pass.cc | pass.cc | ||||
| post/repeat_pass.cc | post/repeat_pass.cc | ||||
| pre/cache_error_pass.cc | |||||
| pre/cache_transform_pass.cc | pre/cache_transform_pass.cc | ||||
| pre/epoch_injection_pass.cc | pre/epoch_injection_pass.cc | ||||
| pre/removal_pass.cc | pre/removal_pass.cc | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | ||||
| #endif | #endif | ||||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | #include "minddata/dataset/engine/datasetops/dataset_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | #include "minddata/dataset/engine/datasetops/device_queue_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | #include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" | ||||
| @@ -143,125 +144,122 @@ Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) { | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| #endif | |||||
| #ifdef ENABLE_PYTHON | |||||
| Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| #endif | |||||
| Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | } | ||||
| #endif | |||||
| Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| #endif | |||||
| Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| @@ -271,24 +269,38 @@ Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| #endif | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) { | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| #endif | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) { | |||||
| #ifdef ENABLE_PYTHON | |||||
| Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) { | |||||
| Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| Status NodePass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| #endif | #endif | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,18 +37,6 @@ class SkipOp; | |||||
| class ShuffleOp; | class ShuffleOp; | ||||
| #ifndef ENABLE_ANDROID | |||||
| class MindRecordOp; | |||||
| class TFReaderOp; | |||||
| #endif | |||||
| #ifdef ENABLE_PYTHON | |||||
| class FilterOp; | |||||
| class GeneratorOp; | |||||
| #endif | |||||
| class AlbumOp; | class AlbumOp; | ||||
| class RandomDataOp; | class RandomDataOp; | ||||
| @@ -63,10 +51,6 @@ class DeviceQueueOp; | |||||
| class ImageFolderOp; | class ImageFolderOp; | ||||
| #ifndef ENABLE_ANDROID | |||||
| class CacheOp; | |||||
| #endif | |||||
| class MnistOp; | class MnistOp; | ||||
| class ManifestOp; | class ManifestOp; | ||||
| @@ -79,18 +63,30 @@ class CocoOp; | |||||
| class CelebAOp; | class CelebAOp; | ||||
| class EpochCtrlOp; | |||||
| class BuildVocabOp; | |||||
| class ConcatOp; | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| class MindRecordOp; | |||||
| class TFReaderOp; | |||||
| class CacheOp; | |||||
| class CacheMergeOp; | class CacheMergeOp; | ||||
| class CacheLookupOp; | class CacheLookupOp; | ||||
| #endif | |||||
| class EpochCtrlOp; | |||||
| class BuildSentencePieceVocabOp; | |||||
| #endif | |||||
| class BuildVocabOp; | |||||
| #ifdef ENABLE_PYTHON | |||||
| class FilterOp; | |||||
| #ifndef ENABLE_ANDROID | |||||
| class BuildSentencePieceVocabOp; | |||||
| class GeneratorOp; | |||||
| #endif | #endif | ||||
| // The base class Pass is the basic unit of tree transformation. | // The base class Pass is the basic unit of tree transformation. | ||||
| @@ -168,22 +164,6 @@ class NodePass : public Pass { | |||||
| virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified); | ||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | |||||
| #endif | |||||
| #ifdef ENABLE_PYTHON | |||||
| virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | |||||
| #endif | |||||
| virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified); | ||||
| @@ -194,10 +174,6 @@ class NodePass : public Pass { | |||||
| virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified); | ||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||||
| #endif | |||||
| virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified); | ||||
| @@ -210,30 +186,48 @@ class NodePass : public Pass { | |||||
| virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified); | ||||
| #endif | |||||
| virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); | ||||
| #endif | |||||
| virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified); | |||||
| #endif | #endif | ||||
| virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified); | |||||
| #ifdef ENABLE_PYTHON | |||||
| virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | |||||
| virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified); | |||||
| virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified); | |||||
| virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified); | |||||
| #endif | #endif | ||||
| private: | private: | ||||
| @@ -225,13 +225,17 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { | |||||
| // Turns off the tracking for operations under merge op | // Turns off the tracking for operations under merge op | ||||
| Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) { | ||||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | |||||
| // would not have been consumed yet. In that case, we need to set its total repeats for it. | |||||
| if (cache_lookup_) { | |||||
| cache_lookup_->set_total_repeats(num_repeats_); | |||||
| cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| } | |||||
| // Setting the flag is needed since we didn't call the base class DatasetOp version | // Setting the flag is needed since we didn't call the base class DatasetOp version | ||||
| if (is_repeated_) { | if (is_repeated_) { | ||||
| // If there was not any repeat in the merge cache miss leg, then the cache_lookup | // If there was not any repeat in the merge cache miss leg, then the cache_lookup | ||||
| // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack | // would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack | ||||
| if (cache_lookup_) { | if (cache_lookup_) { | ||||
| cache_lookup_->set_total_repeats(num_repeats_); | |||||
| node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_); | |||||
| AddToEOEOpStack(std::move(cache_lookup_)); | AddToEOEOpStack(std::move(cache_lookup_)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||||
| #include "minddata/dataset/engine/opt/pre/cache_error_pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Constructor | |||||
| CacheErrorPass::CacheErrorPass() : is_cached_(false) {} | |||||
| // Identifies the subtree below this node as being cached | |||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| // Turn on the flag that we're under a merge op | |||||
| is_cached_ = true; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Returns an error if ZipOp exists under a cache | |||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) { | |||||
| if (is_cached_) { | |||||
| RETURN_STATUS_UNEXPECTED("ZipOp is currently not supported as a descendant operator under a cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Returns an error if MapOp with non-deterministic TensorOps exists under a cache | |||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||||
| if (is_cached_) { | |||||
| auto tfuncs = node->TFuncs(); | |||||
| for (size_t i = 0; i < tfuncs.size(); i++) { | |||||
| if (!tfuncs[i]->Deterministic()) { | |||||
| RETURN_STATUS_UNEXPECTED( | |||||
| "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache."); | |||||
| } | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Returns an error if ConcatOp exists under a cache | |||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) { | |||||
| if (is_cached_) { | |||||
| RETURN_STATUS_UNEXPECTED("ConcatOp is currently not supported as a descendant operator under a cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| #ifdef ENABLE_PYTHON | |||||
| // Returns an error if FilterOp exists under a cache | |||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||||
| if (is_cached_) { | |||||
| RETURN_STATUS_UNEXPECTED("FilterOp is currently not supported as a descendant operator under a cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| #endif | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_ | |||||
| #include <memory> | |||||
| #include <stack> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/engine/opt/pass.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \class CacheErrorPass cache_error_pass.h | |||||
| /// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures. | |||||
| class CacheErrorPass : public NodePass { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CacheErrorPass(); | |||||
| /// \brief Destructor | |||||
| ~CacheErrorPass() = default; | |||||
| /// \brief Identifies the subtree below this node as being cached | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| /// \brief Returns an error if ZipOp exists under a cache | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override; | |||||
| /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||||
| /// \brief Returns an error if ConcatOp exists under a cache | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | |||||
| #ifdef ENABLE_PYTHON | |||||
| /// \brief Returns an error if FilterOp exists under a cache | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | |||||
| #endif | |||||
| private: | |||||
| bool is_cached_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_ | |||||
| @@ -155,50 +155,77 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> n | |||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for AlbumOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for MindRecordOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| #endif | #endif | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for GeneratorOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { | ||||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||||
| if (is_caching_) { | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -40,13 +40,6 @@ Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSe | |||||
| injection_point_ = nullptr; | injection_point_ = nullptr; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Temporary code to prevent the injection of epoch control when cache op is present | |||||
| // Remove this code in cache op phase 2 | |||||
| Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { | |||||
| injection_point_ = nullptr; | |||||
| return Status::OK(); | |||||
| } | |||||
| #endif | #endif | ||||
| Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { | ||||
| @@ -54,13 +54,6 @@ class EpochInjectionPass : public TreePass { | |||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override; | ||||
| /// \brief Temporary code to prevent the injection of epoch control when cache op is present. | |||||
| /// Remove this code in cache op phase 2 | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||||
| #endif | #endif | ||||
| /// \brief Register the DeviceQueueOp for further action. | /// \brief Register the DeviceQueueOp for further action. | ||||
| @@ -62,6 +62,7 @@ Status RandomApplyOp::Compute(const TensorRow &input, TensorRow *output) { | |||||
| RandomApplyOp::RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops) | RandomApplyOp::RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops) | ||||
| : prob_(prob), gen_(GetSeed()), rand_double_(0, 1) { | : prob_(prob), gen_(GetSeed()), rand_double_(0, 1) { | ||||
| compose_ = std::make_unique<ComposeOp>(ops); | compose_ = std::make_unique<ComposeOp>(ops); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -92,6 +92,7 @@ RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops | |||||
| } else if (ops_.size() == 1) { | } else if (ops_.size() == 1) { | ||||
| MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time."; | MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time."; | ||||
| } | } | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,7 @@ RandomAffineOp::RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t | |||||
| interpolation_ = interpolation; | interpolation_ = interpolation; | ||||
| fill_value_ = fill_value; | fill_value_ = fill_value; | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| @@ -36,6 +36,7 @@ RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_f | |||||
| hue_factor_start_(s_hue_factor), | hue_factor_start_(s_hue_factor), | ||||
| hue_factor_end_(e_hue_factor) { | hue_factor_end_(e_hue_factor) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| Status RandomColorAdjustOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status RandomColorAdjustOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| @@ -19,7 +19,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) {} | |||||
| RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_lb, t_ub), t_lb_(t_lb), t_ub_(t_ub) { | |||||
| is_deterministic_ = false; | |||||
| } | |||||
| Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) { | Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) { | ||||
| IO_CHECK(in, out); | IO_CHECK(in, out); | ||||
| @@ -41,6 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ | |||||
| aspect_ub_(aspect_ub), | aspect_ub_(aspect_ub), | ||||
| max_iter_(max_iter) { | max_iter_(max_iter) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| @@ -46,6 +46,7 @@ RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_ | |||||
| fill_g_(fill_g), | fill_g_(fill_g), | ||||
| fill_b_(fill_b) { | fill_b_(fill_b) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| Status RandomCropOp::ImagePadding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *pad_image, | Status RandomCropOp::ImagePadding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *pad_image, | ||||
| @@ -33,6 +33,7 @@ class RandomHorizontalFlipOp : public TensorOp { | |||||
| static const float kDefProbability; | static const float kDefProbability; | ||||
| explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) { | explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) { | ||||
| is_deterministic_ = false; | |||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| } | } | ||||
| @@ -35,6 +35,7 @@ class RandomHorizontalFlipWithBBoxOp : public TensorOp { | |||||
| explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { | explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| ~RandomHorizontalFlipWithBBoxOp() override = default; | ~RandomHorizontalFlipWithBBoxOp() override = default; | ||||
| @@ -29,6 +29,7 @@ const std::vector<uint8_t> RandomPosterizeOp::kBitRange = {4, 8}; | |||||
| RandomPosterizeOp::RandomPosterizeOp(const std::vector<uint8_t> &bit_range) | RandomPosterizeOp::RandomPosterizeOp(const std::vector<uint8_t> &bit_range) | ||||
| : PosterizeOp(bit_range[0]), bit_range_(bit_range) { | : PosterizeOp(bit_range[0]), bit_range_(bit_range) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| Status RandomPosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status RandomPosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| @@ -35,6 +35,7 @@ class RandomResizeOp : public ResizeOp { | |||||
| explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) { | explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) { | ||||
| random_generator_.seed(GetSeed()); | random_generator_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| ~RandomResizeOp() = default; | ~RandomResizeOp() = default; | ||||
| @@ -36,6 +36,7 @@ class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { | |||||
| static const int32_t kDefTargetWidth; | static const int32_t kDefTargetWidth; | ||||
| explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { | explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { | ||||
| random_generator_.seed(GetSeed()); | random_generator_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| ~RandomResizeWithBBoxOp() = default; | ~RandomResizeWithBBoxOp() = default; | ||||
| @@ -46,6 +46,7 @@ RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float c | |||||
| fill_g_(fill_g), | fill_g_(fill_g), | ||||
| fill_b_(fill_b) { | fill_b_(fill_b) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| // main function call for random rotation : Generate the random degrees | // main function call for random rotation : Generate the random degrees | ||||
| @@ -90,6 +90,7 @@ RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &p | |||||
| if (policy_.empty()) { | if (policy_.empty()) { | ||||
| MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty."; | MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty."; | ||||
| } | } | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -31,6 +31,7 @@ const float RandomSharpnessOp::kDefEndDegree = 1.9; | |||||
| RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree) | RandomSharpnessOp::RandomSharpnessOp(float start_degree, float end_degree) | ||||
| : start_degree_(start_degree), end_degree_(end_degree) { | : start_degree_(start_degree), end_degree_(end_degree) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| /// main function call for random sharpness : Generate the random degrees | /// main function call for random sharpness : Generate the random degrees | ||||
| @@ -32,7 +32,10 @@ namespace dataset { | |||||
| class RandomSolarizeOp : public SolarizeOp { | class RandomSolarizeOp : public SolarizeOp { | ||||
| public: | public: | ||||
| // Pick a random threshold value to solarize the image with | // Pick a random threshold value to solarize the image with | ||||
| explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { rnd_.seed(GetSeed()); } | |||||
| explicit RandomSolarizeOp(std::vector<uint8_t> threshold = {0, 255}) : threshold_(threshold) { | |||||
| rnd_.seed(GetSeed()); | |||||
| is_deterministic_ = false; | |||||
| } | |||||
| ~RandomSolarizeOp() = default; | ~RandomSolarizeOp() = default; | ||||
| @@ -34,6 +34,7 @@ class RandomVerticalFlipOp : public TensorOp { | |||||
| explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) { | explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| ~RandomVerticalFlipOp() override = default; | ~RandomVerticalFlipOp() override = default; | ||||
| @@ -34,6 +34,7 @@ class RandomVerticalFlipWithBBoxOp : public TensorOp { | |||||
| // @param probability: Probablity of Image flipping, 0.5 by default | // @param probability: Probablity of Image flipping, 0.5 by default | ||||
| explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { | explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| is_deterministic_ = false; | |||||
| } | } | ||||
| ~RandomVerticalFlipWithBBoxOp() override = default; | ~RandomVerticalFlipWithBBoxOp() override = default; | ||||
| @@ -168,6 +168,10 @@ class TensorOp { | |||||
| // @return true/false | // @return true/false | ||||
| bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } | bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } | ||||
| // Returns true oif the TensorOp produces deterministic result. | |||||
| // @return true/false | |||||
| bool Deterministic() { return is_deterministic_; } | |||||
| // Function to determine the number of inputs the TensorOp can take. 0: means undefined. | // Function to determine the number of inputs the TensorOp can take. 0: means undefined. | ||||
| // @return uint32_t | // @return uint32_t | ||||
| virtual uint32_t NumInput() { return 1; } | virtual uint32_t NumInput() { return 1; } | ||||
| @@ -191,6 +195,9 @@ class TensorOp { | |||||
| virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs); | virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs); | ||||
| virtual std::string Name() const = 0; | virtual std::string Name() const = 0; | ||||
| protected: | |||||
| bool is_deterministic_{true}; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -88,21 +88,21 @@ class Allocator { | |||||
| std::shared_ptr<MemoryPool> pool_; | std::shared_ptr<MemoryPool> pool_; | ||||
| }; | }; | ||||
| /// \brief It is a wrapper of unique_ptr with a custom Allocator class defined above | /// \brief It is a wrapper of unique_ptr with a custom Allocator class defined above | ||||
| template <typename T, typename... Args> | |||||
| Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, Allocator<T> alloc, size_t n, Args &&... args) { | |||||
| template <typename T, typename C = std::allocator<T>, typename... Args> | |||||
| Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc, size_t n, Args &&... args) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | RETURN_UNEXPECTED_IF_NULL(out); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive"); | CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive"); | ||||
| try { | try { | ||||
| T *data = alloc.allocate(n); | T *data = alloc.allocate(n); | ||||
| if (!std::is_arithmetic<T>::value) { | if (!std::is_arithmetic<T>::value) { | ||||
| for (auto i = 0; i < n; i++) { | for (auto i = 0; i < n; i++) { | ||||
| std::allocator_traits<Allocator<T>>::construct(alloc, &(data[i]), std::forward<Args>(args)...); | |||||
| std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...); | |||||
| } | } | ||||
| } | } | ||||
| auto deleter = [](T *p, Allocator<T> f_alloc, size_t f_n) { | |||||
| auto deleter = [](T *p, C f_alloc, size_t f_n) { | |||||
| if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) { | if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) { | ||||
| for (auto i = 0; i < f_n; ++i) { | for (auto i = 0; i < f_n; ++i) { | ||||
| std::allocator_traits<Allocator<T>>::destroy(f_alloc, &p[i]); | |||||
| std::allocator_traits<C>::destroy(f_alloc, &p[i]); | |||||
| } | } | ||||
| } | } | ||||
| f_alloc.deallocate(p, f_n); | f_alloc.deallocate(p, f_n); | ||||
| @@ -129,7 +129,7 @@ class MemGuard { | |||||
| MemGuard(const MemGuard &) = delete; | MemGuard(const MemGuard &) = delete; | ||||
| MemGuard &operator=(const MemGuard &) = delete; | MemGuard &operator=(const MemGuard &) = delete; | ||||
| // On the other hand, We can support move constructor | // On the other hand, We can support move constructor | ||||
| MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} | |||||
| MemGuard(MemGuard &&lhs) noexcept : n_(lhs.n_), alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)) {} | |||||
| MemGuard &operator=(MemGuard &&lhs) noexcept { | MemGuard &operator=(MemGuard &&lhs) noexcept { | ||||
| if (this != &lhs) { | if (this != &lhs) { | ||||
| this->deallocate(); | this->deallocate(); | ||||
| @@ -37,7 +37,8 @@ struct MemHdr { | |||||
| ArenaImpl::ArenaImpl(void *ptr, size_t sz) : size_in_bytes_(sz), ptr_(ptr) { | ArenaImpl::ArenaImpl(void *ptr, size_t sz) : size_in_bytes_(sz), ptr_(ptr) { | ||||
| // Divide the memory into blocks. Ignore the last partial block. | // Divide the memory into blocks. Ignore the last partial block. | ||||
| uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ; | uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ; | ||||
| MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; | |||||
| MS_LOG(DEBUG) << "Arena memory pool is created. Number of blocks : " << num_blks << ". Block size : " << ARENA_BLK_SZ | |||||
| << "."; | |||||
| tr_.Insert(0, num_blks); | tr_.Insert(0, num_blks); | ||||
| } | } | ||||
| @@ -233,9 +234,9 @@ std::ostream &operator<<(std::ostream &os, const ArenaImpl &s) { | |||||
| Status Arena::Init() { | Status Arena::Init() { | ||||
| try { | try { | ||||
| auto sz = size_in_MB_ * 1048576L; | |||||
| mem_ = std::make_unique<uint8_t[]>(sz); | |||||
| impl_ = std::make_unique<ArenaImpl>(mem_.get(), sz); | |||||
| int64_t sz = size_in_MB_ * 1048576L; | |||||
| RETURN_IF_NOT_OK(mem_.allocate(sz)); | |||||
| impl_ = std::make_unique<ArenaImpl>(mem_.GetMutablePointer(), sz); | |||||
| } catch (std::bad_alloc &e) { | } catch (std::bad_alloc &e) { | ||||
| return Status(StatusCode::kOutOfMemory); | return Status(StatusCode::kOutOfMemory); | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/util/allocator.h" | |||||
| #include "minddata/dataset/util/memory_pool.h" | #include "minddata/dataset/util/memory_pool.h" | ||||
| #include "minddata/dataset/util/treap.h" | #include "minddata/dataset/util/treap.h" | ||||
| @@ -140,7 +141,7 @@ class Arena : public MemoryPool { | |||||
| protected: | protected: | ||||
| mutable std::mutex mux_; | mutable std::mutex mux_; | ||||
| std::unique_ptr<ArenaImpl> impl_; | std::unique_ptr<ArenaImpl> impl_; | ||||
| std::unique_ptr<uint8_t[]> mem_; | |||||
| MemGuard<uint8_t> mem_; | |||||
| size_t size_in_MB_; | size_t size_in_MB_; | ||||
| explicit Arena(size_t val_in_MB = 4096); | explicit Arena(size_t val_in_MB = 4096); | ||||
| @@ -131,6 +131,23 @@ class BPlusTree { | |||||
| tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {} | tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {} | ||||
| }; | }; | ||||
| /// \brief Statistics functions | |||||
| /// \return Return the height of the tree | |||||
| auto GetHeight() const { return empty() ? 0 : stats_.level_ + 1; } | |||||
| /// \return Order of the B+ tree | |||||
| auto GetOrder() const { return traits::kLeafSlots; } | |||||
| /// \return Number of leaves nodes | |||||
| auto GetNumLeaves() const { return stats_.leaves_; } | |||||
| /// \return Number of inner nodes | |||||
| auto GetNumInnerNodes() const { return stats_.inner_nodes_; } | |||||
| /// \brief Toggle locking | |||||
| /// \note Once locking is off. It is user's responsibility to ensure concurrency | |||||
| void SetLocking(bool on_off) { | |||||
| UniqueLock lck(&rw_lock_); | |||||
| acquire_lock_ = on_off; | |||||
| } | |||||
| private: | private: | ||||
| // Abstract class of a node (leaf or inner) | // Abstract class of a node (leaf or inner) | ||||
| class BaseNode { | class BaseNode { | ||||
| @@ -288,6 +305,17 @@ class BPlusTree { | |||||
| key_compare key_less_; | key_compare key_less_; | ||||
| // Stat | // Stat | ||||
| tree_stats stats_; | tree_stats stats_; | ||||
| // lock mode | |||||
| bool acquire_lock_; | |||||
| void Init() { | |||||
| typename LeafNode::alloc_type alloc(alloc_); | |||||
| auto *p = alloc.allocate(1); | |||||
| root_ = new (p) LeafNode(alloc_); | |||||
| all_.Prepend(p); | |||||
| leaf_nodes_.Append(p); | |||||
| stats_.leaves_++; | |||||
| } | |||||
| bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); } | bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); } | ||||
| @@ -350,11 +378,11 @@ class BPlusTree { | |||||
| ~Iterator(); | ~Iterator(); | ||||
| explicit Iterator(const Iterator &); | |||||
| Iterator(const Iterator &); | |||||
| Iterator &operator=(const Iterator &lhs); | Iterator &operator=(const Iterator &lhs); | ||||
| explicit Iterator(Iterator &&); | |||||
| Iterator(Iterator &&) noexcept; | |||||
| Iterator &operator=(Iterator &&lhs); | Iterator &operator=(Iterator &&lhs); | ||||
| @@ -399,11 +427,11 @@ class BPlusTree { | |||||
| ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) | ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) | ||||
| : cur_(leaf), slot_(slot), locked_(locked) {} | : cur_(leaf), slot_(slot), locked_(locked) {} | ||||
| explicit ConstIterator(const ConstIterator &); | |||||
| ConstIterator(const ConstIterator &); | |||||
| ConstIterator &operator=(const ConstIterator &lhs); | ConstIterator &operator=(const ConstIterator &lhs); | ||||
| explicit ConstIterator(ConstIterator &&); | |||||
| ConstIterator(ConstIterator &&) noexcept; | |||||
| ConstIterator &operator=(ConstIterator &&lhs); | ConstIterator &operator=(ConstIterator &&lhs); | ||||
| @@ -413,11 +413,16 @@ typename BPlusTree<K, V, A, C, T>::IndexRc BPlusTree<K, V, A, C, T>::Locate(RWLo | |||||
| } | } | ||||
| template <typename K, typename V, typename A, typename C, typename T> | template <typename K, typename V, typename A, typename C, typename T> | ||||
| BPlusTree<K, V, A, C, T>::BPlusTree() : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {} | |||||
| BPlusTree<K, V, A, C, T>::BPlusTree() | |||||
| : leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) { | |||||
| Init(); | |||||
| } | |||||
| template <typename K, typename V, typename A, typename C, typename T> | template <typename K, typename V, typename A, typename C, typename T> | ||||
| BPlusTree<K, V, A, C, T>::BPlusTree(const Allocator<V> &alloc) | BPlusTree<K, V, A, C, T>::BPlusTree(const Allocator<V> &alloc) | ||||
| : alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr) {} | |||||
| : alloc_(alloc), leaf_nodes_(&LeafNode::link_), all_(&BaseNode::lru_), root_(nullptr), acquire_lock_(true) { | |||||
| Init(); | |||||
| } | |||||
| template <typename K, typename V, typename A, typename C, typename T> | template <typename K, typename V, typename A, typename C, typename T> | ||||
| BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept { | BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept { | ||||
| @@ -446,20 +451,6 @@ BPlusTree<K, V, A, C, T>::~BPlusTree() noexcept { | |||||
| template <typename K, typename V, typename A, typename C, typename T> | template <typename K, typename V, typename A, typename C, typename T> | ||||
| Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<value_type> &&value) { | Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<value_type> &&value) { | ||||
| IndexRc rc; | IndexRc rc; | ||||
| if (root_ == nullptr) { | |||||
| UniqueLock lck(&rw_lock_); | |||||
| // Check again after we get the lock. Other thread may have created the root node already. | |||||
| if (root_ == nullptr) { | |||||
| LeafNode *leaf = nullptr; | |||||
| rc = AllocateLeaf(&leaf); | |||||
| if (rc != IndexRc::kOk) { | |||||
| return IndexRc2Status(rc); | |||||
| } | |||||
| leaf_nodes_.Append(leaf); | |||||
| root_ = leaf; | |||||
| } | |||||
| // lock will be unlocked when it goes out of scope. | |||||
| } | |||||
| bool retry = false; | bool retry = false; | ||||
| do { | do { | ||||
| // Track all the paths to the target and lock each internal node in S. | // Track all the paths to the target and lock each internal node in S. | ||||
| @@ -468,7 +459,7 @@ Status BPlusTree<K, V, A, C, T>::DoInsert(const key_type &key, std::unique_ptr<v | |||||
| retry = false; | retry = false; | ||||
| BaseNode *new_child = nullptr; | BaseNode *new_child = nullptr; | ||||
| key_type new_key = key_type(); | key_type new_key = key_type(); | ||||
| rc = InsertKeyValue(&InsCB, root_, key, std::move(value), &new_key, &new_child); | |||||
| rc = InsertKeyValue(acquire_lock_ ? &InsCB : nullptr, root_, key, std::move(value), &new_key, &new_child); | |||||
| if (rc == IndexRc::kRetry) { | if (rc == IndexRc::kRetry) { | ||||
| retry = true; | retry = true; | ||||
| } else if (rc != IndexRc::kOk) { | } else if (rc != IndexRc::kOk) { | ||||
| @@ -511,9 +502,12 @@ std::unique_ptr<V> BPlusTree<K, V, A, C, T>::DoUpdate(const key_type &key, std:: | |||||
| if (root_ != nullptr) { | if (root_ != nullptr) { | ||||
| LeafNode *leaf = nullptr; | LeafNode *leaf = nullptr; | ||||
| slot_type slot; | slot_type slot; | ||||
| RWLock *myLock = &this->rw_lock_; | |||||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||||
| myLock->LockShared(); | |||||
| RWLock *myLock = nullptr; | |||||
| if (acquire_lock_) { | |||||
| myLock = &this->rw_lock_; | |||||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||||
| myLock->LockShared(); | |||||
| } | |||||
| IndexRc rc = Locate(myLock, true, root_, key, &leaf, &slot); | IndexRc rc = Locate(myLock, true, root_, key, &leaf, &slot); | ||||
| if (rc == IndexRc::kOk) { | if (rc == IndexRc::kOk) { | ||||
| // All locks from the tree to the parent of leaf are all gone. We still have a X lock | // All locks from the tree to the parent of leaf are all gone. We still have a X lock | ||||
| @@ -521,7 +515,9 @@ std::unique_ptr<V> BPlusTree<K, V, A, C, T>::DoUpdate(const key_type &key, std:: | |||||
| // Swap out the old value and replace it with new value. | // Swap out the old value and replace it with new value. | ||||
| std::unique_ptr<value_type> old = std::move(leaf->data_[leaf->slot_dir_[slot]]); | std::unique_ptr<value_type> old = std::move(leaf->data_[leaf->slot_dir_[slot]]); | ||||
| leaf->data_[leaf->slot_dir_[slot]] = std::move(new_value); | leaf->data_[leaf->slot_dir_[slot]] = std::move(new_value); | ||||
| leaf->rw_lock_.Unlock(); | |||||
| if (acquire_lock_) { | |||||
| leaf->rw_lock_.Unlock(); | |||||
| } | |||||
| return old; | return old; | ||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Key not found. rc = " << static_cast<int>(rc) << "."; | MS_LOG(DEBUG) << "Key not found. rc = " << static_cast<int>(rc) << "."; | ||||
| @@ -109,7 +109,7 @@ BPlusTree<K, V, A, C, T>::Iterator::Iterator(const BPlusTree<K, V, A, C, T>::Ite | |||||
| } | } | ||||
| template <typename K, typename V, typename A, typename C, typename T> | template <typename K, typename V, typename A, typename C, typename T> | ||||
| BPlusTree<K, V, A, C, T>::Iterator::Iterator(BPlusTree<K, V, A, C, T>::Iterator &&lhs) { | |||||
| BPlusTree<K, V, A, C, T>::Iterator::Iterator(BPlusTree<K, V, A, C, T>::Iterator &&lhs) noexcept { | |||||
| this->cur_ = lhs.cur_; | this->cur_ = lhs.cur_; | ||||
| this->slot_ = lhs.slot_; | this->slot_ = lhs.slot_; | ||||
| this->locked_ = lhs.locked_; | this->locked_ = lhs.locked_; | ||||
| @@ -241,7 +241,7 @@ BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(const BPlusTree<K, V, A, | |||||
| } | } | ||||
| template <typename K, typename V, typename A, typename C, typename T> | template <typename K, typename V, typename A, typename C, typename T> | ||||
| BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(BPlusTree<K, V, A, C, T>::ConstIterator &&lhs) { | |||||
| BPlusTree<K, V, A, C, T>::ConstIterator::ConstIterator(BPlusTree<K, V, A, C, T>::ConstIterator &&lhs) noexcept { | |||||
| this->cur_ = lhs.cur_; | this->cur_ = lhs.cur_; | ||||
| this->slot_ = lhs.slot_; | this->slot_ = lhs.slot_; | ||||
| this->locked_ = lhs.locked_; | this->locked_ = lhs.locked_; | ||||
| @@ -290,9 +290,12 @@ std::pair<typename BPlusTree<K, V, A, C, T>::ConstIterator, bool> BPlusTree<K, V | |||||
| if (root_ != nullptr) { | if (root_ != nullptr) { | ||||
| LeafNode *leaf = nullptr; | LeafNode *leaf = nullptr; | ||||
| slot_type slot; | slot_type slot; | ||||
| RWLock *myLock = &this->rw_lock_; | |||||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||||
| myLock->LockShared(); | |||||
| RWLock *myLock = nullptr; | |||||
| if (acquire_lock_) { | |||||
| myLock = &this->rw_lock_; | |||||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||||
| myLock->LockShared(); | |||||
| } | |||||
| IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); | IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); | ||||
| bool find = (rc == IndexRc::kOk); | bool find = (rc == IndexRc::kOk); | ||||
| return std::make_pair(ConstIterator(leaf, slot, find), find); | return std::make_pair(ConstIterator(leaf, slot, find), find); | ||||
| @@ -306,9 +309,12 @@ std::pair<typename BPlusTree<K, V, A, C, T>::Iterator, bool> BPlusTree<K, V, A, | |||||
| if (root_ != nullptr) { | if (root_ != nullptr) { | ||||
| LeafNode *leaf = nullptr; | LeafNode *leaf = nullptr; | ||||
| slot_type slot; | slot_type slot; | ||||
| RWLock *myLock = &this->rw_lock_; | |||||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||||
| myLock->LockShared(); | |||||
| RWLock *myLock = nullptr; | |||||
| if (acquire_lock_) { | |||||
| myLock = &this->rw_lock_; | |||||
| // Lock the tree in S, pass the lock to Locate which will unlock it for us underneath. | |||||
| myLock->LockShared(); | |||||
| } | |||||
| IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); | IndexRc rc = Locate(myLock, false, root_, key, &leaf, &slot); | ||||
| bool find = (rc == IndexRc::kOk); | bool find = (rc == IndexRc::kOk); | ||||
| return std::make_pair(Iterator(leaf, slot, find), find); | return std::make_pair(Iterator(leaf, slot, find), find); | ||||
| @@ -69,7 +69,7 @@ Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) n | |||||
| *p = addr; | *p = addr; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } else { | } else { | ||||
| return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); | |||||
| return Status(StatusCode::kBuddySpaceFull, "BuddySpace full. Not an error. Please ignore."); | |||||
| } | } | ||||
| } | } | ||||
| @@ -20,8 +20,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CachePool::CachePool(const value_allocator &alloc, const std::string &root) | |||||
| : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} | |||||
| CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root) | |||||
| : alloc_(alloc), | |||||
| root_(root), | |||||
| subfolder_(Services::GetUniqueID()), | |||||
| sm_(nullptr), | |||||
| tree_(nullptr), | |||||
| custom_arena_(ourOwnArena) {} | |||||
| Status CachePool::DoServiceStart() { | Status CachePool::DoServiceStart() { | ||||
| tree_ = std::make_shared<data_index>(); | tree_ = std::make_shared<data_index>(); | ||||
| @@ -45,9 +50,12 @@ Status CachePool::DoServiceStop() { | |||||
| } | } | ||||
| } | } | ||||
| sm_.reset(); | sm_.reset(); | ||||
| for (auto &bl : *tree_) { | |||||
| if (bl.ptr != nullptr) { | |||||
| alloc_.deallocate(bl.ptr, bl.sz); | |||||
| // If it is our own arena, skip freeing individual pieces. | |||||
| if (!custom_arena_) { | |||||
| for (auto &bl : *tree_) { | |||||
| if (bl.ptr != nullptr) { | |||||
| alloc_.deallocate(bl.ptr, bl.sz); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| tree_.reset(); | tree_.reset(); | ||||
| @@ -68,7 +76,7 @@ Status CachePool::DoServiceStop() { | |||||
| return rc2; | return rc2; | ||||
| } | } | ||||
| CachePool::~CachePool() noexcept { (void)ServiceStop(); } | CachePool::~CachePool() noexcept { (void)ServiceStop(); } | ||||
| Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_type *key) { | |||||
| Status CachePool::Insert(CachePool::key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly) { | |||||
| DataLocator bl; | DataLocator bl; | ||||
| Status rc; | Status rc; | ||||
| size_t sz = 0; | size_t sz = 0; | ||||
| @@ -78,22 +86,31 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t | |||||
| } | } | ||||
| bl.sz = sz; | bl.sz = sz; | ||||
| try { | try { | ||||
| bl.ptr = alloc_.allocate(sz); | |||||
| // We will do a piecewise copy. | |||||
| WritableSlice dest(bl.ptr, bl.sz); | |||||
| size_t pos = 0; | |||||
| for (auto &v : buf) { | |||||
| WritableSlice out(dest, pos); | |||||
| rc = WritableSlice::Copy(&out, v); | |||||
| if (!writeToDiskDirectly) { | |||||
| bl.ptr = alloc_.allocate(sz); | |||||
| // We will do a piecewise copy. | |||||
| WritableSlice dest(bl.ptr, bl.sz); | |||||
| size_t pos = 0; | |||||
| for (auto &v : buf) { | |||||
| WritableSlice out(dest, pos); | |||||
| rc = WritableSlice::Copy(&out, v); | |||||
| if (rc.IsError()) { | |||||
| break; | |||||
| } | |||||
| pos += v.GetSize(); | |||||
| } | |||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| break; | |||||
| alloc_.deallocate(bl.ptr, sz); | |||||
| bl.ptr = nullptr; | |||||
| return rc; | |||||
| } | } | ||||
| pos += v.GetSize(); | |||||
| } | |||||
| if (rc.IsError()) { | |||||
| alloc_.deallocate(bl.ptr, sz); | |||||
| bl.ptr = nullptr; | |||||
| return rc; | |||||
| } else if (sm_ != nullptr) { | |||||
| MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes."; | |||||
| RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | |||||
| } else { | |||||
| // If asked to spill to disk instead but there is no storage set up, simply return no memory | |||||
| // instead. | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| } | } | ||||
| } catch (std::bad_alloc &e) { | } catch (std::bad_alloc &e) { | ||||
| if (sm_ != nullptr) { | if (sm_ != nullptr) { | ||||
| @@ -102,7 +119,13 @@ Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_t | |||||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | ||||
| } | } | ||||
| } | } | ||||
| rc = tree_->insert(bl, key); | |||||
| // Insert into the B+ tree. We may still get out of memory error. So need to catch it. | |||||
| try { | |||||
| rc = tree_->DoInsert(key, bl); | |||||
| } catch (const std::bad_alloc &e) { | |||||
| rc = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||||
| } | |||||
| // Duplicate key is treated as error and we will also free the memory. | |||||
| if (rc.IsError() && bl.ptr != nullptr) { | if (rc.IsError() && bl.ptr != nullptr) { | ||||
| alloc_.deallocate(bl.ptr, sz); | alloc_.deallocate(bl.ptr, sz); | ||||
| } | } | ||||
| @@ -138,15 +161,26 @@ Path CachePool::GetSpillPath() const { | |||||
| auto spill = Path(root_) / subfolder_; | auto spill = Path(root_) / subfolder_; | ||||
| return spill; | return spill; | ||||
| } | } | ||||
| CachePool::CacheStat CachePool::GetStat() const { | |||||
| CacheStat cs{0}; | |||||
| CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { | |||||
| CacheStat cs{-1, -1, 0, 0, 0}; | |||||
| int64_t total_sz = 0; | int64_t total_sz = 0; | ||||
| for (auto &it : *tree_) { | |||||
| total_sz += it.sz; | |||||
| if (it.ptr != nullptr) { | |||||
| ++cs.num_mem_cached; | |||||
| } else { | |||||
| ++cs.num_disk_cached; | |||||
| if (tree_->begin() != tree_->end()) { | |||||
| cs.min_key = tree_->begin().key(); | |||||
| cs.max_key = cs.min_key; // will adjust later. | |||||
| for (auto it = tree_->begin(); it != tree_->end(); ++it) { | |||||
| total_sz += it.value().sz; | |||||
| if (it.value().ptr != nullptr) { | |||||
| ++cs.num_mem_cached; | |||||
| } else { | |||||
| ++cs.num_disk_cached; | |||||
| } | |||||
| auto cur_key = it.key(); | |||||
| if (GetMissingKeys) { | |||||
| for (auto i = cs.max_key + 1; i < cur_key; ++i) { | |||||
| cs.gap.push_back((i)); | |||||
| } | |||||
| } | |||||
| cs.max_key = cur_key; | |||||
| } | } | ||||
| } | } | ||||
| if (total_sz > 0) { | if (total_sz > 0) { | ||||
| @@ -25,13 +25,13 @@ | |||||
| #include "minddata/dataset/util/slice.h" | #include "minddata/dataset/util/slice.h" | ||||
| #include "minddata/dataset/util/storage_manager.h" | #include "minddata/dataset/util/storage_manager.h" | ||||
| #include "minddata/dataset/util/auto_index.h" | #include "minddata/dataset/util/auto_index.h" | ||||
| #include "minddata/dataset/util/btree.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| /// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of | /// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of | ||||
| /// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to | /// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to | ||||
| /// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to | |||||
| /// restore the buffer. | |||||
| /// disk (if a disk directory is provided). User must provide a key to insert the buffer. | |||||
| /// \see ReadableSlice | /// \see ReadableSlice | ||||
| class CachePool : public Service { | class CachePool : public Service { | ||||
| public: | public: | ||||
| @@ -73,22 +73,25 @@ class CachePool : public Service { | |||||
| StorageManager::key_type storage_key; | StorageManager::key_type storage_key; | ||||
| }; | }; | ||||
| using data_index = AutoIndexObj<DataLocator>; | |||||
| using data_index = BPlusTree<int64_t, DataLocator>; | |||||
| using key_type = data_index::key_type; | using key_type = data_index::key_type; | ||||
| using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other; | using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other; | ||||
| /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and | /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and | ||||
| /// how many elements are spilled to disk. | /// how many elements are spilled to disk. | ||||
| struct CacheStat { | struct CacheStat { | ||||
| key_type min_key; | |||||
| key_type max_key; | |||||
| int64_t num_mem_cached; | int64_t num_mem_cached; | ||||
| int64_t num_disk_cached; | int64_t num_disk_cached; | ||||
| int64_t average_cache_sz; | int64_t average_cache_sz; | ||||
| std::vector<key_type> gap; | |||||
| }; | }; | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| /// \param alloc Allocator to allocate memory from | /// \param alloc Allocator to allocate memory from | ||||
| /// \param root Optional disk folder to spill | /// \param root Optional disk folder to spill | ||||
| explicit CachePool(const value_allocator &alloc, const std::string &root = ""); | |||||
| explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = ""); | |||||
| CachePool(const CachePool &) = delete; | CachePool(const CachePool &) = delete; | ||||
| CachePool(CachePool &&) = delete; | CachePool(CachePool &&) = delete; | ||||
| @@ -103,10 +106,11 @@ class CachePool : public Service { | |||||
| /// \brief Insert a sequence of ReadableSlice objects into the pool. | /// \brief Insert a sequence of ReadableSlice objects into the pool. | ||||
| /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. | /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. | ||||
| /// \param[in] key User supplied key | |||||
| /// \param[in] buf A sequence of ReadableSlice objects. | /// \param[in] buf A sequence of ReadableSlice objects. | ||||
| /// \param[out] key Generated key | |||||
| /// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory | |||||
| /// \return Error code | /// \return Error code | ||||
| Status Insert(const std::vector<ReadableSlice> &buf, key_type *key); | |||||
| Status Insert(key_type key, const std::vector<ReadableSlice> &buf, bool writeToDiskDirectly); | |||||
| /// \brief Restore a cached buffer (from memory or disk) | /// \brief Restore a cached buffer (from memory or disk) | ||||
| /// \param[in] key A previous key returned from Insert | /// \param[in] key A previous key returned from Insert | ||||
| /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice | /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice | ||||
| @@ -122,18 +126,23 @@ class CachePool : public Service { | |||||
| /// \brief Get statistics. | /// \brief Get statistics. | ||||
| /// \return CacheStat object | /// \return CacheStat object | ||||
| CacheStat GetStat() const; | |||||
| CacheStat GetStat(bool GetMissingKeys = false) const; | |||||
| const value_allocator &get_allocator() const; | const value_allocator &get_allocator() const; | ||||
| std::string MyName() const { return subfolder_; } | std::string MyName() const { return subfolder_; } | ||||
| /// \brief Toggle locking | |||||
| /// \note Once locking is off. It is user's responsibility to ensure concurrency | |||||
| void SetLocking(bool on_off) { tree_->SetLocking(on_off); } | |||||
| private: | private: | ||||
| value_allocator alloc_; | value_allocator alloc_; | ||||
| Path root_; | Path root_; | ||||
| const std::string subfolder_; | const std::string subfolder_; | ||||
| std::shared_ptr<StorageManager> sm_; | std::shared_ptr<StorageManager> sm_; | ||||
| std::shared_ptr<data_index> tree_; | std::shared_ptr<data_index> tree_; | ||||
| bool custom_arena_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -133,12 +133,13 @@ void CircularPool::Deallocate(void *p) { | |||||
| // Lock in the chain in shared mode and find out which | // Lock in the chain in shared mode and find out which | ||||
| // segment it comes from | // segment it comes from | ||||
| SharedLock lock(&rw_lock_); | SharedLock lock(&rw_lock_); | ||||
| auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr<Arena> &b) -> bool { | |||||
| auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool { | |||||
| char *q = reinterpret_cast<char *>(p); | char *q = reinterpret_cast<char *>(p); | ||||
| char *base = const_cast<char *>(reinterpret_cast<const char *>(b->get_base_addr())); | |||||
| return (q > base && q < base + b->get_max_size()); | |||||
| auto *base = reinterpret_cast<const char *>(b->get_base_addr()); | |||||
| return (q > base && q < base + arena_size_ * 1048576L); | |||||
| }); | }); | ||||
| lock.Unlock(); | lock.Unlock(); | ||||
| MS_ASSERT(it != mem_segments_.end()); | |||||
| it->get()->Deallocate(p); | it->get()->Deallocate(p); | ||||
| } | } | ||||
| @@ -150,10 +151,10 @@ Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) { | |||||
| } | } | ||||
| void *p = *pp; | void *p = *pp; | ||||
| SharedLock lock(&rw_lock_); | SharedLock lock(&rw_lock_); | ||||
| auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr<Arena> &b) -> bool { | |||||
| auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool { | |||||
| char *q = reinterpret_cast<char *>(p); | char *q = reinterpret_cast<char *>(p); | ||||
| char *base = const_cast<char *>(reinterpret_cast<const char *>(b->get_base_addr())); | |||||
| return (q > base && q < base + b->get_max_size()); | |||||
| auto *base = reinterpret_cast<const char *>(b->get_base_addr()); | |||||
| return (q > base && q < base + arena_size_ * 1048576L); | |||||
| }); | }); | ||||
| lock.Unlock(); | lock.Unlock(); | ||||
| MS_ASSERT(it != mem_segments_.end()); | MS_ASSERT(it != mem_segments_.end()); | ||||
| @@ -16,11 +16,14 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_ | ||||
| #include <atomic> | |||||
| #include <deque> | #include <deque> | ||||
| #include <iostream> | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include "minddata/dataset/util/allocator.h" | #include "minddata/dataset/util/allocator.h" | ||||
| #include "minddata/dataset/util/system_pool.h" | |||||
| #include "minddata/dataset/util/semaphore.h" | #include "minddata/dataset/util/semaphore.h" | ||||
| #include "minddata/dataset/util/services.h" | #include "minddata/dataset/util/services.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -37,7 +40,7 @@ class QueueMap { | |||||
| using key_type = K; | using key_type = K; | ||||
| using value_type = T; | using value_type = T; | ||||
| QueueMap() = default; | |||||
| QueueMap() : num_rows_(0) {} | |||||
| virtual ~QueueMap() = default; | virtual ~QueueMap() = default; | ||||
| /// Add an element <key, T> to the map and wake up any consumer that is waiting | /// Add an element <key, T> to the map and wake up any consumer that is waiting | ||||
| @@ -48,6 +51,7 @@ class QueueMap { | |||||
| RequestQueue *rq = nullptr; | RequestQueue *rq = nullptr; | ||||
| RETURN_IF_NOT_OK(GetRq(key, &rq)); | RETURN_IF_NOT_OK(GetRq(key, &rq)); | ||||
| RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload))); | RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload))); | ||||
| ++num_rows_; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -56,9 +60,35 @@ class QueueMap { | |||||
| RequestQueue *rq = nullptr; | RequestQueue *rq = nullptr; | ||||
| RETURN_IF_NOT_OK(GetRq(key, &rq)); | RETURN_IF_NOT_OK(GetRq(key, &rq)); | ||||
| RETURN_IF_NOT_OK(rq->Wait(out)); | RETURN_IF_NOT_OK(rq->Wait(out)); | ||||
| --num_rows_; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| /// Get the number of elements in the container | |||||
| /// \return The number of elements in the container | |||||
| int64_t size() const { return num_rows_; } | |||||
| /// \return if the container is empty | |||||
| bool empty() const { return num_rows_ == 0; } | |||||
| /// Print out some useful information about the container | |||||
| friend std::ostream &operator<<(std::ostream &out, const QueueMap &qm) { | |||||
| std::unique_lock<std::mutex> lck(qm.mux_); | |||||
| out << "Number of elements: " << qm.num_rows_ << "\n"; | |||||
| out << "Dumping internal info:\n"; | |||||
| int64_t k = 0; | |||||
| for (auto &it : qm.all_) { | |||||
| auto key = it.first; | |||||
| const RequestQueue *rq = it.second.GetPointer(); | |||||
| out << "(k:" << key << "," << *rq << ") "; | |||||
| ++k; | |||||
| if (k % 6 == 0) { | |||||
| out << "\n"; | |||||
| } | |||||
| } | |||||
| return out; | |||||
| } | |||||
| protected: | protected: | ||||
| /// This is a handshake structure between producer and consumer | /// This is a handshake structure between producer and consumer | ||||
| class RequestQueue { | class RequestQueue { | ||||
| @@ -86,8 +116,13 @@ class QueueMap { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| friend std::ostream &operator<<(std::ostream &out, const RequestQueue &rq) { | |||||
| out << "sz:" << rq.row_.size() << ",uc:" << rq.use_count_.Peek(); | |||||
| return out; | |||||
| } | |||||
| private: | private: | ||||
| std::mutex dq_mux_; | |||||
| mutable std::mutex dq_mux_; | |||||
| Semaphore use_count_; | Semaphore use_count_; | ||||
| std::deque<T> row_; | std::deque<T> row_; | ||||
| }; | }; | ||||
| @@ -104,7 +139,7 @@ class QueueMap { | |||||
| *out = it->second.GetMutablePointer(); | *out = it->second.GetMutablePointer(); | ||||
| } else { | } else { | ||||
| // We will create a new one. | // We will create a new one. | ||||
| auto alloc = Services::GetAllocator<RequestQueue>(); | |||||
| auto alloc = SystemPool::GetAllocator<RequestQueue>(); | |||||
| auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc)); | auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc)); | ||||
| if (r.second) { | if (r.second) { | ||||
| auto &mem = r.first->second; | auto &mem = r.first->second; | ||||
| @@ -118,8 +153,9 @@ class QueueMap { | |||||
| } | } | ||||
| private: | private: | ||||
| std::mutex mux_; | |||||
| mutable std::mutex mux_; | |||||
| std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_; | std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_; | ||||
| std::atomic<int64_t> num_rows_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,10 +29,7 @@ void Semaphore::V() { | |||||
| ++value_; | ++value_; | ||||
| wait_cond_.NotifyOne(); | wait_cond_.NotifyOne(); | ||||
| } | } | ||||
| int Semaphore::Peek() { | |||||
| std::unique_lock<std::mutex> lck(mutex_); | |||||
| return value_; | |||||
| } | |||||
| int Semaphore::Peek() const { return value_; } | |||||
| Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } | Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } | ||||
| Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } | Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } | ||||
| void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } | void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } | ||||
| @@ -38,7 +38,7 @@ class Semaphore { | |||||
| void V(); | void V(); | ||||
| /// \brief Peek the internal value | /// \brief Peek the internal value | ||||
| /// \return The internal value | /// \return The internal value | ||||
| int Peek(); | |||||
| int Peek() const; | |||||
| Status Register(TaskGroup *vg); | Status Register(TaskGroup *vg); | ||||
| Status Deregister(); | Status Deregister(); | ||||
| void ResetIntrpState(); | void ResetIntrpState(); | ||||
| @@ -51,6 +51,12 @@ std::string CodeAsString(const StatusCode c) { | |||||
| case StatusCode::kSyntaxError: | case StatusCode::kSyntaxError: | ||||
| s = "Syntax error"; | s = "Syntax error"; | ||||
| break; | break; | ||||
| case StatusCode::kBuddySpaceFull: | |||||
| s = "BuddySpace full"; | |||||
| break; | |||||
| case StatusCode::kNetWorkError: | |||||
| s = "Network error"; | |||||
| break; | |||||
| case StatusCode::kUnexpectedError: | case StatusCode::kUnexpectedError: | ||||
| default: | default: | ||||
| s = "Unexpected error"; | s = "Unexpected error"; | ||||
| @@ -82,6 +82,8 @@ enum class StatusCode : char { | |||||
| kBoundingBoxInvalidShape = 12, | kBoundingBoxInvalidShape = 12, | ||||
| kSyntaxError = 13, | kSyntaxError = 13, | ||||
| kTimeOut = 14, | kTimeOut = 14, | ||||
| kBuddySpaceFull = 14, | |||||
| kNetWorkError = 15, | |||||
| // Make this error code the last one. Add new error code above it. | // Make this error code the last one. Add new error code above it. | ||||
| kUnexpectedError = 127 | kUnexpectedError = 127 | ||||
| }; | }; | ||||
| @@ -137,6 +139,8 @@ class Status { | |||||
| bool IsNoSpace() const { return (get_code() == StatusCode::kNoSpace); } | bool IsNoSpace() const { return (get_code() == StatusCode::kNoSpace); } | ||||
| bool IsNetWorkError() const { return (get_code() == StatusCode::kNetWorkError); } | |||||
| private: | private: | ||||
| StatusCode code_; | StatusCode code_; | ||||
| std::string err_msg_; | std::string err_msg_; | ||||
| @@ -99,7 +99,11 @@ Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const | |||||
| #endif | #endif | ||||
| if (r_sz != sz) { | if (r_sz != sz) { | ||||
| errno_t err = (r_sz == 0) ? EOF : errno; | errno_t err = (r_sz == 0) ? EOF : errno; | ||||
| RETURN_STATUS_UNEXPECTED(strerror(err)); | |||||
| if (errno == ENOSPC) { | |||||
| return Status(StatusCode::kNoSpace, __LINE__, __FILE__); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED(strerror(err)); | |||||
| } | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -71,10 +71,11 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu | |||||
| key_type out_key; | key_type out_key; | ||||
| value_type out_value; | value_type out_value; | ||||
| bool create_new_container = false; | bool create_new_container = false; | ||||
| size_t last_num_container = -1; | |||||
| do { | do { | ||||
| SharedLock lock_s(&rw_lock_); | SharedLock lock_s(&rw_lock_); | ||||
| size_t num_containers = containers_.size(); | size_t num_containers = containers_.size(); | ||||
| if (create_new_container) { | |||||
| if (create_new_container && (num_containers == last_num_container)) { | |||||
| // Upgrade to exclusvie lock. | // Upgrade to exclusvie lock. | ||||
| lock_s.Upgrade(); | lock_s.Upgrade(); | ||||
| create_new_container = false; | create_new_container = false; | ||||
| @@ -95,8 +96,11 @@ Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &bu | |||||
| cont = containers_.at(num_containers - 1); | cont = containers_.at(num_containers - 1); | ||||
| off64_t offset; | off64_t offset; | ||||
| Status rc = cont->Insert(buf, &offset); | Status rc = cont->Insert(buf, &offset); | ||||
| if (rc.IsNoSpace()) { | |||||
| if (rc.get_code() == StatusCode::kBuddySpaceFull) { | |||||
| create_new_container = true; | create_new_container = true; | ||||
| // Remember how many containers we saw. In the next iteration we will do a comparision to see | |||||
| // if someone has already created it. | |||||
| last_num_container = num_containers; | |||||
| } else if (rc.IsOk()) { | } else if (rc.IsOk()) { | ||||
| out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); | out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); | ||||
| RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); | RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); | ||||
| @@ -15,6 +15,7 @@ | |||||
| """Cache client | """Cache client | ||||
| """ | """ | ||||
| import os | |||||
| import copy | import copy | ||||
| from ..core.validator_helpers import type_check, check_uint32, check_uint64 | from ..core.validator_helpers import type_check, check_uint32, check_uint64 | ||||
| @@ -25,11 +26,11 @@ class DatasetCache: | |||||
| A client to interface with tensor caching service | A client to interface with tensor caching service | ||||
| """ | """ | ||||
| def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20): | |||||
| def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None, | |||||
| prefetch_size=None): | |||||
| check_uint32(session_id, "session_id") | check_uint32(session_id, "session_id") | ||||
| check_uint64(size, "size") | check_uint64(size, "size") | ||||
| type_check(spilling, (bool,), "spilling") | type_check(spilling, (bool,), "spilling") | ||||
| check_uint32(prefetch_size, "prefetch size") | |||||
| self.session_id = session_id | self.session_id = session_id | ||||
| self.size = size | self.size = size | ||||
| @@ -37,8 +38,13 @@ class DatasetCache: | |||||
| self.hostname = hostname | self.hostname = hostname | ||||
| self.port = port | self.port = port | ||||
| self.prefetch_size = prefetch_size | self.prefetch_size = prefetch_size | ||||
| # temporary disable cache feature in the current release | |||||
| self.cache_client = None | |||||
| self.num_connections = num_connections | |||||
| if os.getenv('MS_ENABLE_CACHE') != 'TRUE': | |||||
| # temporary disable cache feature in the current release | |||||
| self.cache_client = None | |||||
| else: | |||||
| from mindspore._c_dataengine import CacheClient | |||||
| self.cache_client = CacheClient(session_id, size, spilling, hostname, port, num_connections, prefetch_size) | |||||
| def GetStat(self): | def GetStat(self): | ||||
| return self.cache_client.GetStat() | return self.cache_client.GetStat() | ||||
| @@ -55,5 +61,6 @@ class DatasetCache: | |||||
| new_cache.hostname = copy.deepcopy(self.hostname, memodict) | new_cache.hostname = copy.deepcopy(self.hostname, memodict) | ||||
| new_cache.port = copy.deepcopy(self.port, memodict) | new_cache.port = copy.deepcopy(self.port, memodict) | ||||
| new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | ||||
| new_cache.num_connections = copy.deepcopy(self.num_connections, memodict) | |||||
| new_cache.cache_client = self.cache_client | new_cache.cache_client = self.cache_client | ||||
| return new_cache | return new_cache | ||||
| @@ -1234,5 +1234,8 @@ def check_paddeddataset(method): | |||||
| def check_cache_option(cache): | def check_cache_option(cache): | ||||
| """Sanity check for cache parameter""" | """Sanity check for cache parameter""" | ||||
| if cache is not None: | if cache is not None: | ||||
| # temporary disable cache feature in the current release | |||||
| raise ValueError("Caching is disabled in the current release") | |||||
| if os.getenv('MS_ENABLE_CACHE') != 'TRUE': | |||||
| # temporary disable cache feature in the current release | |||||
| raise ValueError("Caching is disabled in the current release") | |||||
| from . import cache_client | |||||
| type_check(cache, (cache_client.DatasetCache,), "cache") | |||||
| @@ -156,6 +156,24 @@ def update_permissions(path): | |||||
| if filename == "ms_serving": | if filename == "ms_serving": | ||||
| os.chmod(file_fullpath, stat.S_IREAD | stat.S_IEXEC) | os.chmod(file_fullpath, stat.S_IREAD | stat.S_IEXEC) | ||||
| def bin_files(): | |||||
| """ | |||||
| Gets the binary files to be installed. | |||||
| """ | |||||
| data_files = [] | |||||
| binary_files = [] | |||||
| cache_server_bin = os.path.join('mindspore', 'bin', 'cache_server') | |||||
| if not os.path.exists(cache_server_bin): | |||||
| return data_files | |||||
| binary_files.append(cache_server_bin) | |||||
| cache_admin_bin = os.path.join('mindspore', 'bin', 'cache_admin') | |||||
| if not os.path.exists(cache_admin_bin): | |||||
| return data_files | |||||
| binary_files.append(cache_admin_bin) | |||||
| data_files.append(('bin', binary_files)) | |||||
| return data_files | |||||
| class EggInfo(egg_info): | class EggInfo(egg_info): | ||||
| """Egg info.""" | """Egg info.""" | ||||
| @@ -192,6 +210,7 @@ setup( | |||||
| 'framework that could be used for mobile, edge and cloud scenarios.', | 'framework that could be used for mobile, edge and cloud scenarios.', | ||||
| long_description="\n\n".join([readme, release]), | long_description="\n\n".join([readme, release]), | ||||
| long_description_content_type="text/markdown", | long_description_content_type="text/markdown", | ||||
| data_files=bin_files(), | |||||
| packages=find_packages(), | packages=find_packages(), | ||||
| package_data=package_data, | package_data=package_data, | ||||
| include_package_data=True, | include_package_data=True, | ||||
| @@ -24,27 +24,26 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | using mindspore::LogStream; | ||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| // For testing purposes, we will make the branching factor very low. | // For testing purposes, we will make the branching factor very low. | ||||
| struct mytraits { | struct mytraits { | ||||
| using slot_type = uint16_t; | |||||
| static const slot_type kLeafSlots = 6; | |||||
| static const slot_type kInnerSlots = 3; | |||||
| using slot_type = uint16_t; | |||||
| static const slot_type kLeafSlots = 6; | |||||
| static const slot_type kInnerSlots = 3; | |||||
| }; | }; | ||||
| class MindDataTestBPlusTree : public UT::Common { | class MindDataTestBPlusTree : public UT::Common { | ||||
| public: | public: | ||||
| MindDataTestBPlusTree() = default; | |||||
| MindDataTestBPlusTree() = default; | |||||
| }; | }; | ||||
| // Test serial insert. | // Test serial insert. | ||||
| TEST_F(MindDataTestBPlusTree, Test1) { | TEST_F(MindDataTestBPlusTree, Test1) { | ||||
| Allocator<std::string> alloc(std::make_shared<SystemPool>()); | Allocator<std::string> alloc(std::make_shared<SystemPool>()); | ||||
| BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<uint64_t>, mytraits> btree(alloc); | |||||
| BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<>, mytraits> btree(alloc); | |||||
| Status rc; | Status rc; | ||||
| for (int i = 0; i < 100; i++) { | for (int i = 0; i < 100; i++) { | ||||
| uint64_t key = 2 * i; | uint64_t key = 2 * i; | ||||
| @@ -109,23 +108,24 @@ TEST_F(MindDataTestBPlusTree, Test1) { | |||||
| // Test concurrent insert. | // Test concurrent insert. | ||||
| TEST_F(MindDataTestBPlusTree, Test2) { | TEST_F(MindDataTestBPlusTree, Test2) { | ||||
| Allocator<std::string> alloc(std::make_shared<SystemPool>()); | Allocator<std::string> alloc(std::make_shared<SystemPool>()); | ||||
| BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<uint64_t>, mytraits> btree(alloc); | |||||
| BPlusTree<uint64_t, std::string, Allocator<std::string>, std::less<>, mytraits> btree(alloc); | |||||
| TaskGroup vg; | TaskGroup vg; | ||||
| auto f = [&](int k) -> Status { | auto f = [&](int k) -> Status { | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| for (int i = 0; i < 100; i++) { | |||||
| uint64_t key = k * 100 + i; | |||||
| std::ostringstream oss; | |||||
| oss << "Hello World. I am " << key; | |||||
| Status rc = btree.DoInsert(key, oss.str()); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| } | |||||
| return Status::OK(); | |||||
| for (int i = 0; i < 100; i++) { | |||||
| uint64_t key = k * 100 + i; | |||||
| std::ostringstream oss; | |||||
| oss << "Hello World. I am " << key; | |||||
| Status rc = btree.DoInsert(key, oss.str()); | |||||
| EXPECT_TRUE(rc.IsOk()); | |||||
| } | |||||
| return Status::OK(); | |||||
| }; | }; | ||||
| auto g = [&](int k) -> Status { | auto g = [&](int k) -> Status { | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| for (int i = 0; i < 1000; i++) { | for (int i = 0; i < 1000; i++) { | ||||
| uint64_t key = rand() % 10000;; | |||||
| uint64_t key = rand() % 10000; | |||||
| ; | |||||
| auto it = btree.Search(key); | auto it = btree.Search(key); | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -226,3 +226,22 @@ TEST_F(MindDataTestBPlusTree, Test4) { | |||||
| EXPECT_EQ(cnt, 1000); | EXPECT_EQ(cnt, 1000); | ||||
| } | } | ||||
| } | } | ||||
| TEST_F(MindDataTestBPlusTree, TestPerfNoLocking) { | |||||
| AutoIndexObj<int64_t> btree; | |||||
| // No locking test | |||||
| btree.SetLocking(false); | |||||
| // Insert a million entries using the default traits. | |||||
| for (auto i = 0; i < 1000000; ++i) { | |||||
| ASSERT_TRUE(btree.insert(i)); | |||||
| } | |||||
| std::cout << "Tree height : " << btree.GetHeight() << std::endl; | |||||
| std::cout << "Tree Order : " << btree.GetOrder() << std::endl; | |||||
| std::cout << "Number of leaves : " << btree.GetNumLeaves() << std::endl; | |||||
| std::cout << "Number of inner nodes : " << btree.GetNumInnerNodes() << std::endl; | |||||
| auto r = btree.Search(3); | |||||
| EXPECT_TRUE(r.second); | |||||
| r = btree.Search(999999); | |||||
| EXPECT_TRUE(r.second); | |||||
| } | |||||
| @@ -35,6 +35,23 @@ using mindspore::dataset::TaskGroup; | |||||
| using mindspore::ExceptionType::NoExceptionType; | using mindspore::ExceptionType::NoExceptionType; | ||||
| using mindspore::MsLogLevel::INFO; | using mindspore::MsLogLevel::INFO; | ||||
| // Helper function to get the session id from SESSION_ID env variable | |||||
| Status GetSessionFromEnv(session_id_type *session_id) { | |||||
| RETURN_UNEXPECTED_IF_NULL(session_id); | |||||
| if (const char *session_env = std::getenv("SESSION_ID")) { | |||||
| std::string session_id_str(session_env); | |||||
| try { | |||||
| *session_id = std::stoul(session_id_str); | |||||
| } catch (const std::exception &e) { | |||||
| std::string err_msg = "Invalid numeric value for session id in env var: " + session_id_str; | |||||
| return Status(StatusCode::kSyntaxError, err_msg); | |||||
| } | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Test case requires a session id to be provided via SESSION_ID environment variable."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| class MindDataTestCacheOp : public UT::DatasetOpTesting { | class MindDataTestCacheOp : public UT::DatasetOpTesting { | ||||
| public: | public: | ||||
| void SetUp() override { | void SetUp() override { | ||||
| @@ -46,8 +63,12 @@ class MindDataTestCacheOp : public UT::DatasetOpTesting { | |||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { | TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { | ||||
| Status rc; | Status rc; | ||||
| CacheClient::Builder builder; | CacheClient::Builder builder; | ||||
| session_id_type env_session; | |||||
| rc = GetSessionFromEnv(&env_session); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // use arbitrary session of 1, size of 0, spilling// is true | // use arbitrary session of 1, size of 0, spilling// is true | ||||
| builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||||
| builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | std::shared_ptr<CacheClient> myClient; | ||||
| rc = builder.Build(&myClient); | rc = builder.Build(&myClient); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -118,9 +139,6 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheServer) { | |||||
| cmp = (map_out == map); | cmp = (map_out == map); | ||||
| ASSERT_TRUE(cmp); | ASSERT_TRUE(cmp); | ||||
| // Test Purge and Destroy | |||||
| rc = myClient->PurgeCache(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = myClient->DestroyCache(); | rc = myClient->DestroyCache(); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| } | } | ||||
| @@ -130,10 +148,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { | |||||
| (void)TaskManager::GetMasterThreadRc(); | (void)TaskManager::GetMasterThreadRc(); | ||||
| TaskGroup vg; | TaskGroup vg; | ||||
| Status rc; | Status rc; | ||||
| session_id_type env_session; | |||||
| rc = GetSessionFromEnv(&env_session); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // use arbitrary session of 1, size 1, spilling is true | // use arbitrary session of 1, size 1, spilling is true | ||||
| CacheClient::Builder builder; | CacheClient::Builder builder; | ||||
| // use arbitrary session of 1, size of 0, spilling// is true | // use arbitrary session of 1, size of 0, spilling// is true | ||||
| builder.SetSessionId(1).SetCacheMemSz(1).SetSpill(true); | |||||
| builder.SetSessionId(env_session).SetCacheMemSz(1).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | std::shared_ptr<CacheClient> myClient; | ||||
| rc = builder.Build(&myClient); | rc = builder.Build(&myClient); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -199,8 +222,15 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestConcurrencyRequest) { | |||||
| // RandomDataOp | // RandomDataOp | ||||
| // | // | ||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | ||||
| // Clear the rc of the master thread if any | |||||
| (void)TaskManager::GetMasterThreadRc(); | |||||
| Status rc; | Status rc; | ||||
| int32_t rank = 0; // not used | int32_t rank = 0; // not used | ||||
| session_id_type env_session; | |||||
| rc = GetSessionFromEnv(&env_session); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| MS_LOG(INFO) << "UT test TestRandomDataCache1"; | MS_LOG(INFO) << "UT test TestRandomDataCache1"; | ||||
| // Start with an empty execution tree | // Start with an empty execution tree | ||||
| auto myTree = std::make_shared<ExecutionTree>(); | auto myTree = std::make_shared<ExecutionTree>(); | ||||
| @@ -236,8 +266,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||||
| // CacheOp | // CacheOp | ||||
| // size of 0, spilling is true | // size of 0, spilling is true | ||||
| CacheClient::Builder builder; | CacheClient::Builder builder; | ||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| builder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||||
| builder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | std::shared_ptr<CacheClient> myClient; | ||||
| rc = builder.Build(&myClient); | rc = builder.Build(&myClient); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -273,7 +302,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | MS_LOG(INFO) << "Launching tree and begin iteration"; | ||||
| rc = myTree->Prepare(); | |||||
| rc = myTree->Prepare(1); | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| // quick check to see what tree looks like | // quick check to see what tree looks like | ||||
| @@ -314,9 +343,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { | |||||
| //// RandomDataOp | //// RandomDataOp | ||||
| //// | //// | ||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | ||||
| // Clear the rc of the master thread if any | |||||
| (void)TaskManager::GetMasterThreadRc(); | |||||
| Status rc; | Status rc; | ||||
| int32_t rank = 0; // not used | int32_t rank = 0; // not used | ||||
| MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; | MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; | ||||
| session_id_type env_session; | |||||
| rc = GetSessionFromEnv(&env_session); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Start with an empty execution tree | // Start with an empty execution tree | ||||
| auto myTree = std::make_shared<ExecutionTree>(); | auto myTree = std::make_shared<ExecutionTree>(); | ||||
| @@ -353,8 +389,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | ||||
| CacheClient::Builder builder; | CacheClient::Builder builder; | ||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| builder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); | |||||
| builder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | std::shared_ptr<CacheClient> myClient; | ||||
| rc = builder.Build(&myClient); | rc = builder.Build(&myClient); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -386,7 +421,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | MS_LOG(INFO) << "Launching tree and begin iteration"; | ||||
| rc = myTree->Prepare(); | |||||
| rc = myTree->Prepare(1); | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| std::cout << *myClient << std::endl; | std::cout << *myClient << std::endl; | ||||
| @@ -413,14 +448,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { | |||||
| } | } | ||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | ||||
| // Clear the rc of the master thread if any | |||||
| (void)TaskManager::GetMasterThreadRc(); | |||||
| Status rc; | Status rc; | ||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| session_id_type env_session; | |||||
| rc = GetSessionFromEnv(&env_session); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | ||||
| CacheClient::Builder ccbuilder; | CacheClient::Builder ccbuilder; | ||||
| // use arbitrary session of 1, size of 0, spilling// is true | |||||
| ccbuilder.SetSessionId(1).SetCacheMemSz(0).SetSpill(true); | |||||
| ccbuilder.SetSessionId(env_session).SetCacheMemSz(0).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | std::shared_ptr<CacheClient> myClient; | ||||
| rc = ccbuilder.Build(&myClient); | rc = ccbuilder.Build(&myClient); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -468,7 +509,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||||
| rc = myCacheOp->AddChild(so); | rc = myCacheOp->AddChild(so); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| rc = myTree->Prepare(); | |||||
| rc = myTree->Prepare(1); | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| rc = myTree->Launch(); | rc = myTree->Launch(); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -507,10 +548,16 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||||
| //// RandomDataOp | //// RandomDataOp | ||||
| //// | //// | ||||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | ||||
| // Clear the rc of the master thread if any | |||||
| (void)TaskManager::GetMasterThreadRc(); | |||||
| Status rc; | Status rc; | ||||
| int32_t rank = 0; // not used | int32_t rank = 0; // not used | ||||
| MS_LOG(INFO) << "UT test TestCacheInheritSampler"; | MS_LOG(INFO) << "UT test TestCacheInheritSampler"; | ||||
| session_id_type env_session; | |||||
| rc = GetSessionFromEnv(&env_session); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| int64_t num_samples = 0; | int64_t num_samples = 0; | ||||
| int64_t start_index = 0; | int64_t start_index = 0; | ||||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | ||||
| @@ -550,7 +597,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | |||||
| // CacheOp | // CacheOp | ||||
| CacheClient::Builder ccbuilder; | CacheClient::Builder ccbuilder; | ||||
| // use arbitrary session of 1, size of 0, spilling// is true | // use arbitrary session of 1, size of 0, spilling// is true | ||||
| ccbuilder.SetSessionId(1).SetCacheMemSz(4).SetSpill(true); | |||||
| ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); | |||||
| std::shared_ptr<CacheClient> myClient; | std::shared_ptr<CacheClient> myClient; | ||||
| rc = ccbuilder.Build(&myClient); | rc = ccbuilder.Build(&myClient); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -577,7 +624,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | MS_LOG(INFO) << "Launching tree and begin iteration"; | ||||
| rc = myTree->Prepare(); | |||||
| rc = myTree->Prepare(1); | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| std::cout << *myClient << std::endl; | std::cout << *myClient << std::endl; | ||||
| @@ -25,13 +25,13 @@ using namespace mindspore::dataset; | |||||
| class MindDataTestMemoryPool : public UT::Common { | class MindDataTestMemoryPool : public UT::Common { | ||||
| public: | public: | ||||
| std::shared_ptr<MemoryPool> mp_; | |||||
| MindDataTestMemoryPool() {} | |||||
| std::shared_ptr<MemoryPool> mp_; | |||||
| MindDataTestMemoryPool() {} | |||||
| void SetUp() { | |||||
| Status rc = CircularPool::CreateCircularPool(&mp_, 1, 1, true); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | |||||
| void SetUp() { | |||||
| Status rc = CircularPool::CreateCircularPool(&mp_, 1, 1, true); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | |||||
| }; | }; | ||||
| TEST_F(MindDataTestMemoryPool, DumpPoolInfo) { | TEST_F(MindDataTestMemoryPool, DumpPoolInfo) { | ||||
| @@ -40,7 +40,7 @@ TEST_F(MindDataTestMemoryPool, DumpPoolInfo) { | |||||
| TEST_F(MindDataTestMemoryPool, TestOperator1) { | TEST_F(MindDataTestMemoryPool, TestOperator1) { | ||||
| Status rc; | Status rc; | ||||
| int *p = new(&rc, mp_) int; | |||||
| int *p = new (&rc, mp_) int; | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| *p = 2048; | *p = 2048; | ||||
| ::operator delete(p, mp_); | ::operator delete(p, mp_); | ||||
| @@ -61,12 +61,11 @@ TEST_F(MindDataTestMemoryPool, TestOperator3) { | |||||
| TEST_F(MindDataTestMemoryPool, TestAllocator) { | TEST_F(MindDataTestMemoryPool, TestAllocator) { | ||||
| class A { | class A { | ||||
| public: | public: | ||||
| explicit A (int x) : a(x) {} | |||||
| int val_a() const { | |||||
| return a; | |||||
| } | |||||
| explicit A(int x) : a(x) {} | |||||
| int val_a() const { return a; } | |||||
| private: | private: | ||||
| int a; | |||||
| int a; | |||||
| }; | }; | ||||
| Allocator<A> alloc(mp_); | Allocator<A> alloc(mp_); | ||||
| std::shared_ptr<A> obj_a = std::allocate_shared<A>(alloc, 3); | std::shared_ptr<A> obj_a = std::allocate_shared<A>(alloc, 3); | ||||
| @@ -74,3 +73,16 @@ TEST_F(MindDataTestMemoryPool, TestAllocator) { | |||||
| ASSERT_EQ(v, 3); | ASSERT_EQ(v, 3); | ||||
| MS_LOG(DEBUG) << *(std::dynamic_pointer_cast<CircularPool>(mp_)) << std::endl; | MS_LOG(DEBUG) << *(std::dynamic_pointer_cast<CircularPool>(mp_)) << std::endl; | ||||
| } | } | ||||
| TEST_F(MindDataTestMemoryPool, TestMemGuard) { | |||||
| MemGuard<uint8_t> mem; | |||||
| // Try some large value. | |||||
| int64_t sz = 5LL * 1024LL * 1024LL * 1024LL; | |||||
| Status rc = mem.allocate(sz); | |||||
| ASSERT_TRUE(rc.IsOk() || rc.IsOutofMemory()); | |||||
| if (rc.IsOk()) { | |||||
| // Try write a character half way. | |||||
| auto *p = mem.GetMutablePointer(); | |||||
| p[sz / 2] = 'a'; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,48 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # This script is the driver of the individual test scenarios | |||||
| CURRPATH=$(cd $(dirname $0); pwd) | |||||
| echo "----------------------------------------------" | |||||
| echo "Invalid syntax and cache_admin failure testing" | |||||
| echo "----------------------------------------------" | |||||
| echo | |||||
| ${CURRPATH}/cachetest_args.sh | |||||
| num_failures=$? | |||||
| echo | |||||
| echo "Invalid syntax and cache_admin failure testing complete. Number of failures: $num_failures" | |||||
| echo | |||||
| echo "----------------------------------------------" | |||||
| echo "Test pipelines with cache (python)" | |||||
| echo "----------------------------------------------" | |||||
| echo | |||||
| ${CURRPATH}/cachetest_py.sh | |||||
| num_failures=$? | |||||
| echo | |||||
| echo "Test pipelines with cache complete. Number of failures: $num_failures" | |||||
| echo | |||||
| echo "----------------------------------------------" | |||||
| echo "Cache cpp tests" | |||||
| echo "----------------------------------------------" | |||||
| echo | |||||
| ${CURRPATH}/cachetest_cpp.sh | |||||
| num_failures=$? | |||||
| echo | |||||
| echo "Cache cpp tests complete. Number of failures: $num_failures" | |||||
| echo | |||||
| @@ -0,0 +1,207 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # source the globals and functions for use with cache testing | |||||
| SKIP_ADMIN_COUNTER=false | |||||
| . cachetest_lib.sh | |||||
| echo | |||||
| ################################################################################ | |||||
| # Cache testing: cache_admin argument testing # | |||||
| # Summary: Various tests that expect to get failure messages returned # | |||||
| ################################################################################ | |||||
| # Double-command test | |||||
| cmd="${CACHE_ADMIN} --start --stop" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # missing command test | |||||
| cmd="${CACHE_ADMIN} --port 50082" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # bad arg test | |||||
| cmd="${CACHE_ADMIN} -p abc --start" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # missing arg test | |||||
| cmd="${CACHE_ADMIN} -p --start" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # invalid command | |||||
| cmd="${CACHE_ADMIN} -p 50082 --start --not_exist_cmd" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # spill directory does not exist | |||||
| cmd="${CACHE_ADMIN} --start --spilldir /path_that_does_not_exist" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # start cache server twice | |||||
| StartServer | |||||
| HandleRcExit $? 1 1 | |||||
| # start the cache server again, however, this time we expect an error | |||||
| cmd="${CACHE_ADMIN} --start" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| StopServer | |||||
| HandleRcExit $? 1 1 | |||||
| # start cache server twice with different ports | |||||
| # this one starts with the default port 50052 | |||||
| StartServer | |||||
| HandleRcExit $? 1 1 | |||||
| # this one starts with port 50053 | |||||
| cmd="${CACHE_ADMIN} --start -p 50053" | |||||
| CacheAdminCmd "${cmd}" 0 | |||||
| HandleRcExit $? 1 1 | |||||
| # stop the cache server with default port | |||||
| StopServer | |||||
| HandleRcExit $? 1 1 | |||||
| # stop the cache server with port 50053 | |||||
| cmd="${CACHE_ADMIN} --stop -p 50053" | |||||
| CacheAdminCmd "${cmd}" 0 | |||||
| HandleRcExit $? 1 1 | |||||
| # stop the cache server without bringing it up | |||||
| cmd="${CACHE_ADMIN} --stop" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| # start the cache server with illegal hostname | |||||
| cmd="${CACHE_ADMIN} --start -h 0.0.0.0" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| cmd="${CACHE_ADMIN} --start -h illegal" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| cmd="${CACHE_ADMIN} --start -h" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| cmd="${CACHE_ADMIN} --start -h --hostname" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| cmd="${CACHE_ADMIN} --start -h --hostname 127.0.0.1" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| # start the cache server with illegal port | |||||
| cmd="${CACHE_ADMIN} --start -p 0" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| cmd="${CACHE_ADMIN} --start -p -1" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| cmd="${CACHE_ADMIN} --start -p 65536" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| cmd="${CACHE_ADMIN} --start -p illegal" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| cmd="${CACHE_ADMIN} --start -p" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| # find a port that is occupied using netstat | |||||
| if [ -x "$(command -v netstat)" ]; then | |||||
| port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) | |||||
| if [ ${port} -gt 1025 ]; then | |||||
| # start cache server with occupied port | |||||
| cmd="${CACHE_ADMIN} --start -p ${port}" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 1 | |||||
| fi | |||||
| fi | |||||
| # generate session before starting the cache server | |||||
| cmd="${CACHE_ADMIN} -g" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # illegal generate session command | |||||
| StartServer | |||||
| HandleRcExit $? 1 1 | |||||
| cmd="${CACHE_ADMIN} -g 1" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # illegal destroy session command | |||||
| cmd="${CACHE_ADMIN} -d -2" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} -d illegal" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} -d" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # destroy a non-existing session | |||||
| cmd="${CACHE_ADMIN} -d 99999" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # stop cache server at this point | |||||
| StopServer | |||||
| HandleRcExit $? 1 1 | |||||
| # illegal number of workers | |||||
| cmd="${CACHE_ADMIN} --start -w 0" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} --start -w -1" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} --start -w illegal" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} --start -w 101" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} --start -w 9999999" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} --start -w" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # illegal spill path | |||||
| cmd="${CACHE_ADMIN} --start -s" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| # spill path without writing perm | |||||
| if [ "$EUID" -ne 0 ]; then | |||||
| cmd="${CACHE_ADMIN} --start -s /" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| fi | |||||
| # illegal log level | |||||
| cmd="${CACHE_ADMIN} --start -l 4" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} --start -l -1" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| cmd="${CACHE_ADMIN} --start -l" | |||||
| CacheAdminCmd "${cmd}" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| exit ${failed_tests} | |||||