Turn --stop into sync Ignore the rc of shutdowntags/v1.0.0
| @@ -23,11 +23,13 @@ namespace dataset { | |||||
| PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) { | ||||
| (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient") | ||||
| .def( | .def( | ||||
| py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) { | |||||
| py::init([](session_id_type id, uint64_t mem_sz, bool spill, std::optional<std::string> hostname, | |||||
| std::optional<int32_t> port, int32_t prefetch_sz) { | |||||
| std::shared_ptr<CacheClient> cc; | std::shared_ptr<CacheClient> cc; | ||||
| CacheClient::Builder builder; | CacheClient::Builder builder; | ||||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize( | |||||
| prefetch_sz); | |||||
| builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPrefetchSize(prefetch_sz); | |||||
| if (hostname) builder.SetHostname(hostname.value()); | |||||
| if (port) builder.SetPort(port.value()); | |||||
| THROW_IF_ERROR(builder.Build(&cc)); | THROW_IF_ERROR(builder.Build(&cc)); | ||||
| return cc; | return cc; | ||||
| })) | })) | ||||
| @@ -19,10 +19,37 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <string> | #include <string> | ||||
| #include "mindspore/core/utils/log_adapter.h" | |||||
| #include "minddata/dataset/util/system_pool.h" | #include "minddata/dataset/util/system_pool.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| ConfigManager::ConfigManager() | |||||
| : rows_per_buffer_(kCfgRowsPerBuffer), | |||||
| num_parallel_workers_(kCfgParallelWorkers), | |||||
| worker_connector_size_(kCfgWorkerConnectorSize), | |||||
| op_connector_size_(kCfgOpConnectorSize), | |||||
| seed_(kCfgDefaultSeed), | |||||
| monitor_sampling_interval_(kCfgMonitorSamplingInterval), | |||||
| callback_timout_(kCfgCallbackTimeout), | |||||
| cache_host_(kCfgDefaultCacheHost), | |||||
| cache_port_(kCfgDefaultCachePort) { | |||||
| auto env_cache_host = std::getenv("MS_CACHE_HOST"); | |||||
| auto env_cache_port = std::getenv("MS_CACHE_PORT"); | |||||
| if (env_cache_host) { | |||||
| cache_host_ = env_cache_host; | |||||
| } | |||||
| if (env_cache_port) { | |||||
| char *end = nullptr; | |||||
| cache_port_ = strtol(env_cache_port, &end, 10); | |||||
| if (*end != '\0') { | |||||
| MS_LOG(WARNING) << "\nCache port from env variable MS_CACHE_PORT is invalid, back to use default " | |||||
| << kCfgDefaultCachePort << std::endl; | |||||
| cache_port_ = kCfgDefaultCachePort; | |||||
| } | |||||
| } | |||||
| } | |||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| void ConfigManager::Print(std::ostream &out) const { | void ConfigManager::Print(std::ostream &out) const { | ||||
| // Don't show the test/internal ones. Only display the main ones here. | // Don't show the test/internal ones. Only display the main ones here. | ||||
| @@ -42,6 +69,8 @@ Status ConfigManager::FromJson(const nlohmann::json &j) { | |||||
| set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); | set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); | ||||
| set_seed(j.value("seed", seed_)); | set_seed(j.value("seed", seed_)); | ||||
| set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); | set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); | ||||
| set_cache_host(j.value("cacheHost", cache_host_)); | |||||
| set_cache_port(j.value("cachePort", cache_port_)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -91,5 +120,8 @@ void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_s | |||||
| void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } | void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } | ||||
| void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = cache_host; } | |||||
| void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,7 +41,7 @@ namespace dataset { | |||||
| // those values. | // those values. | ||||
| class ConfigManager { | class ConfigManager { | ||||
| public: | public: | ||||
| ConfigManager() = default; | |||||
| ConfigManager(); | |||||
| // destructor | // destructor | ||||
| ~ConfigManager() = default; | ~ConfigManager() = default; | ||||
| @@ -89,6 +89,14 @@ class ConfigManager { | |||||
| // @return The internal worker-to-master connector queue size | // @return The internal worker-to-master connector queue size | ||||
| int32_t worker_connector_size() const { return worker_connector_size_; } | int32_t worker_connector_size() const { return worker_connector_size_; } | ||||
| // getter function | |||||
| // @return The hostname of cache server | |||||
| std::string cache_host() const { return cache_host_; } | |||||
| // getter function | |||||
| // @return The port of cache server | |||||
| int32_t cache_port() const { return cache_port_; } | |||||
| // setter function | // setter function | ||||
| // @param rows_per_buffer - The setting to apply to the config | // @param rows_per_buffer - The setting to apply to the config | ||||
| void set_rows_per_buffer(int32_t rows_per_buffer); | void set_rows_per_buffer(int32_t rows_per_buffer); | ||||
| @@ -105,6 +113,14 @@ class ConfigManager { | |||||
| // @param connector_size - The setting to apply to the config | // @param connector_size - The setting to apply to the config | ||||
| void set_op_connector_size(int32_t connector_size); | void set_op_connector_size(int32_t connector_size); | ||||
| // setter function | |||||
| // @param cache_host - The hostname of cache server | |||||
| void set_cache_host(std::string cache_host); | |||||
| // setter function | |||||
| // @param cache_port - The port of cache server | |||||
| void set_cache_port(int32_t cache_port); | |||||
| uint32_t seed() const; | uint32_t seed() const; | ||||
| // setter function | // setter function | ||||
| @@ -128,13 +144,15 @@ class ConfigManager { | |||||
| int32_t callback_timeout() const { return callback_timout_; } | int32_t callback_timeout() const { return callback_timout_; } | ||||
| private: | private: | ||||
| int32_t rows_per_buffer_{kCfgRowsPerBuffer}; | |||||
| int32_t num_parallel_workers_{kCfgParallelWorkers}; | |||||
| int32_t worker_connector_size_{kCfgWorkerConnectorSize}; | |||||
| int32_t op_connector_size_{kCfgOpConnectorSize}; | |||||
| uint32_t seed_{kCfgDefaultSeed}; | |||||
| uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; | |||||
| uint32_t callback_timout_{kCfgCallbackTimeout}; | |||||
| int32_t rows_per_buffer_; | |||||
| int32_t num_parallel_workers_; | |||||
| int32_t worker_connector_size_; | |||||
| int32_t op_connector_size_; | |||||
| uint32_t seed_; | |||||
| uint32_t monitor_sampling_interval_; | |||||
| uint32_t callback_timout_; | |||||
| std::string cache_host_; | |||||
| int32_t cache_port_; | |||||
| // Private helper function that takes a nlohmann json format and populates the settings | // Private helper function that takes a nlohmann json format and populates the settings | ||||
| // @param j - The json nlohmann json info | // @param j - The json nlohmann json info | ||||
| @@ -69,6 +69,8 @@ constexpr uint32_t kCfgOpConnectorSize = 16; | |||||
| constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; | constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; | ||||
| constexpr uint32_t kCfgMonitorSamplingInterval = 10; | constexpr uint32_t kCfgMonitorSamplingInterval = 10; | ||||
| constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds | constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds | ||||
| constexpr int32_t kCfgDefaultCachePort = 50052; | |||||
| constexpr char kCfgDefaultCacheHost[] = "127.0.0.1"; | |||||
| // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | ||||
| constexpr uint8_t kCVInvalidType = 255; | constexpr uint8_t kCVInvalidType = 255; | ||||
| @@ -25,21 +25,21 @@ | |||||
| #include "minddata/dataset/engine/cache/cache_request.h" | #include "minddata/dataset/engine/cache/cache_request.h" | ||||
| #include "minddata/dataset/engine/cache/cache_client.h" | #include "minddata/dataset/engine/cache/cache_client.h" | ||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| #include "minddata/dataset/core/constants.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1"; | |||||
| const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; | const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; | ||||
| const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; | const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; | ||||
| CacheAdminArgHandler::CacheAdminArgHandler() | CacheAdminArgHandler::CacheAdminArgHandler() | ||||
| : port_(kDefaultPort), | |||||
| : port_(kCfgDefaultCachePort), | |||||
| session_id_(0), | session_id_(0), | ||||
| num_workers_(kDefaultNumWorkers), | num_workers_(kDefaultNumWorkers), | ||||
| shm_mem_sz_(kDefaultSharedMemorySizeInGB), | shm_mem_sz_(kDefaultSharedMemorySizeInGB), | ||||
| log_level_(kDefaultLogLevel), | log_level_(kDefaultLogLevel), | ||||
| hostname_(kDefaultHost), | |||||
| hostname_(kCfgDefaultCacheHost), | |||||
| spill_dir_(kDefaultSpillDir), | spill_dir_(kDefaultSpillDir), | ||||
| command_id_(CommandId::kCmdUnknown) { | command_id_(CommandId::kCmdUnknown) { | ||||
| // Initialize the command mappings | // Initialize the command mappings | ||||
| @@ -376,6 +376,8 @@ Status CacheAdminArgHandler::StopServer() { | |||||
| RETURN_IF_NOT_OK(comm.ServiceStart()); | RETURN_IF_NOT_OK(comm.ServiceStart()); | ||||
| auto rq = std::make_shared<ShutdownRequest>(); | auto rq = std::make_shared<ShutdownRequest>(); | ||||
| RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | RETURN_IF_NOT_OK(comm.HandleRequest(rq)); | ||||
| // We will ignore the rc because if the shutdown is successful, the server will not reply back. | |||||
| (void)rq->Wait(); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -29,11 +29,9 @@ namespace dataset { | |||||
| class CacheAdminArgHandler { | class CacheAdminArgHandler { | ||||
| public: | public: | ||||
| static constexpr int32_t kDefaultPort = 50052; | |||||
| static constexpr int32_t kDefaultNumWorkers = 32; | static constexpr int32_t kDefaultNumWorkers = 32; | ||||
| static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; | ||||
| static constexpr int32_t kDefaultLogLevel = 1; | static constexpr int32_t kDefaultLogLevel = 1; | ||||
| static const char kDefaultHost[]; | |||||
| static const char kServerBinary[]; | static const char kServerBinary[]; | ||||
| static const char kDefaultSpillDir[]; | static const char kDefaultSpillDir[]; | ||||
| @@ -23,6 +23,35 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CacheClient::Builder::Builder() | |||||
| : session_id_(0), cache_mem_sz_(0), spill_(false), hostname_(""), port_(0), num_workers_(0), prefetch_size_(0) { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| hostname_ = cfg->cache_host(); | |||||
| port_ = cfg->cache_port(); | |||||
| num_workers_ = cfg->num_parallel_workers(); | |||||
| prefetch_size_ = 20; // rows_per_buf is too small (1 by default). | |||||
| } | |||||
| Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | |||||
| *out = | |||||
| std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, prefetch_size_); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CacheClient::Builder::SanityCheck() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1", | |||||
| "now cache client has to be on the same host with cache server"); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Constructor | // Constructor | ||||
| CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, | CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, | ||||
| @@ -44,13 +44,7 @@ class CacheClient { | |||||
| /// \brief A builder to help creating a CacheClient object | /// \brief A builder to help creating a CacheClient object | ||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||||
| hostname_ = "127.0.0.1"; | |||||
| port_ = 50052; | |||||
| num_workers_ = cfg->num_parallel_workers(); | |||||
| prefetch_size_ = 20; // rows_per_buf is too small (1 by default). | |||||
| } | |||||
| Builder(); | |||||
| /// Setter function to set the session id | /// Setter function to set the session id | ||||
| /// \param session_id | /// \param session_id | ||||
| @@ -117,22 +111,9 @@ class CacheClient { | |||||
| int32_t getNumWorkers() const { return num_workers_; } | int32_t getNumWorkers() const { return num_workers_; } | ||||
| int32_t getPrefetchSize() const { return prefetch_size_; } | int32_t getPrefetchSize() const { return prefetch_size_; } | ||||
| Status SanityCheck() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SanityCheck(); | |||||
| Status Build(std::shared_ptr<CacheClient> *out) { | |||||
| RETURN_UNEXPECTED_IF_NULL(out); | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | |||||
| *out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_, | |||||
| prefetch_size_); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Build(std::shared_ptr<CacheClient> *out); | |||||
| private: | private: | ||||
| session_id_type session_id_; | session_id_type session_id_; | ||||
| @@ -20,24 +20,25 @@ from mindspore._c_dataengine import CacheClient | |||||
| from ..core.validator_helpers import type_check, check_uint32, check_uint64 | from ..core.validator_helpers import type_check, check_uint32, check_uint64 | ||||
| class DatasetCache: | class DatasetCache: | ||||
| """ | """ | ||||
| A client to interface with tensor caching service | A client to interface with tensor caching service | ||||
| """ | """ | ||||
| def __init__(self, session_id=None, size=0, spilling=False, port=50052, prefetch_size=20): | |||||
| def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, prefetch_size=20): | |||||
| check_uint32(session_id, "session_id") | check_uint32(session_id, "session_id") | ||||
| check_uint64(size, "size") | check_uint64(size, "size") | ||||
| type_check(spilling, (bool,), "spilling") | type_check(spilling, (bool,), "spilling") | ||||
| check_uint32(port, "port") | |||||
| check_uint32(prefetch_size, "prefetch size") | check_uint32(prefetch_size, "prefetch size") | ||||
| self.session_id = session_id | self.session_id = session_id | ||||
| self.size = size | self.size = size | ||||
| self.spilling = spilling | self.spilling = spilling | ||||
| self.hostname = hostname | |||||
| self.port = port | self.port = port | ||||
| self.prefetch_size = prefetch_size | self.prefetch_size = prefetch_size | ||||
| self.cache_client = CacheClient(session_id, size, spilling, port, prefetch_size) | |||||
| self.cache_client = CacheClient(session_id, size, spilling, hostname, port, prefetch_size) | |||||
| def GetStat(self): | def GetStat(self): | ||||
| return self.cache_client.GetStat() | return self.cache_client.GetStat() | ||||
| @@ -51,6 +52,7 @@ class DatasetCache: | |||||
| new_cache.session_id = copy.deepcopy(self.session_id, memodict) | new_cache.session_id = copy.deepcopy(self.session_id, memodict) | ||||
| new_cache.spilling = copy.deepcopy(self.spilling, memodict) | new_cache.spilling = copy.deepcopy(self.spilling, memodict) | ||||
| new_cache.size = copy.deepcopy(self.size, memodict) | new_cache.size = copy.deepcopy(self.size, memodict) | ||||
| new_cache.hostname = copy.deepcopy(self.hostname, memodict) | |||||
| new_cache.port = copy.deepcopy(self.port, memodict) | new_cache.port = copy.deepcopy(self.port, memodict) | ||||
| new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | new_cache.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) | ||||
| new_cache.cache_client = self.cache_client | new_cache.cache_client = self.cache_client | ||||